从代码层面理解Transformer


跑通

代码使用的是 https://github.com/jadore801120/attention-is-all-you-need-pytorch,

commit-id 为: 132907d

各模块粗览

Transformer

主要包括一堆参数,
以及encoder和decoder

图片[1] - 从代码层面理解Transformer - MaxSSL

forward的时候主要做了如下操作.

  1. 先 pad_mask
  2. 过encoder
  3. 过decoder
  4. 输出logit

图片[2] - 从代码层面理解Transformer - MaxSSL

从train.py 我们可以看出, 模型的输出直接去做了loss

图片[3] - 从代码层面理解Transformer - MaxSSL
这里的Loss就是cross_entropy.

然后每个encoder其实是一堆EncoderLayer的,每个decoder其实也是一堆DecoderLayer的, 所以先大致看一下.

Encoder

图片[4] - 从代码层面理解Transformer - MaxSSL
整体流程如下

  1. 输入 原始的src_seq, 得到word embeding, 叫做 enc_output
  2. 做position_encoder, 即位置编码.
  3. 做LayerNorm
  4. 过各个堆叠的encoder-layer, 每一个encoder-layer的输入都是上一层的输出.
  5. 返回最后一个encoder-layer的输出

什么是PositionEncoding

这个模块其实没有可以学习的参数. 这里的这个buffer的用法可以学习一下.
图片[5] - 从代码层面理解Transformer - MaxSSL
图片[6] - 从代码层面理解Transformer - MaxSSL
这里的这个实现还是挺简洁的. 一行就解决了.
图片[7] - 从代码层面理解Transformer - MaxSSL

Decocer

Decoder的结构和encoder的结构几乎一样,
图片[8] - 从代码层面理解Transformer - MaxSSL

但是要注意的是, Decnoder的输入.
图片[9] - 从代码层面理解Transformer - MaxSSL
图片[10] - 从代码层面理解Transformer - MaxSSL
也就是说, Decoder中的positionEncoding是对groundtruth做的.

这里我有一个疑问, 推理的时候, 没有trg_seq 的时候具体是怎么做的呢” />推理时细节

图片[11] - 从代码层面理解Transformer - MaxSSL
图片[12] - 从代码层面理解Transformer - MaxSSL

  1. 输入一个句子, 会先用encoder得到encoder_output.
  2. 同时会有 init_seq 传进decoder里面
    这里的init_seq 是

图片[13] - 从代码层面理解Transformer - MaxSSL

图片[14] - 从代码层面理解Transformer - MaxSSL
而 trg_box_idx 是一个常数. 即
图片[15] - 从代码层面理解Transformer - MaxSSL

  1. 然后从第2个词开始, 循环作为decoder模块的输入传进去.

图片[16] - 从代码层面理解Transformer - MaxSSL

各模块细节

EncoderLayer

图片[17] - 从代码层面理解Transformer - MaxSSL
如图每个EncoderLayer包括了, self-attention, 以及 positionFFN.

这里的self-attention是MultiHeadAttention.

MultiHeadAttention

看MultiHeadAttention的操作的话,主要是经历了以下主要的几个操作

图片[18] - 从代码层面理解Transformer - MaxSSL
把这个图画成下面这个样子来理解:

图片[19] - 从代码层面理解Transformer - MaxSSL

Attention

这里的Attention其实就是这个公式

图片[20] - 从代码层面理解Transformer - MaxSSL
代码里面叫做 ScaledDotProductionAttention

图片[21] - 从代码层面理解Transformer - MaxSSL

这里的temperature 是MultiHeadAttention的一个参数,
图片[22] - 从代码层面理解Transformer - MaxSSL

这里面需要注意的是, n_head这个参数,
在看知乎(https://zhuanlan.zhihu.com/p/48508221)上面的讲解时, 是这样的流程图
图片[23] - 从代码层面理解Transformer - MaxSSL
我理解其实是一样的, 一个是流程图解释,而一个是具体的实现方式.

PositionwiseFeedForward

图片[24] - 从代码层面理解Transformer - MaxSSL
也就是经过了两个全连接层, 然后过一个droupout, 过残差, 然后再过layer_norm

至此,整个encoder 其实就是两个模块, 一个是self-attention, 一个是FFN.
图片[25] - 从代码层面理解Transformer - MaxSSL
然后这里的self-attention,其实是Multi-Head-Attention.

DecoderLayer

decoder layer的模块其实和encoder的模块差不多,但是多了一个MultiHeadAttention, 这个叫encoder-decoder-attention.

图片[26] - 从代码层面理解Transformer - MaxSSL

forward的时候, 会把dec_input 分三份输入 self_attion模块中,

图片[27] - 从代码层面理解Transformer - MaxSSL

然后 encoder的output和上面的dec_output 作为encoder-decoder-attention的输入

图片[28] - 从代码层面理解Transformer - MaxSSL
最终返回 三个东西. 分别是

dec_output, dec_slf_attn, dec_enc_attn.

其它补充

mask在encoder里如何起作用的

我们先追溯被用在了哪里
首先是这样传入encoder的
图片[29] - 从代码层面理解Transformer - MaxSSL
又是这样被传入每个encoder-layer的
图片[30] - 从代码层面理解Transformer - MaxSSL
又是这样在每个encoder-layer被使用的.
图片[31] - 从代码层面理解Transformer - MaxSSL

即在MultiHeadAttention中是
图片[32] - 从代码层面理解Transformer - MaxSSL
最终确定是在这里被使用的

图片[33] - 从代码层面理解Transformer - MaxSSL
先来看一下 masked_fill如何使用, 它输入两个参数, 一个是mask, 一个是value,
也即是说, 在mask为1的那些地方, 把值改成value, 而mask为0的地方的值不改变.

即么上面的意思就是说在(mask0)的这些地方让attention的值变得非常地小, 而又经过了softmax之后, 也就是说这些mask0的地方的响应值为0. 或者说接近于0.
总结一句话就是, 传入的mask的作用. 会让其在值为mask=1的地方几乎不响应.

© 版权声明
THE END
喜欢就支持一下吧
点赞0 分享