RankingPrompter
RankingPrompter是由人工智能与数字经济广东省实验室(深圳光明实验室)开发的一个开源的重排/精排模型。
- 在1500万中文句对数据集上进行训练。
- 在多项中文测试集上均取得最好的效果。
如果希望使用RankingPrompter更加丰富的功能(如完整的文档编码-召回-精排链路),我们推荐使用配套代码库(To be released)。
如何使用
You can use this model simply as a re-ranker, note now the model is only available for Chinese. 本模型可简单用作一个强力的重排/精排模型,现阶段仅支持中文。
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("howard-hou/RankingPrompterForPreTraining-small")
# trust_remote_code=True 很重要,否则不会读取到正确的模型
model = AutoModel.from_pretrained("howard-hou/RankingPrompterForPreTraining-small",
trust_remote_code=True)
#
documents = [
'水库诱发地震的震中多在库底和水库边缘。',
'双标紫斑蝶广泛分布于南亚、东南亚、澳洲、新几内亚等地。台湾地区于本岛中海拔地区可见,多以特有亚种归类。',
'月经停止是怀孕最显著也是最早的一个信号,如果在无避孕措施下进行了性生活而出现月经停止的话,很可能就是怀孕了。'
]
question = "什么是怀孕最显著也是最早的信号?"
question_input = tokenizer(question, padding=True, return_tensors="pt")
docs_input = tokenizer(documents, padding=True, return_tensors="pt")
# document input shape should be [batch_size, num_docs, seq_len]
# so if only input one sample of documents, add one dim by unsqueeze(0)
output = model(
document_input_ids=docs_input.input_ids.unsqueeze(0),
document_attention_mask=docs_input.attention_mask.unsqueeze(0),
question_input_ids=question_input.input_ids,
question_attention_mask=question_input.attention_mask
)
print("reranking scores: ", output.logits)