跑通

代码使用的是 https://github.com/jadore801120/attention-is-all-you-need-pytorch,

commit-id 为: 132907d

各模块粗览

Transformer

主要包括一堆参数,
以及encoder和decoder

forward的时候主要做了如下操作.

  1. 先 pad_mask
  2. 过encoder
  3. 过decoder
  4. 输出logit

从train.py 我们可以看出, 模型的输出直接去做了loss


这里的Loss就是cross_entropy.

然后每个encoder其实是一堆EncoderLayer的,每个decoder其实也是一堆DecoderLayer的, 所以先大致看一下.

Encoder


整体流程如下

  1. 输入 原始的src_seq, 得到word embeding, 叫做 enc_output
  2. 做position_encoder, 即位置编码.
  3. 做LayerNorm
  4. 过各个堆叠的encoder-layer, 每一个encoder-layer的输入都是上一层的输出.
  5. 返回最后一个encoder-layer的输出

什么是PositionEncoding

这个模块其实没有可以学习的参数. 这里的这个buffer的用法可以学习一下.


这里的这个实现还是挺简洁的. 一行就解决了.

Decocer

Decoder的结构和encoder的结构几乎一样,

但是要注意的是, Decnoder的输入.


也就是说, Decoder中的positionEncoding是对groundtruth做的.

这里我有一个疑问, 推理的时候, 没有trg_seq 的时候具体是怎么做的呢” />推理时细节


  1. 输入一个句子, 会先用encoder得到encoder_output.
  2. 同时会有 init_seq 传进decoder里面
    这里的init_seq 是


而 trg_box_idx 是一个常数. 即

  1. 然后从第2个词开始, 循环作为decoder模块的输入传进去.

各模块细节

EncoderLayer


如图每个EncoderLayer包括了, self-attention, 以及 positionFFN.

这里的self-attention是MultiHeadAttention.

MultiHeadAttention

看MultiHeadAttention的操作的话,主要是经历了以下主要的几个操作


把这个图画成下面这个样子来理解:

Attention

这里的Attention其实就是这个公式


代码里面叫做 ScaledDotProductionAttention

这里的temperature 是MultiHeadAttention的一个参数,

这里面需要注意的是, n_head这个参数,
在看知乎(https://zhuanlan.zhihu.com/p/48508221)上面的讲解时, 是这样的流程图

我理解其实是一样的, 一个是流程图解释,而一个是具体的实现方式.

PositionwiseFeedForward


也就是经过了两个全连接层, 然后过一个droupout, 过残差, 然后再过layer_norm

至此,整个encoder 其实就是两个模块, 一个是self-attention, 一个是FFN.

然后这里的self-attention,其实是Multi-Head-Attention.

DecoderLayer

decoder layer的模块其实和encoder的模块差不多,但是多了一个MultiHeadAttention, 这个叫encoder-decoder-attention.

forward的时候, 会把dec_input 分三份输入 self_attion模块中,

然后 encoder的output和上面的dec_output 作为encoder-decoder-attention的输入


最终返回 三个东西. 分别是

dec_output, dec_slf_attn, dec_enc_attn.

其它补充

mask在encoder里如何起作用的

我们先追溯被用在了哪里
首先是这样传入encoder的

又是这样被传入每个encoder-layer的

又是这样在每个encoder-layer被使用的.

即在MultiHeadAttention中是

最终确定是在这里被使用的


先来看一下 masked_fill如何使用, 它输入两个参数, 一个是mask, 一个是value,
也即是说, 在mask为1的那些地方, 把值改成value, 而mask为0的地方的值不改变.

即么上面的意思就是说在(mask0)的这些地方让attention的值变得非常地小, 而又经过了softmax之后, 也就是说这些mask0的地方的响应值为0. 或者说接近于0.
总结一句话就是, 传入的mask的作用. 会让其在值为mask=1的地方几乎不响应.