Pytorch中nn.RNN()基本用法和輸入輸出
阿新 • • 發佈:2020-12-20
以下均為單向RNN。
0. RNN模型結構
網上教程的標準RNN結構如下圖,其實是有輸入層x、隱藏層h和輸出層y三層結構的。
但是在Pytorch中定義的RNN,其實是沒有y這個輸出層的。例如下圖中,Pytorch版本的兩個輸出,output=[h1, h2, h3, h4], hn = h4。如果想要得到輸出層y,可以自行加一個全連線層。
1. 初始化RNN
rnn = nn.RNN(input_size, hidden_size, num_layers)
2. RNN的輸入
- input:(seq_len, batch_size, input_size)
- h0:(num_layers, batch_size, hidden_size)
注:
a. h0如果沒有被提供,則預設設定為全0
b. 實際上h0維度是(num_layers*num_directions, batch_size, hidden_size),如果RNN是單向的,則num_directions=1;如果RNN是雙向的,則num_directions=2,此處取單向RNN。
3. RNN輸出
- output:(seq_len, batch_size, hidden_size)
為每個時間步得到的hidden_state - hn:(num_layers, batch_size, hidden_size)
注 :
a. 實際output維度是(seq_len, batch_size, num_directions * hidden_size)
b. 實際hn維度是(num_layers * num_directions, batch_size, hidden_size)
4. 例項
#########定義模型和輸入#########
# (input_size, hidden_size, num_layers)
rnn = nn.RNN(10, 20, 1)
# (seq_len, batch_size, input_size)
input = torch. randn(5, 3, 10)
# (num_layers, batch_size, hidden_size)
h0 = torch.randn(1, 3, 20)
#########將輸入喂入模型#########
output, hn = rnn(input, h0)
#########檢視模型引數#########
rnn._parameters