1. 程式人生 > 其它 >PyTorch 中 LSTM 的 output、h_n 和 c_n 之間的關係

PyTorch 中 LSTM 的 output、h_n 和 c_n 之間的關係

技術標籤:PyTorch 基礎

LSTM 簡介

  • 官方文件:https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
  • h_n:最後一個時間步的輸出,即 h_n = output[:, -1, :](一般可以直接輸入到後續的全連線層,在 Keras 中通過設定引數 return_sequences=False 獲得)
  • c_n:最後一個時間步 LSTM cell 的狀態(一般用不到)

例項

  • 例項:根據紅框可以直觀看出,h_n 是最後一個時間步的輸出,即是 h_n = output[:, -1, :],如何還是無法直觀理解,直接看如下截圖,對照程式碼可以非常容易看出它們的關係

  • 例項程式碼:

>>> import torch
>>> import torch.nn as nn
>>> rnn = nn.LSTM(input_size=2, hidden_size=3, batch_first=True)
>>> input = torch.randn(5,4,2)
>>> h0 = torch.randn(1, 5, 3)
>>> c0 = torch.randn(1, 5, 3)
>>> output, (hn, cn) = rnn(input, (h0, c0))
>>> output
tensor([[[-0.1046, -0.0316, -0.2261],
         [ 0.0702,  0.0756, -0.2856],
         [ 0.1146,  0.0666, -0.1841],
         [ 0.1137,  0.0508, -0.3966]],

        [[ 0.3702, -0.1192, -0.3513],
         [ 0.3964, -0.0513, -0.1744],
         [ 0.3144,  0.0564, -0.2114],
         [ 0.3056,  0.1312, -0.1656]],

        [[ 0.1581, -0.3509,  0.0068],
         [ 0.2391, -0.0308,  0.0773],
         [ 0.2420,  0.0607, -0.0652],
         [ 0.2854,  0.0656, -0.0306]],

        [[-0.0562, -0.0229,  0.1600],
         [-0.2156, -0.0006,  0.0898],
         [ 0.0700,  0.2200, -0.0068],
         [ 0.1903,  0.3120,  0.0253]],

        [[ 0.1025, -0.0167,  0.3068],
         [ 0.2028,  0.0652,  0.1738],
         [ 0.3324,  0.1645,  0.1908],
         [ 0.2594,  0.0896, -0.0507]]], grad_fn=<TransposeBackward0>)
>>> hn
tensor([[[ 0.1137,  0.0508, -0.3966],
         [ 0.3056,  0.1312, -0.1656],
         [ 0.2854,  0.0656, -0.0306],
         [ 0.1903,  0.3120,  0.0253],
         [ 0.2594,  0.0896, -0.0507]]], grad_fn=<StackBackward>)
>>> cn
tensor([[[ 0.3811,  0.2079, -0.7427],
         [ 0.9059,  0.2375, -0.3272],
         [ 0.5819,  0.1175, -0.0766],
         [ 0.5059,  0.5022,  0.0446],
         [ 0.7312,  0.2270, -0.0970]]], grad_fn=<StackBackward>)
>>> output[-1]
tensor([[ 0.1025, -0.0167,  0.3068],
        [ 0.2028,  0.0652,  0.1738],
        [ 0.3324,  0.1645,  0.1908],
        [ 0.2594,  0.0896, -0.0507]], grad_fn=<SelectBackward>)
>>> output[:,:,-1]
tensor([[-0.2261, -0.2856, -0.1841, -0.3966],
        [-0.3513, -0.1744, -0.2114, -0.1656],
        [ 0.0068,  0.0773, -0.0652, -0.0306],
        [ 0.1600,  0.0898, -0.0068,  0.0253],
        [ 0.3068,  0.1738,  0.1908, -0.0507]], grad_fn=<SelectBackward>)
>>> output[:,-1,:]
tensor([[ 0.1137,  0.0508, -0.3966],
        [ 0.3056,  0.1312, -0.1656],
        [ 0.2854,  0.0656, -0.0306],
        [ 0.1903,  0.3120,  0.0253],
        [ 0.2594,  0.0896, -0.0507]], grad_fn=<SliceBackward>)
>>> output[:,-1,:].shape
torch.Size([5, 3])
>>> output.shape
torch.Size([5, 4, 3])
>>> hn.shape
torch.Size([1, 5, 3])
>>> cn.shape
torch.Size([1, 5, 3])