Deep Recurrent Neural Networks

22. Deep Recurrent Neural Networks#

22.1. Concise Implementation#

Fortunately many of the logistical details required to implement multiple layers of an RNN are readily available in high-level APIs. Our concise implementation will use such built-in functionalities.

using Downloads,IterTools,CUDA,Flux
using StatsBase: wsample

device = Flux.get_device(; verbose=true)

file_path = Downloads.download("http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt")
raw_text = open(io->read(io, String),file_path)
str = lowercase(replace(raw_text,r"[^A-Za-z]+"=>" "))
tokens = [str...]
vocab = unique(tokens)
vocab_len = length(vocab)

# n*[seq_length x feature x batch_size]
function getdata(str::String,vocab::Vector{Char},seq_length::Int,batch_size::Int)::Tuple
    data = collect.(partition(str,seq_length,1))
    x = [[Flux.onehotbatch(i,vocab) for i in d] for d in Flux.batchseq.(Flux.chunk(data[begin:end-1];size = batch_size))]
    y = [[Flux.onehotbatch(i,vocab) for i in d] for d in Flux.batchseq.(Flux.chunk(data[2:end];size = batch_size))]
    return x,y
end

function loss(model, xs, ys)
    Flux.reset!(model)
    return sum(Flux.logitcrossentropy.([model(x) for x in xs], ys))
end

function predict(model::Chain, prefix::String, num_preds::Int)
    model = cpu(model)
    Flux.reset!(model)
    buf = IOBuffer()
    write(buf, prefix)

    c = wsample(vocab, softmax([model(Flux.onehot(c, vocab)) for c in collect(prefix)][end]))
    for i in 1:num_preds
        write(buf, c)
        c = wsample(vocab, softmax(model(Flux.onehot(c, vocab))))
    end
    return String(take!(buf))
end
┌ Info: Using backend: CUDA.
└ @ Flux /home/nero/.julia/packages/Flux/Wz6D4/src/functor.jl:662
predict (generic function with 1 method)
model = Chain(GRUv3(vocab_len => 32),GRUv3(32 => 32),Dense(32 => vocab_len)) |> device
Chain(
  Recur(
    GRUv3Cell(27 => 32),                # 5_792 parameters
  ),
  Recur(
    GRUv3Cell(32 => 32),                # 6_272 parameters
  ),
  Dense(32 => 27),                      # 891 parameters
)         # Total: 12 trainable arrays, 12_955 parameters,
          # plus 2 non-trainable, 64 parameters, summarysize 1.938 KiB.
opt_state = Flux.setup(Adam(1e-2), model)

x,y = getdata(str, vocab, 32, 1024) |> device
data = zip(x,y)
loss_train = []

for epoch in 1:50
    @show "$epoch start"
    Flux.reset!(model)
    Flux.train!(loss,model,data,opt_state)
    push!(loss_train,sum(loss.(Ref(model), x, y)) / length(str)) 
end