当将 return_dict 设置为 True 时,BERT 模型的输出格式会从元组(tuple) 变为字典风格的对象(BaseModelOutput 及其子类),这会导致后续处理代码的写法有明显差异。

具体来说,return_dict=True 时的代码写法与 return_dict=False 相比,主要区别在于如何访问 BERT 的输出结果

1. return_dict=False(元组格式,原代码)

输出是元组,通过索引访问隐藏状态:

python

运行

# 模型定义
self.bert = BertModel.from_pretrained(..., return_dict=False)

# 后续使用(前向传播)
def forward(self, input_ids, labels=None):
    # BERT输出是元组:(last_hidden_state, pooler_output, ...)
    outputs = self.bert(input_ids)
    # 取第0个元素作为最后一层隐藏状态
    last_hidden_state = outputs[0]
    # 传入线性层计算logits
    logits = self.classify(last_hidden_state)
    
    if labels is not None:
        # 计算损失(假设labels是下一个token的索引)
        loss = self.loss(logits.transpose(1, 2), labels)
        return loss, logits
    return logits

2. return_dict=True(字典对象格式)

输出是带有属性的对象,通过属性名访问隐藏状态(更直观):

python

运行

# 模型定义
self.bert = BertModel.from_pretrained(..., return_dict=True)  # 默认就是True

# 后续使用(前向传播)
def forward(self, input_ids, labels=None):
    # BERT输出是字典风格对象,包含last_hidden_state等属性
    outputs = self.bert(input_ids)
    # 直接通过属性名获取最后一层隐藏状态(无需记索引)
    last_hidden_state = outputs.last_hidden_state
    # 传入线性层计算logits(与之前一致)
    logits = self.classify(last_hidden_state)
    
    if labels is not None:
        loss = self.loss(logits.transpose(1, 2), labels)
        return loss, logits
    return logits

核心差异总结

特性 return_dict=False(元组) return_dict=True(字典对象)
输出类型 元组(tuple) BaseModelOutput 类对象
访问隐藏状态方式 outputs[0](依赖索引,需记忆位置) outputs.last_hidden_state(属性名访问,直观)
代码可读性 较低(需记住索引对应含义) 较高(属性名自解释)
兼容性 兼容旧代码(早期 Transformers 版本) 适合新代码(v4.0 + 版本默认)

其他可访问的属性

当 return_dict=True 时,BERT 输出对象还包含其他常用属性(根据模型配置可能有所不同):

  • outputs.pooler_output:[CLS] token 的池化输出(常用于分类任务)
  • outputs.hidden_states:所有层的隐藏状态(需在模型加载时设置 output_hidden_states=True
  • outputs.attentions:所有层的注意力权重(需设置 output_attentions=True

例如,获取所有层的隐藏状态:

python

运行

# 加载模型时指定输出隐藏状态
self.bert = BertModel.from_pretrained(..., return_dict=True, output_hidden_states=True)

# 前向传播中访问
outputs = self.bert(input_ids)
all_hidden_states = outputs.hidden_states  # 列表,包含每一层的隐藏状态

总结

return_dict=True 时,代码会更直观易读(通过属性名访问,无需记忆索引),这也是当前 Transformers 库的默认设置,推荐在新代码中使用。如果你的代码需要兼容旧项目或依赖元组索引的处理逻辑,才需要显式设置 return_dict=False。两种方式在功能上完全等价,仅输出格式不同。

Logo

开源鸿蒙跨平台开发社区汇聚开发者与厂商,共建“一次开发,多端部署”的开源生态,致力于降低跨端开发门槛,推动万物智联创新。

更多推荐