使用BGE Reranker模型计算文本对相关性:

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# 加载预训练模型与分词器(使用BAAI官方发布的reranker模型)
model = AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-large')
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-large')
model.eval()  # 设置为推理模式

def calculate_rerank_score(query, documents):
    """
    计算查询与多个文档的相关性分数
    :param query: 查询文本,如:"熊猫是什么"
    :param documents: 候选文档列表,如:["熊猫是熊科动物", "企鹅生活在南极"]
    :return: 包含分数和文档的排序列表
    """
    # 构造文本对(格式:[[query, doc1], [query, doc2], ...])
    pairs = [[query, doc] for doc in documents]
    
    with torch.no_grad():
        # 批量编码文本对
        inputs = tokenizer(
            pairs,
            padding=True,
            truncation=True,
            return_tensors='pt',
            max_length=512  
        )
        
        # 获取模型输出
        outputs = model(**inputs)
        scores = torch.sigmoid(outputs.logits).squeeze().tolist()  # 将logits转换为0-1概率值

    # 组合结果并按分数降序排序
    sorted_results = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
    return sorted_results

# 使用示例
if __name__ == "__main__":
    query = "What is the capital of France?"
    documents = [
        "Paris is the most populous city in France",
        "Lyon is a major city in eastern France",
        "The Eiffel Tower is located in Paris"
    ]
    
    results = calculate_rerank_score(query, documents)
    
    # 打印结果
    print("Query:", query)
    for rank, (doc, score) in enumerate(results, 1):
        print(f"Rank {rank} (Score: {score:.4f}): {doc}")
  • 输出示例:
Query: What is the capital of France?
Rank 1 (Score: 0.9872): Paris is the most populous city in France
Rank 2 (Score: 0.8531): The Eiffel Tower is located in Paris
Rank 3 (Score: 0.1023): Lyon is a major city in eastern France
  • 关键实现细节说明:
    • 模型选择:使用BAAI/bge-reranker-large模型,该模型专门针对查询-文档相关性任务训练,支持中英文混合场景

    • 输入构造:将查询与每个文档组成二维列表,形成(query, doc)对,这种交叉编码方式能捕捉细粒度语义交互

    • 分数计算:通过sigmoid函数将logits转换为0-1的概率值,分数越高表示相关性越强,0.5为判定阈值

    • 批处理优化:通过padding=True和return_tensors='pt’实现批量推理,提升计算效率

Logo

开源鸿蒙跨平台开发社区汇聚开发者与厂商,共建“一次开发,多端部署”的开源生态,致力于降低跨端开发门槛,推动万物智联创新。

更多推荐