项目地址:LLMPruner:大语言模型裁剪工具

LLMPruner是一个大语言模型裁剪工具,通过对大语言模型的冗余词表进行裁剪,减少模型参数量,降低显存占用,提升训练速度,并且能够保留预训练中学习到的知识。

本项目对Bloom进行词表裁剪,保留中文token和常用的英文token,词表由250880将至46145,缩减为原来的18.39%。裁剪得到的Bloom模型如下表:

裁剪模型 原模型 参数量比例
YeungNLP/bloom-396m-zh bigscience/bloom-560m 70.96%
YeungNLP/bloom-820m-zh bigscience/bloom-1b1 77.13%
YeungNLP/bloom-1b4-zh bigscience/bloom-1b7 81.14%
YeungNLP/bloom-2b6-zh bigscience/bloom-3b 86.48%
YeungNLP/bloom-6b4-zh bigscience/bloom-7b1 90.81%
YeungNLP/bloomz-396m-zh bigscience/bloomz-560m 70.96%
YeungNLP/bloomz-820m-zh bigscience/bloomz-1b1 77.13%
YeungNLP/bloomz-1b4-zh bigscience/bloomz-1b7 81.14%
YeungNLP/bloomz-2b6-zh bigscience/bloomz-3b 86.48%
YeungNLP/bloomz-6b4-zh bigscience/bloomz-7b1 90.81%
YeungNLP/bloomz-6b4-mt-zh bigscience/bloomz-7b1-mt 90.81%

使用方法:

from transformers import BloomTokenizerFast, BloomForCausalLM

tokenizer = BloomTokenizerFast.from_pretrained('YeungNLP/bloom-1b4-zh')
model = BloomForCausalLM.from_pretrained('YeungNLP/bloom-1b4-zh')
print(tokenizer.batch_decode(model.generate(tokenizer.encode('长风破浪会有时', return_tensors='pt'))))