PyTorch 中 LSTM 的 output、h_n 和 c_n 之間的關係
阿新 • • 發佈:2020-12-24
技術標籤:PyTorch 基礎
LSTM 簡介
- 官方文件:https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
- h_n:最後一個時間步的輸出,即 h_n = output[:, -1, :](一般可以直接輸入到後續的全連線層,在 Keras 中通過設定引數 return_sequences=False 獲得)
- c_n:最後一個時間步 LSTM cell 的狀態(一般用不到)
例項
-
例項:根據紅框可以直觀看出,h_n 是最後一個時間步的輸出,即是 h_n = output[:, -1, :],如何還是無法直觀理解,直接看如下截圖,對照程式碼可以非常容易看出它們的關係
-
例項程式碼:
>>> import torch >>> import torch.nn as nn >>> rnn = nn.LSTM(input_size=2, hidden_size=3, batch_first=True) >>> input = torch.randn(5,4,2) >>> h0 = torch.randn(1, 5, 3) >>> c0 = torch.randn(1, 5, 3) >>> output, (hn, cn) = rnn(input, (h0, c0)) >>> output tensor([[[-0.1046, -0.0316, -0.2261], [ 0.0702, 0.0756, -0.2856], [ 0.1146, 0.0666, -0.1841], [ 0.1137, 0.0508, -0.3966]], [[ 0.3702, -0.1192, -0.3513], [ 0.3964, -0.0513, -0.1744], [ 0.3144, 0.0564, -0.2114], [ 0.3056, 0.1312, -0.1656]], [[ 0.1581, -0.3509, 0.0068], [ 0.2391, -0.0308, 0.0773], [ 0.2420, 0.0607, -0.0652], [ 0.2854, 0.0656, -0.0306]], [[-0.0562, -0.0229, 0.1600], [-0.2156, -0.0006, 0.0898], [ 0.0700, 0.2200, -0.0068], [ 0.1903, 0.3120, 0.0253]], [[ 0.1025, -0.0167, 0.3068], [ 0.2028, 0.0652, 0.1738], [ 0.3324, 0.1645, 0.1908], [ 0.2594, 0.0896, -0.0507]]], grad_fn=<TransposeBackward0>) >>> hn tensor([[[ 0.1137, 0.0508, -0.3966], [ 0.3056, 0.1312, -0.1656], [ 0.2854, 0.0656, -0.0306], [ 0.1903, 0.3120, 0.0253], [ 0.2594, 0.0896, -0.0507]]], grad_fn=<StackBackward>) >>> cn tensor([[[ 0.3811, 0.2079, -0.7427], [ 0.9059, 0.2375, -0.3272], [ 0.5819, 0.1175, -0.0766], [ 0.5059, 0.5022, 0.0446], [ 0.7312, 0.2270, -0.0970]]], grad_fn=<StackBackward>) >>> output[-1] tensor([[ 0.1025, -0.0167, 0.3068], [ 0.2028, 0.0652, 0.1738], [ 0.3324, 0.1645, 0.1908], [ 0.2594, 0.0896, -0.0507]], grad_fn=<SelectBackward>) >>> output[:,:,-1] tensor([[-0.2261, -0.2856, -0.1841, -0.3966], [-0.3513, -0.1744, -0.2114, -0.1656], [ 0.0068, 0.0773, -0.0652, -0.0306], [ 0.1600, 0.0898, -0.0068, 0.0253], [ 0.3068, 0.1738, 0.1908, -0.0507]], grad_fn=<SelectBackward>) >>> output[:,-1,:] tensor([[ 0.1137, 0.0508, -0.3966], [ 0.3056, 0.1312, -0.1656], [ 0.2854, 0.0656, -0.0306], [ 0.1903, 0.3120, 0.0253], [ 0.2594, 0.0896, -0.0507]], grad_fn=<SliceBackward>) >>> output[:,-1,:].shape torch.Size([5, 3]) >>> output.shape torch.Size([5, 4, 3]) >>> hn.shape torch.Size([1, 5, 3]) >>> cn.shape torch.Size([1, 5, 3])