本文是 I made a transformer by hand (no training!) 的复现及个人补充理解。原文通过实现 Transformer 并手动分配权重的方式,预测给定的简单序列,藉此理解 Transformer 结构中各个部分的具体作用。
目标预测任务
我们需要一个不太复杂但也不能太简单的序列作为预测目标,这里选定的序列是 aabaabaabaab...,输入给定的上下文长度,让 Transformer 预测并输出下一个字符是什么。
为什么不能是更简单的 ababab..?因为它甚至不需要 position embedding,只要一个 if-else 就可以正确预测。
为了让序列能进行数学运算,我们对字符进行编码:a→0,b→1。
模型
这里给出的 Transformer 结构是原文作者基于 GPT-2 的简化实现:
我们将设置其中涉及的每一个权重,让模型能够正确预测我们的序列。
Embedding Sequence
算法进行的第一个计算是:
x = wte[inputs] + wpe[range(len(inputs))]
其中 inputs 是输入的 token IDs(按照我们之前的编码约定,aabaa 对应的是列表 [0, 0, 1, 0, 0]);wte 为 Token 权重,将 token IDs 转换为向量;wpe 为位置权重,将每一个位置生成位置编码向量。
这里 range(len(inputs)) 实际上就是一个位置序列 [0, 1, 2, 3, 4]。
在 PyTorch 中,wte[inputs] 这样的形式是标准的数组/张量索引操作,表示从 wte 这个矩阵中,按照 inputs 提供的索引值选取对应的行(或元素)。
例如:
wte = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6]
]
inputs = [0, 1, 0]
那么:
wte[inputs] = [
[0.1, 0.2, 0.3], # inputs[0]=0 → 取 wte 的第 0 行
[0.4, 0.5, 0.6], # inputs[1]=1 → 取 wte 的第 1 行
[0.1, 0.2, 0.3] # inputs[2]=0 → 取 wte 的第 0 行
]
我们需要以数学的方式来表示以下信息:某个位置上是某个字母。
在这里,我们以 one-hot 编码的方式来表示具体的位置和具体的字母。
由于我们的上下文长度为 5,共有 5 个位置,按照 one-hot,我们以下面这个矩阵来表示“究竟在哪一个位置上”,这就得到了 wpe:
"wpe": np.array(
# one-hot position embeddings
[
[1, 0, 0, 0, 0, 0, 0, 0], # position 0
[0, 1, 0, 0, 0, 0, 0, 0], # position 1
[0, 0, 1, 0, 0, 0, 0, 0], # position 2
[0, 0, 0, 1, 0, 0, 0, 0], # position 3
[0, 0, 0, 0, 1, 0, 0, 0], # position 4
]
),
可以看到我们还有三列空间,用其中两列用来表示是 a 还是 b,这就是 wte:
"wte": np.array(
# one-hot token embeddings
[
[0, 0, 0, 0, 0, 1, 0, 0], # token `a` (id 0)
[0, 0, 0, 0, 0, 0, 1, 0], # token `b` (id 1)
]
),
还有最后一列没用上,这是 Transformer 的临时空间。
以上述方式来表示序列 aabaa。
对于位置信息,我们可以得到以下矩阵。这也是 wpe[range(len(inputs))]的运算结果,与 wpe 原矩阵一样(毕竟,我们确实在五个位置上都有具体的字符):
[
[1, 0, 0, 0, 0, 0, 0, 0], # position 0
[0, 1, 0, 0, 0, 0, 0, 0], # position 1
[0, 0, 1, 0, 0, 0, 0, 0], # position 2
[0, 0, 0, 1, 0, 0, 0, 0], # position 3
[0, 0, 0, 0, 1, 0, 0, 0], # position 4
]
对于字符信息,即 wte[inputs](回忆一下运算规则):
[
[0, 0, 0, 0, 0, 1, 0, 0], # inputs[0] = 0 → 取 wte 的第 0 行
[0, 0, 0, 0, 0, 1, 0, 0], # inputs[1] = 0 → 取 wte 的第 0 行
[0, 0, 0, 0, 0, 0, 1, 0], # inputs[2] = 0 → 取 wte 的第 1 行
[0, 0, 0, 0, 0, 1, 0, 0], # inputs[3] = 0 → 取 wte 的第 0 行
[0, 0, 0, 0, 0, 1, 0, 0], # inputs[4] = 0 → 取 wte 的第 0 行
]
然后,我们把上面两个矩阵相加,就能以数学的方式表示“哪个位置上有哪个字符”了,也就是 x = wte[inputs] + wpe[range(len(inputs))] 的结果:
[
[1, 0, 0, 0, 0, 1, 0, 0], # position 0 字符是 'a'
[0, 1, 0, 0, 0, 1, 0, 0], # position 1 字符是 'a'
[0, 0, 1, 0, 0, 0, 1, 0], # position 2 字符是 'b'
[0, 0, 0, 1, 0, 1, 0, 0], # position 3 字符是 'a'
[0, 0, 0, 0, 1, 1, 0, 0], # position 4 字符是 'a'
]
Transformer Block
我们的模型中只使用一个 Transformer 模块,且模块只由一个 attention head + 一个 linear network(c_proj)构成。后者的作用是将 attention 结果矩阵投影回到我们编码时使用的 5 * 8 矩阵。
在通常的实现中,Transformer 模块具有多个并行的 attention head,即多头注意力,但简化起见,我们只使用一个。
首先,我们得先通过 c_attn 得到 Q, K, V 矩阵。
c_attn 是一个全连接层,有偏置值 b 和权重 w。显然,我们也得设置一个合适的权重 w 得到一个合适的 Q, K, V。
但在这里我们无法直接设置它,我们需要从可理解的(同时也是正确的) Q, K, V 取值来反推才能理解为什么是这个权重值……所以,我们跳过这个 c_attn,来看看Q, K, V 应该是什么样子。
def attention(q, k, v, mask):
return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v

