1. 程式人生 > >pytorch筆記:09)Attention機制

pytorch筆記:09)Attention機制

剛從影象處理的hole中攀爬出來,剛走一步竟掉到了另一個hole(fire in the hole**)

1.RNN中的attention
pytorch官方教程:https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
首先,RNN的輸入大小都是(1,1,hidden_size),即batch=1,seq_len=1,hidden_size=embed_size,相對於傳統的encoder-decoder模型,attention機制僅在decoder處有所不同。下面具體看看:
1>儲存了rnn每個詞向量對應隱藏層的輸出狀態(encoder_outputs

),用於decoder的attention機制

#train程式碼部分
for ei in range(input_length):
	encoder_output, encoder_hidden = encoder(
		input_tensor[ei], encoder_hidden)
	encoder_outputs[ei] = encoder_output[0, 0]

2>AttnDecoderRNN的forward
1.輸入的input經過embed

embedded = self.embedding(input).view(1, 1, -1)
embedded = self.
dropout(embedded)

2.獲取關於輸入的attention權重,這裡的Q=decoder_rnn的input,K=decoder_rnn的隱藏元
2.1求Q和K相似度的方法有很多,這裡讓全連線層自己來學習,把embedded和hidden連線在一起經過fc層(部分修改了下)

similarity=self.attn(torch.cat((embedded[0], hidden[0]), 1))

2.2 經過softmax獲得歸一化的權重

attn_weights = F.softmax(similarity, dim=1)

3.權重應用於encoder輸出的所有詞對應的詞向量上(對應相乘即可)->獲得attention結果

attn_applied = torch.bmm(attn_weights.unsqueeze(0),encoder_outputs.unsqueeze(0))

4.把attention結果和decoder的輸入cat在一起,使用1個全連線層來融合二者,最終生成帶注意力機制的詞向量

output = torch.cat((embedded[0], attn_applied[0]), 1)
output = self.attn_combine(output).unsqueeze(0)

5.根據decoder的上一個輸出單詞來預測下一個單詞,這裡多插一句,decoder的首個輸入為起始標誌符’sos’,其根據encode最後的隱藏元來預測第一個單詞,後面依次類推。

output = F.relu(output)
output, hidden = self.gru(output, hidden)
output = F.log_softmax(self.out(output[0]), dim=1)
return output, hidden, attn_weights

2.transformer中的attention
“Attention is All You Need”(霸氣標題),pytorch程式碼推薦2篇:
哈佛大學NLP研究組:http://nlp.seas.harvard.edu/2018/04/03/attention.html
臺灣小哥的程式碼(較通俗):https://github.com/jadore801120/attention-is-all-you-need-pytorch:

下面以soft_attention為例(*input和output的attention,僅和self_attention做下區分,第1篇程式碼標記src_attn,第2篇程式碼標記dec_enc_attn),soft_attention的目標:給定序列Q(query,長度記為lq,維度dk),鍵序列K(key,長度記為lk,維度dk),值序列V(value,長度記為lv,維度dv),計算Q和K的相似度權重,最後再乘上V。

下面直接貼上attention-is-all-you-need-pytorch中MultiHeadAttention程式碼

    def forward(self, q, k, v, mask=None):
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        sz_b, len_q, _ = q.size()
        sz_b, len_k, _ = k.size()
        sz_b, len_v, _ = v.size()
        residual = q
        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
        #這裡把batch和分塊數放在一起,便於使用bmm
        q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv

        mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
        output, attn = self.attention(q, k, v, mask=mask)

        output = output.view(n_head, sz_b, len_q, d_v)
        output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)

        output = self.dropout(self.fc(output))
        output = self.layer_norm(output + residual)
        return output, attn

和RNN中的attention的不同,這裡的batch_size和seq_len均不為1,其把序列視為一個整體,求Q和V的相似度可使用點乘(V可以視為上面提及的encoder_outputs),獲得的是一個相似度矩陣,比如Q是一個長度為10的序列,K是一個長度為16的序列,其相似度矩陣就是一個10*16的矩陣,再如矩陣第一行表示Q的第一個單詞和K序列所有單詞的相似度。
s i m i l a r i t y : = ( l q , d k ) ( d k , l k ) = ( l q , l k ) similarity:=(lq,dk)*(dk,lk)=(lq,lk)

然後,生成帶注意力機制的詞向量(通常K和V取相同的值,因而有lv=lk),另外上面整合attn_applied和input使用的是cat操作,而這裡使用的是殘差(類似於unet和resnet),最後使用PositionwiseFeedForward(2個fc層)來融合attn_applied和input,最終生成帶注意力機制的詞向量。
a t t e n t i o n _ a p p l i e d = ( l q , l k ) ( l v , d v ) = ( l q , d v ) attention\_applied=(lq,lk)*(lv,dv)=(lq,dv)


細節部分
在資料預處理部分,對序列s都進行了首尾標記,比如s=’’+ s + ‘’,剛看transform(之前跳過了seq2seq),對下面的程式碼甚是不解

decoder_input=target_seq[:, :-1] #這裡不是去掉終止標記<eos>,去掉的可能是padding_0,只為相容target_ground_y的序列長度?
encoder_input=input_seq[:, 1:]  #encoder的輸入序列去掉了起始標記<sos>
target_ground_y= target_seqtrg[:, 1:] #用於計算模型loss的target,去掉了起始標記<sos>

其實在pytorch官方教程中說的比較清楚,看下圖
seq2seq
encoder的輸入序列和ground_true只需要一個終止符即可,而decoder的輸入序列開始必須指定一個起始符,讓其根據context預測輸出序列的第一個單詞,後面根據前一個單詞再預測下一個單詞,依次類推直到當前預測的單詞為終止標記’eos’,才計算loss.