代码

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return self.weight * (
            x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        ).type_as(x)

代码解释

这段代码定义了一个自定义的PyTorch模块 RMSNorm,用于实现Root Mean Square Normalization (RMSNorm)。RMSNorm是一种归一化技术,类似于Layer Normalization,但它只对输入进行缩放,而不进行平移(即没有偏置项)。下面是代码的详细解释:

1. 初始化方法 __init__

def __init__(self, dim: int, eps: float):
    super().__init__()
    self.eps = eps
    self.weight = nn.Parameter(torch.ones(dim))
  • dim: int: 输入特征的维度。
  • eps: float: 一个小常数,用于数值稳定性,避免除以零的情况。
  • self.weight: 一个可学习的参数,形状为 (dim,),初始化为全1的张量。这个参数用于对归一化后的输入进行缩放。

2. 前向传播方法 forward

def forward(self, x):
    return self.weight * (
        x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
    ).type_as(x)
  • x: 输入张量,形状通常为 (batch_size, ..., dim)
  • x.pow(2): 对输入 x 的每个元素求平方。
  • x.pow(2).mean(-1, keepdim=True): 沿着最后一个维度(即特征维度 dim)计算平方的均值,并保持维度不变。结果形状为 (batch_size, ..., 1)
  • torch.rsqrt(...): 计算均方根的倒数(即1除以平方根),用于归一化。
  • x.float() * torch.rsqrt(...): 将输入 x 转换为浮点数后,乘以均方根的倒数,得到归一化后的结果。
  • .type_as(x): 将结果转换回与输入 x 相同的数据类型。
  • self.weight * (...): 最后,将归一化后的结果乘以可学习的权重 self.weight,进行缩放。

3. 总结

  • RMSNorm 通过对输入进行归一化,使得每个特征的均方根值为1,然后通过可学习的权重进行缩放。
  • 与LayerNorm不同,RMSNorm没有偏置项,只进行缩放操作。
  • eps 用于防止除以零的情况,增加数值稳定性。

4. 使用场景

RMSNorm通常用于深度学习模型中,特别是在Transformer架构中,作为LayerNorm的替代方案。它可以加速训练并提高模型的稳定性。

可视化

dim = 64
eps = 1e-5
m = RMSNorm(dim, eps)
x = torch.randn(32, 10, dim)  # 示例输入 (batch_size, seq_len, dim)


f = "rms_norm.onnx"  # 导出的 ONNX 文件名
torch.onnx.export(m, x, f)  # 模型  # 示例输入

https://netron.app/ 上打开 rms_norm.onnx

在这里插入图片描述

Logo

开源鸿蒙跨平台开发社区汇聚开发者与厂商,共建“一次开发,多端部署”的开源生态,致力于降低跨端开发门槛,推动万物智联创新。

更多推荐