首先让我们专注 q 和 k。
k 很眼熟——对,这就是我们的位置编码 wpe。
q 又是啥呢?在注意力计算公式中,我们计算了 q @ k.T,即 q 与 k 的转置相乘,得到 5*5 的矩阵:

看起来还是一头雾水,但如果我们对这个结果加上 mask,再用 softmax 映射一下,将得到(这里的 softmax 是一个激活函数,可以将 (-∞, +∞) 的值映射到 (0, 1) 中):

为什么要加上 mask?这是一种“屏蔽”机制,避免大模型将注意力放到“未来的token”上,提前剧透答案而给它作弊的可能性。它的具体值是:

回到 softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) 的结果,也就是以下这个矩阵:

从上往下的每一行,都代表模型在预测不同 token 位置时,需要关注的位置以及所要投放的注意力。
为了预测 token 2,模型必须且只能”注意“ token 1,所以考虑它的概率是 1.
为了预测 token 3,模型只能”注意“ token 2 和 token 1,并均分注意力,各为 0.5(这是我们人为设置的 c_attn 计算得到的结果。如果一切都是训练得出的,也可能是 0.4 + 0.6 或其他组合)。
以此类推。但有一点是不变的——模型不能“注意”未来的 token。
然后我们得考虑怎么得到预测结果了——模型需要给出 0 或者 1。我们得有一种赋值方式,让模型能结合这个注意力矩阵来计算得到正确的结果。
所以,我们的最后一步是让这个注意力矩阵与 v 相乘,这个v 就是我们需要的赋值方式:

结合我们的 wte[inputs],它的意义是显然的:将 one-hot token 编码(a = [1, 0], b = [0, 1])转换为了 1/-1 编码。
def attention(q, k, v, mask):
return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v
这样一来,上面这段代码的计算结果是:

我们得到了正确的 0/1 结果,每一行的最后一列的值都表示模型对该行的预测结果:
- 第一行,输入:
a,只关注a,输出b - 第二行,输入:
aa,分别关注aa,输出b - 第三行,输入:
aab,分别关注ab,输出a - 第四行,输入:
aaba,分别关注ba,输出a - 第五行,输入:
aabaa,分别关注aa,输出b
可以看到大部分预测结果都是正确的。只有第一行介于正确和错误之间,这是因为它没有足够的数据——只看到了一个 a,而下一个字符可能是 a 也可能是 b。对于我们的序列 aabaab...来说,这种模糊是确实存在的。
理解完 Q, K, V 及 attention 的计算后,我们再回头看 c_attn,它的权重设置目标总结如下(由于篇幅有限,这里不放出它的具体值,可以直接查看原文):
因此,总结来说,
c_attn权重的作用是:
- 将位置编码映射到
q的“注意力窗口”;- 将位置编码提取到
k;- 将 token 嵌入转换为
v中的1/-1编码。当
q和k通过softmax(q @ k.T / ... + mask)结合时,会得到一个5 x 5的矩阵,其中:
- 第一行仅关注第一个 token;
- 其他行均等关注最近的两个 token。
最后,通过
softmax(...) @ v的加法抵消,模型会在:
- 应预测
"a"时,使最后一列输出0;- 应预测
"b"时,使最后一列输出1。