22. Deep Recurrent Neural Networks#
22.1. Concise Implementation#
Fortunately many of the logistical details required to implement multiple layers of an RNN are readily available in high-level APIs. Our concise implementation will use such built-in functionalities.
using Downloads,IterTools,CUDA,Flux
using StatsBase: wsample
device = Flux.get_device(; verbose=true)
file_path = Downloads.download("http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt")
raw_text = open(io->read(io, String),file_path)
str = lowercase(replace(raw_text,r"[^A-Za-z]+"=>" "))
tokens = [str...]
vocab = unique(tokens)
vocab_len = length(vocab)
# n*[seq_length x feature x batch_size]
function getdata(str::String,vocab::Vector{Char},seq_length::Int,batch_size::Int)::Tuple
data = collect.(partition(str,seq_length,1))
x = [[Flux.onehotbatch(i,vocab) for i in d] for d in Flux.batchseq.(Flux.chunk(data[begin:end-1];size = batch_size))]
y = [[Flux.onehotbatch(i,vocab) for i in d] for d in Flux.batchseq.(Flux.chunk(data[2:end];size = batch_size))]
return x,y
end
function loss(model, xs, ys)
Flux.reset!(model)
return sum(Flux.logitcrossentropy.([model(x) for x in xs], ys))
end
function predict(model::Chain, prefix::String, num_preds::Int)
model = cpu(model)
Flux.reset!(model)
buf = IOBuffer()
write(buf, prefix)
c = wsample(vocab, softmax([model(Flux.onehot(c, vocab)) for c in collect(prefix)][end]))
for i in 1:num_preds
write(buf, c)
c = wsample(vocab, softmax(model(Flux.onehot(c, vocab))))
end
return String(take!(buf))
end
┌ Info: Using backend: CUDA.
└ @ Flux /home/nero/.julia/packages/Flux/Wz6D4/src/functor.jl:662
predict (generic function with 1 method)
model = Chain(GRUv3(vocab_len => 32),GRUv3(32 => 32),Dense(32 => vocab_len)) |> device
Chain(
Recur(
GRUv3Cell(27 => 32), # 5_792 parameters
),
Recur(
GRUv3Cell(32 => 32), # 6_272 parameters
),
Dense(32 => 27), # 891 parameters
) # Total: 12 trainable arrays, 12_955 parameters,
# plus 2 non-trainable, 64 parameters, summarysize 1.938 KiB.
opt_state = Flux.setup(Adam(1e-2), model)
x,y = getdata(str, vocab, 32, 1024) |> device
data = zip(x,y)
loss_train = []
for epoch in 1:50
@show "$epoch start"
Flux.reset!(model)
Flux.train!(loss,model,data,opt_state)
push!(loss_train,sum(loss.(Ref(model), x, y)) / length(str))
end