从零实现Transformer:深入理解注意力机制

本文是 I made a transformer by hand (no training!) 的复现及个人补充理解。原文通过实现 Transformer 并手动分配权重的方式,预测给定的简单序列,藉此理解 Transformer 结构中各个部分的具体作用。

目标预测任务

我们需要一个不太复杂但也不能太简单的序列作为预测目标,这里选定的序列是 aabaabaabaab...,输入给定的上下文长度,让 Transformer 预测并输出下一个字符是什么。

为什么不能是更简单的 ababab..?因为它甚至不需要 position embedding,只要一个 if-else 就可以正确预测。

为了让序列能进行数学运算,我们对字符进行编码:a0b1

模型

这里给出的 Transformer 结构是原文作者基于 GPT-2 的简化实现:

img

我们将设置其中涉及的每一个权重,让模型能够正确预测我们的序列。

Embedding Sequence

算法进行的第一个计算是:

x = wte[inputs] + wpe[range(len(inputs))]

其中 inputs 是输入的 token IDs(按照我们之前的编码约定,aabaa 对应的是列表 [0, 0, 1, 0, 0]);wte 为 Token 权重,将 token IDs 转换为向量;wpe 为位置权重,将每一个位置生成位置编码向量。

这里 range(len(inputs)) 实际上就是一个位置序列 [0, 1, 2, 3, 4]


在 PyTorch 中,wte[inputs] 这样的形式是标准的数组/张量索引操作,表示从 wte 这个矩阵中,按照 inputs 提供的索引值选取对应的行(或元素)。

例如:

wte = [
    [0.1, 0.2, 0.3], 
    [0.4, 0.5, 0.6]  
]

inputs = [0, 1, 0]

那么:

wte[inputs] = [
    [0.1, 0.2, 0.3],  # inputs[0]=0 → 取 wte 的第 0 行
    [0.4, 0.5, 0.6],  # inputs[1]=1 → 取 wte 的第 1 行
    [0.1, 0.2, 0.3]   # inputs[2]=0 → 取 wte 的第 0 行
]

我们需要以数学的方式来表示以下信息:某个位置上是某个字母。

在这里,我们以 one-hot 编码的方式来表示具体的位置和具体的字母。

由于我们的上下文长度为 5,共有 5 个位置,按照 one-hot,我们以下面这个矩阵来表示“究竟在哪一个位置上”,这就得到了 wpe

"wpe": np.array(
    # one-hot position embeddings
    [
      [1, 0, 0, 0, 0, 0, 0, 0],  # position 0
      [0, 1, 0, 0, 0, 0, 0, 0],  # position 1
      [0, 0, 1, 0, 0, 0, 0, 0],  # position 2
      [0, 0, 0, 1, 0, 0, 0, 0],  # position 3
      [0, 0, 0, 0, 1, 0, 0, 0],  # position 4
    ]
  ),

可以看到我们还有三列空间,用其中两列用来表示是 a 还是 b,这就是 wte

"wte": np.array(
    # one-hot token embeddings
    [
      [0, 0, 0, 0, 0, 1, 0, 0],  # token `a` (id 0)
      [0, 0, 0, 0, 0, 0, 1, 0],  # token `b` (id 1)
    ]
  ),

还有最后一列没用上,这是 Transformer 的临时空间。

以上述方式来表示序列 aabaa

对于位置信息,我们可以得到以下矩阵。这也是 wpe[range(len(inputs))]的运算结果,与 wpe 原矩阵一样(毕竟,我们确实在五个位置上都有具体的字符):

[
    [1, 0, 0, 0, 0, 0, 0, 0],  # position 0
    [0, 1, 0, 0, 0, 0, 0, 0],  # position 1
    [0, 0, 1, 0, 0, 0, 0, 0],  # position 2
    [0, 0, 0, 1, 0, 0, 0, 0],  # position 3
    [0, 0, 0, 0, 1, 0, 0, 0],  # position 4
]

对于字符信息,即 wte[inputs](回忆一下运算规则):

[
    [0, 0, 0, 0, 0, 1, 0, 0],  # inputs[0] = 0 → 取 wte 的第 0 行
    [0, 0, 0, 0, 0, 1, 0, 0],  # inputs[1] = 0 → 取 wte 的第 0 行
    [0, 0, 0, 0, 0, 0, 1, 0],  # inputs[2] = 0 → 取 wte 的第 1 行
    [0, 0, 0, 0, 0, 1, 0, 0],  # inputs[3] = 0 → 取 wte 的第 0 行
    [0, 0, 0, 0, 0, 1, 0, 0],  # inputs[4] = 0 → 取 wte 的第 0 行
]

然后,我们把上面两个矩阵相加,就能以数学的方式表示“哪个位置上有哪个字符”了,也就是 x = wte[inputs] + wpe[range(len(inputs))] 的结果:

