解析transformer——4: Add&Norm,残差连接与layernorm
本文介绍了Transformer中的Add&Norm模块,详细解析了归一化、LayerNorm与BatchNorm的区别以及残差连接的作用。归一化通过缩放数据解决梯度下降问题;LayerNorm按样本计算均值和方差,适合序列模型;残差连接则用于构建深层网络避免梯度消失。文章还提供了代码实现,展示如何在PyTorch中应用这些技术。
文章目录
transformer中的Add&Norm
- Add&Norm模块位于子模块Mutil-Head Attention与Feed Forward之后
- 在这个模块要理解3个问题:
- 归一化是什么? (norm)
- layernorm与batchnorm区别? (norm)
- 残差连接residual connection是什么? (add)
1.1 归一化
归一化(Normalization) 是指将数据按照比例缩放到一个特定的范围的过程。 可以用于协变量偏移(covariate shift)问题。
除了归一化,还常常提到标准化(Standardization),我个人认为可以直接将标准化视为归一化的一个具体方法。
好处:一定程度上可以解决数据中某些特征取值过大导致梯度下降过慢的问题。
一种归一化方法z-score标准化:
x ∗ = x − μ σ x^*=\frac{x-\mu}{\sigma} x∗=σx−μ
[4, 8, 3] (u = 5, σ = 2.16 \sigma=2.16 σ=2.16) --z-score–> [ 4 − 5 2.16 \frac{4-5}{2.16} 2.164−5, 1.39, -0.93]
参考资料:
如何理解归一化(normalization)?
归一化 (Normalization)、标准化 (Standardization)
1.2 layernorm与batchnorm
小例子
按照一个一个样本计算mean( μ \mu μ)与std( σ \sigma σ)即为layernorm,图的上面
按照一个一个特征计算mean( μ \mu μ)与std( σ \sigma σ)即为batchnorm,图的下面
layernorm具体实现
y = x − μ σ 2 + ϵ ∗ W + b y = \frac{x-\mu}{\sqrt{\sigma^2+\epsilon}} * W + b y=σ2+ϵx−μ∗W+b
公式如上,即进行归一化后,加上一个线性变换,线性变换的参数W,b通过学习获得。
ϵ \epsilon ϵ是为了防止分母为0设置的极小常数 。
为什么transformer中使用layernorm而不是batchnorm?
对于这种序列模型,包括rnn,transformer,采用layernorm大概率会优于使用batchnorm,可能原因如下:
对于transformer比如输入两个句子:
输入1: I love AI
输入2: Transformers are powerful models
经过embedding后:
输入1
[0.1, 0.2, 0.3, 0.4, 0, 0, 0, 0] # I
[0.5, 0.6, 0.7, 0.8, 0, 0, 0, 0] # love
[0.9, 1.0, 1.1, 1.2, 0, 0, 0, 0] # AI
[0.0, 0.0, 0.0, 0.0, 0, 0, 0, 0] # PAD (无意义)
输入2
[0.3, 0.2, 0.1, 0.4, 0.5, 0.6, 0.7, 0.8] # Transformers
[0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6] # are
[0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4] # powerful
[0.5, 0.4, 0.3, 0.2, 0.1, 0.0, 0.0, 0.0] #models
显然进行batchnorm时,models会与PAD进行计算作为第4个特征。对于这种序列模型,输入的位置1,位置2,位置3并不代表特征,位置2的love同样会出现在位置1或者位置3。
1.3 残差连接
残差连接目的是为了构建更深的网络,避免梯度消失。
参考资料:
残差详解, BatchNorm详解,LayerNorm详解-视频
什么是层归一化LayerNorm,为什么Transformer使用层归一化-视频
Layer Normalization-论文
BatchNorm and LayerNorm
Transformer学习笔记四:ResNet(残差网络)
代码实现
layernorm直接使用pytorch中实现的类。
前面实现的embedding,positional encoding与MultiHeadAttention代码。
import torch.nn as nn
import torch.nn.functional as F
import torch
import math
vocab_size = 10
d_model = 4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class MyEmbedding(nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, device):
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.device = device
super().__init__(self.num_embeddings, self.embedding_dim, device=device)
def forward(self, input_ids):
return super().forward(input_ids) * torch.sqrt(torch.tensor(self.embedding_dim).to(device))
class MyPositonalEncoding(nn.Module):
def __init__(self, seq_length, d_model, device):
if d_model % 2:
raise ValueError("embedding_dim must be an even number for positional encoding.")
super().__init__()
self.seq_length = seq_length
self.d_model = d_model
self.device = device
pe = torch.zeros(self.seq_length, self.d_model)
pos = torch.arange(0, self.seq_length, dtype=torch.float).unsqueeze(1)
# div_term = 1 / (1000 ** (torch.arange(0, d_model, 2).float() / d_model)), 存在精度与性能问题,参考pytorch使用下面方式
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(pos * div_term)
pe[:, 1::2] = torch.cos(pos * div_term)
# (seq_length, d_model) -> (seq_length, 1, d_model)
pe = pe.unsqueeze(0).transpose(0, 1)
# 与self.pe = pe不同在于:会被持久化保存,不参与梯度学习
self.register_buffer("pe", pe)
def forward(self, x):
# x的形状为(seq_length, batch_size, d_model)
seq_length = x.shape[0]
return x + self.pe[:seq_length, :, :].to(device)
class MyAttention(nn.Module):
def __init__(self):
super().__init__()
def forward(self, q, k, v, mask=None):
seq_len, batch_size, num_head, d_k = q.shape
_q = q.permute(1, 2, 0, 3)
_k = k.permute(1, 2, 3, 0)
score = torch.matmul(_q, _k) / math.sqrt(d_k) # score形状为 (batch_size, num_head, seq_len, seq_len)
if mask is not None:
score = score.masked_fill(mask == 0, float('-inf'))
attention_score = F.softmax(score, -1)
_v = v.permute(1, 2, 0, 3)
output = torch.matmul(attention_score, _v) #output形状为(batch_size, num_head, seq_len, d_v)
output = output.permute(2, 0, 1, 3) #output形状为(seq_len, batch_size, num_head, d_v)
return output
class MyMultiHeadAttention(nn.Module):
def __init__(self, num_head, d_model, device):
super().__init__()
self.num_head = num_head
self.d_model = d_model
self.device = device
assert d_model % num_head == 0
self.d_k = d_model // num_head
self.d_v = self.d_k
# 这里采用策略为先统一进行线性变换,再切分给不同头
self.q_proj = nn.Linear(d_model, d_model, bias=False, device=device)
self.k_proj = nn.Linear(d_model, d_model, bias=False, device=device)
self.v_proj = nn.Linear(d_model, d_model, bias=False, device=device)
self.attention = MyAttention()
self.o_proj = nn.Linear(d_model, d_model, bias=False, device=device)
def forward(self, x, mask=None):
seq_len, batch_size, d_model = x.shape
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
# (seq_len, batch_size, d_model) -> (seq_len, batch_size, num_head, d_v)
q = q.contiguous().view(seq_len, batch_size, self.num_head, self.d_k)
k = k.contiguous().view(seq_len, batch_size, self.num_head, self.d_k)
v = v.contiguous().view(seq_len, batch_size, self.num_head, self.d_v)
atten_out = self.attention(q, k, v, mask)
atten_out = atten_out.contiguous().view(seq_len, batch_size, -1)
output = self.o_proj(atten_out)
return output
Add&Norm模块的类
class MyAddAndNorm(nn.Module):
def __init__(self, d_model, device):
super().__init__()
self.d_model = d_model
self.device = device
self.layer_norm = nn.LayerNorm(d_model, device=device)
def forward(self, x, change_x):
return self.layer_norm(x + change_x)
测试代码
embedding_layer = MyEmbedding(vocab_size, d_model, device)
positional_ecoding_layer = MyPositonalEncoding(vocab_size, d_model, device)
multi_head_attention = MyMultiHeadAttention(num_head=2, d_model=d_model, device=device)
# 输入(seq_len, batch_size, d_model)
if_test = True
add_and_norm = MyAddAndNorm(d_model=d_model, device=device)
if if_test:
token_ids = torch.tensor([[0, 1, 2], [2, 3, 4]], dtype=torch.long).to(device)
embedding_ids = embedding_layer(token_ids)
embedding_ids = embedding_ids.transpose(0, 1)
pos_ids = positional_ecoding_layer(embedding_ids)
print(pos_ids.shape)
mask = torch.tril(torch.ones((3, 3)), diagonal = 0).unsqueeze(0).to(device) # 3为seq_len
attention_ids = multi_head_attention(pos_ids)
print(attention_ids)
print(attention_ids.shape)
output_ids = add_and_norm(pos_ids, attention_ids)
print(output_ids)
更多推荐

所有评论(0)