Transformer架构拆解:从数学公式到代码实现
引言:从理论到实践的桥梁
Transformer架构的诞生彻底改变了自然语言处理(NLP)领域,其核心设计——自注意力机制(Self-Attention)与并行计算能力——使其在处理长序列任务时远超传统模型(如RNN、CNN)。本文将从数学公式出发,逐步拆解Transformer的关键模块,并通过PyTorch代码片段展示其实现逻辑,揭示“千亿参数”背后的工程智慧。
一、自注意力机制:从数学公式到矩阵运算
数学原理
• 输入定义:对于输入序列 ( X \in \mathbb{R}^{n \times d} )(n为序列长度,d为特征维度),通过线性变换生成查询(Query)、键(Key)、值(Value)矩阵:$$ Q = XW_Q, \quad K = XW_K, \quad V = XW_V $$
其中 ( W_Q, W_K, W_V \in \mathbb{R}^{d \times d_k} ) 为可学习参数。
• 注意力权重计算:$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$
缩放因子 ( \frac{1}{\sqrt{d_k}} ) 用于防止点积结果过大导致梯度消失。代码实现
import torch import torch.nn as nn import torch.nn.functional as F class SelfAttention(nn.Module): def __init__(self, d_model, d_k): super().__init__() self.W_Q = nn.Linear(d_model, d_k) self.W_K = nn.Linear(d_model, d_k) self.W_V = nn.Linear(d_model, d_k) def forward(self, x): Q = self.W_Q(x) # (batch_size, seq_len, d_k) K = self.W_K(x) # (batch_size, seq_len, d_k) V = self.W_V(x) # (batch_size, seq_len, d_k) scores = torch.matmul(Q, K.transpose(-2, -1)) / (K.size(-1) ** 0.5) attn_weights = F.softmax(scores, dim=-1) output = torch.matmul(attn_weights, V) return output关键点:
• 所有输入token并行计算,时间复杂度为 ( O(n^2 d) )。• 使用矩阵乘法实现高效并行化(GPU加速)。
二、多头注意力:分而治之的信息融合策略
数学原理
• 将Q、K、V拆分为h个“头”(Head),每个头独立学习不同子空间的语义关系:$$ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W_O $$
其中 ( \text{head}_i = \text{Attention}(QW_Q^{(i)}, KW_K^{(i)}, VW_V^{(i)}) ),( W_O \in \mathbb{R}^{h d_v \times d} ) 为输出投影矩阵。代码实现
class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super().__init__() assert d_model % num_heads == 0 self.d_k = d_model // num_heads self.num_heads = num_heads self.W_Q = nn.Linear(d_model, d_model) self.W_K = nn.Linear(d_model, d_model) self.W_V = nn.Linear(d_model, d_model) self.W_O = nn.Linear(d_model, d_model) def forward(self, x): batch_size = x.size(0) # 线性变换并分头 Q = self.W_Q(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) K = self.W_K(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) V = self.W_V(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) # 计算注意力并合并 scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5) attn_weights = F.softmax(scores, dim=-1) context = torch.matmul(attn_weights, V) context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k) return self.W_O(context)设计优势:
• 多头机制增强模型捕捉不同语义关系的能力(如语法结构、语义关联)。• 分头计算降低单个注意力矩阵的维度,减少内存占用。
三、位置编码:让序列“记住”顺序
数学原理
• 正弦位置编码:$$ PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d}}\right) $$
$$ PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d}}\right) $$
其中 ( pos ) 为位置,( i ) 为维度索引。代码实现
class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe.unsqueeze(0)) def forward(self, x): return x + self.pe[:, :x.size(1)]关键细节:
• 位置编码与输入嵌入相加,而非拼接,避免维度变化。• 正弦函数的周期性使模型能外推到训练时未见的序列长度。