[
    [1, 0, 0, 0, 0, 1, 0, 0],  # position 0 字符是 'a'
    [0, 1, 0, 0, 0, 1, 0, 0],  # position 1 字符是 'a'
    [0, 0, 1, 0, 0, 0, 1, 0],  # position 2 字符是 'b'
    [0, 0, 0, 1, 0, 1, 0, 0],  # position 3 字符是 'a'
    [0, 0, 0, 0, 1, 1, 0, 0],  # position 4 字符是 'a'
]

Transformer Block

我们的模型中只使用一个 Transformer 模块,且模块只由一个 attention head + 一个 linear network(c_proj)构成。后者的作用是将 attention 结果矩阵投影回到我们编码时使用的 5 * 8 矩阵。

在通常的实现中,Transformer 模块具有多个并行的 attention head,即多头注意力,但简化起见,我们只使用一个。

首先,我们得先通过 c_attn 得到 Q, K, V 矩阵。

c_attn 是一个全连接层,有偏置值 b 和权重 w。显然,我们也得设置一个合适的权重 w 得到一个合适的 Q, K, V

但在这里我们无法直接设置它,我们需要从可理解的(同时也是正确的) Q, K, V 取值来反推才能理解为什么是这个权重值……所以,我们跳过这个 c_attn,来看看Q, K, V 应该是什么样子。

def attention(q, k, v, mask):
  return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v

image-20250619002715289

首先让我们专注 qk

k 很眼熟——对,这就是我们的位置编码 wpe

q 又是啥呢?在注意力计算公式中,我们计算了 q @ k.T,即 qk 的转置相乘,得到 5*5 的矩阵:

image-20250619003023694

看起来还是一头雾水,但如果我们对这个结果加上 mask,再用 softmax 映射一下,将得到(这里的 softmax 是一个激活函数,可以将 (-∞, +∞) 的值映射到 (0, 1) 中):

image-20250619003232091

为什么要加上 mask?这是一种“屏蔽”机制,避免大模型将注意力放到“未来的token”上,提前剧透答案而给它作弊的可能性。它的具体值是:

image-20250619003321525

回到 softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) 的结果,也就是以下这个矩阵:

image-20250619003639984

从上往下的每一行,都代表模型在预测不同 token 位置时,需要关注的位置以及所要投放的注意力。

为了预测 token 2,模型必须且只能”注意“ token 1,所以考虑它的概率是 1.

为了预测 token 3,模型只能”注意“ token 2 和 token 1,并均分注意力,各为 0.5(这是我们人为设置的 c_attn 计算得到的结果。如果一切都是训练得出的,也可能是 0.4 + 0.6 或其他组合)。

以此类推。但有一点是不变的——模型不能“注意”未来的 token。

然后我们得考虑怎么得到预测结果了——模型需要给出 0 或者 1。我们得有一种赋值方式,让模型能结合这个注意力矩阵来计算得到正确的结果。

所以,我们的最后一步是让这个注意力矩阵与 v 相乘,这个v 就是我们需要的赋值方式:

image-20250619004649954

结合我们的 wte[inputs],它的意义是显然的:将 one-hot token 编码(a = [1, 0], b = [0, 1])转换为了 1/-1 编码。

def attention(q, k, v, mask):
  return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v

这样一来,上面这段代码的计算结果是:

image-20250619004920906

我们得到了正确的 0/1 结果,每一行的最后一列的值都表示模型对该行的预测结果:

  • 第一行,输入:a,只关注 a,输出 b
  • 第二行,输入:aa,分别关注 aa,输出 b
  • 第三行,输入:aab,分别关注ab,输出a
  • 第四行,输入:aaba,分别关注ba,输出a
  • 第五行,输入:aabaa,分别关注aa,输出b

可以看到大部分预测结果都是正确的。只有第一行介于正确和错误之间,这是因为它没有足够的数据——只看到了一个 a,而下一个字符可能是 a 也可能是 b。对于我们的序列 aabaab...来说,这种模糊是确实存在的。

理解完 Q, K, V 及 attention 的计算后,我们再回头看 c_attn,它的权重设置目标总结如下(由于篇幅有限,这里不放出它的具体值,可以直接查看原文):

因此,总结来说,c_attn 权重的作用是:

  1. 将位置编码映射到 q 的“注意力窗口”;
  2. 将位置编码提取到 k
  3. 将 token 嵌入转换为 v 中的 1/-1 编码。

qk 通过 softmax(q @ k.T / ... + mask) 结合时,会得到一个 5 x 5 的矩阵,其中:

  • 第一行仅关注第一个 token;
  • 其他行均等关注最近的两个 token。

最后,通过 softmax(...) @ v 的加法抵消,模型会在:

  • 应预测 "a" 时,使最后一列输出 0
  • 应预测 "b" 时,使最后一列输出 1

Projecting back to embedding space