gau alpha torch

pytorch 代码

https://github.com/JunnYu/GAU-alpha-pytorch

bert4keras代码

https://github.com/ZhuiyiTechnology/GAU-alpha

Install

pip install git+https://github.com/JunnYu/GAU-alpha-pytorch.git
or
pip install gau_alpha

评测对比

CLUE-dev榜单分类任务结果,base版本。

iflytek tnews afqmc cmnli ocnli wsc csl
BERT 60.06 56.80 72.41 79.56 73.93 78.62 83.93
RoBERTa 60.64 58.06 74.05 81.24 76.00 87.50 84.50
RoFormer 60.91 57.54 73.52 80.92 76.07 86.84 84.63
RoFormerV2<sup>*</sup> 60.87 56.54 72.75 80.34 75.36 80.92 84.67
GAU-α 61.41 57.76 74.17 81.82 75.86 79.93 85.67
RoFormerV2-pytorch 62.87 59.03 76.20 80.85 79.73 87.82 91.87
GAU-α-pytorch(Adafactor) 61.18 57.52 73.42 80.91 75.69 80.59 85.5
GAU-α-pytorch(AdamW wd0.01 warmup0.1) 60.68 57.95 73.08 81.02 75.36 81.25 83.93

CLUE-test榜单分类任务结果,base版本。

iflytek tnews afqmc cmnli ocnli wsc csl
RoFormerV2-pytorch 63.15 58.24 75.42 80.59 74.17 83.79 83.73
GAU-α-pytorch(Adafactor) 61.38 57.08 74.05 80.37 73.53 74.83 85.6
GAU-α-pytorch(AdamW wd0.01 warmup0.1) 60.54 57.67 72.44 80.32 72.97 76.55 84.13

CLUE-dev集榜单阅读理解和NER结果

cmrc2018 c3 chid cluener
BERT 56.17 60.54 85.69 79.45
RoBERTa 56.54 67.66 86.71 79.47
RoFormer 56.26 67.24 86.57 79.72
RoFormerV2<sup>*</sup> 57.91 64.62 85.09 81.08
GAU-α 58.09 68.24 87.91 80.01

注:

Usage

import torch

from gau_alpha import GAUAlphaForMaskedLM, GAUAlphaTokenizer

text = "今天[MASK]很好,我[MASK]去公园玩。"
tokenizer = GAUAlphaTokenizer.from_pretrained(
    "junnyu/chinese_GAU-alpha-char_L-24_H-768"
)
pt_model = GAUAlphaForMaskedLM.from_pretrained(
    "junnyu/chinese_GAU-alpha-char_L-24_H-768"
)

pt_inputs = tokenizer(text, return_tensors="pt")

with torch.no_grad():
    pt_outputs = pt_model(**pt_inputs).logits[0]
pt_outputs_sentence = "pytorch: "
for i, id in enumerate(tokenizer.encode(text)):
    if id == tokenizer.mask_token_id:
        val, idx = pt_outputs[i].softmax(-1).topk(k=5)
        tokens = tokenizer.convert_ids_to_tokens(idx)
        new_tokens = []
        for v, t in zip(val.cpu(), tokens):
            new_tokens.append(f"{t}+{round(v.item(),4)}")
        pt_outputs_sentence += "[" + "||".join(new_tokens) + "]"
    else:
        pt_outputs_sentence += "".join(
            tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True)
        )
print(pt_outputs_sentence)
# pytorch: 今天[天+0.8657||气+0.0535||阳+0.0165||,+0.0126||晴+0.0111]很好,我[要+0.4619||想+0.4352||又+0.0252||就+0.0157||跑+0.0064]去公园玩。

Reference

Bibtex:

@techreport{gau-alpha,
  title={GAU-α: GAU-based Transformers for NLP - ZhuiyiAI},
  author={Jianlin Su, Shengfeng Pan, Bo Wen, Yunfeng Liu},
  year={2022},
  url="https://github.com/ZhuiyiTechnology/GAU-alpha",
}