Transformer架构拆解:从数学公式到代码实现

引言:从理论到实践的桥梁
Transformer架构的诞生彻底改变了自然语言处理(NLP)领域,其核心设计——自注意力机制(Self-Attention)与并行计算能力——使其在处理长序列任务时远超传统模型(如RNN、CNN)。本文将从数学公式出发,逐步拆解Transformer的关键模块,并通过PyTorch代码片段展示其实现逻辑,揭示“千亿参数”背后的工程智慧。


一、自注意力机制:从数学公式到矩阵运算

  1. 数学原理
    • 输入定义:对于输入序列 ( X \in \mathbb{R}^{n \times d} )(n为序列长度,d为特征维度),通过线性变换生成查询(Query)、键(Key)、值(Value)矩阵:

    $$ Q = XW_Q, \quad K = XW_K, \quad V = XW_V $$
    其中 ( W_Q, W_K, W_V \in \mathbb{R}^{d \times d_k} ) 为可学习参数。
    • 注意力权重计算:

    $$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$
    缩放因子 ( \frac{1}{\sqrt{d_k}} ) 用于防止点积结果过大导致梯度消失。

  2. 代码实现

    import torch  
    import torch.nn as nn  
    import torch.nn.functional as F  
    
    class SelfAttention(nn.Module):  
        def __init__(self, d_model, d_k):  
            super().__init__()  
            self.W_Q = nn.Linear(d_model, d_k)  
            self.W_K = nn.Linear(d_model, d_k)  
            self.W_V = nn.Linear(d_model, d_k)  
    
        def forward(self, x):  
            Q = self.W_Q(x)  # (batch_size, seq_len, d_k)  
            K = self.W_K(x)  # (batch_size, seq_len, d_k)  
            V = self.W_V(x)  # (batch_size, seq_len, d_k)  
    
            scores = torch.matmul(Q, K.transpose(-2, -1)) / (K.size(-1) ** 0.5)  
            attn_weights = F.softmax(scores, dim=-1)  
            output = torch.matmul(attn_weights, V)  
            return output  

    关键点:
    • 所有输入token并行计算,时间复杂度为 ( O(n^2 d) )。

    • 使用矩阵乘法实现高效并行化(GPU加速)。


二、多头注意力:分而治之的信息融合策略

  1. 数学原理
    • 将Q、K、V拆分为h个“头”(Head),每个头独立学习不同子空间的语义关系:

    $$ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W_O $$
    其中 ( \text{head}_i = \text{Attention}(QW_Q^{(i)}, KW_K^{(i)}, VW_V^{(i)}) ),( W_O \in \mathbb{R}^{h d_v \times d} ) 为输出投影矩阵。

  2. 代码实现

    class MultiHeadAttention(nn.Module):  
        def __init__(self, d_model, num_heads):  
            super().__init__()  
            assert d_model % num_heads == 0  
            self.d_k = d_model // num_heads  
            self.num_heads = num_heads  
    
            self.W_Q = nn.Linear(d_model, d_model)  
            self.W_K = nn.Linear(d_model, d_model)  
            self.W_V = nn.Linear(d_model, d_model)  
            self.W_O = nn.Linear(d_model, d_model)  
    
        def forward(self, x):  
            batch_size = x.size(0)  
    
            # 线性变换并分头  
            Q = self.W_Q(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)  
            K = self.W_K(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)  
            V = self.W_V(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)  
    
            # 计算注意力并合并  
            scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)  
            attn_weights = F.softmax(scores, dim=-1)  
            context = torch.matmul(attn_weights, V)  
            context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)  
    
            return self.W_O(context)  

    设计优势:
    • 多头机制增强模型捕捉不同语义关系的能力(如语法结构、语义关联)。

    • 分头计算降低单个注意力矩阵的维度,减少内存占用。


三、位置编码:让序列“记住”顺序

  1. 数学原理
    • 正弦位置编码:

    $$ PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d}}\right) $$
    $$ PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d}}\right) $$
    其中 ( pos ) 为位置,( i ) 为维度索引。

  2. 代码实现

    class PositionalEncoding(nn.Module):  
        def __init__(self, d_model, max_len=5000):  
            super().__init__()  
            pe = torch.zeros(max_len, d_model)  
            position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  
            div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))  
    
            pe[:, 0::2] = torch.sin(position * div_term)  
            pe[:, 1::2] = torch.cos(position * div_term)  
            self.register_buffer('pe', pe.unsqueeze(0))  
    
        def forward(self, x):  
            return x + self.pe[:, :x.size(1)]  

    关键细节:
    • 位置编码与输入嵌入相加,而非拼接,避免维度变化。

    • 正弦函数的周期性使模型能外推到训练时未见的序列长度。







扫描下方二维码,关注公众号:程序进阶之路,实时获取更多优质文章推送。


扫码关注

评论