rwkv KoRWKV

KoRWKV

RWKV-Runner에서 사용하기 위해 변환한 모델 파일

import re

import torch

from transformers import RwkvForCausalLM

def convert_state_dict(state_dict):
    state_dict_keys = list(state_dict.keys())
    for name in state_dict_keys:
        weight = state_dict.pop(name)
        # emb -> embedding
        if name.startswith("emb."):
            name = name.replace("emb.", "embeddings.")
        # ln_0 -> pre_ln (only present at block 0)
        if name.startswith("blocks.0.ln0"):
            name = name.replace("blocks.0.ln0", "blocks.0.pre_ln")
        # att -> attention
        name = re.sub(r"blocks\.(\d+)\.att", r"blocks.\1.attention", name)
        # ffn -> feed_forward
        name = re.sub(r"blocks\.(\d+)\.ffn", r"blocks.\1.feed_forward", name)
        # time_mix_k -> time_mix_key and reshape
        if name.endswith(".time_mix_k"):
            name = name.replace(".time_mix_k", ".time_mix_key")
        # time_mix_v -> time_mix_value and reshape
        if name.endswith(".time_mix_v"):
            name = name.replace(".time_mix_v", ".time_mix_value")
        # time_mix_r -> time_mix_key and reshape
        if name.endswith(".time_mix_r"):
            name = name.replace(".time_mix_r", ".time_mix_receptance")

        if name != "head.weight":
            name = "rwkv." + name

        state_dict[name] = weight
    return state_dict


def revert_state_dict(state_dict):
    state_dict_keys = list(state_dict.keys())
    for name in state_dict_keys:
        weight = state_dict.pop(name)
        name = name.removeprefix("rwkv.")

        # emb -> embedding
        if name.startswith("embeddings."):
            name = name.replace("embeddings.", "emb.")
        # ln_0 -> pre_ln (only present at block 0)
        if name.startswith("blocks.0.pre_ln"):
            name = name.replace("blocks.0.pre_ln", "blocks.0.ln0")
        # att -> attention
        name = re.sub(r"blocks\.(\d+)\.attention", r"blocks.\1.att", name)
        # ffn -> feed_forward
        name = re.sub(r"blocks\.(\d+)\.feed_forward", r"blocks.\1.ffn", name)
        # time_mix_k -> time_mix_key and reshape
        if name.endswith(".time_mix_key"):
            name = name.replace(".time_mix_key", ".time_mix_k")
        # time_mix_v -> time_mix_value and reshape
        if name.endswith(".time_mix_value"):
            name = name.replace(".time_mix_value", ".time_mix_v")
        # time_mix_r -> time_mix_key and reshape
        if name.endswith(".time_mix_receptance"):
            name = name.replace(".time_mix_receptance", ".time_mix_r")

        state_dict[name] = weight
    return state_dict


if __name__ == "__main__":
    # repo = "beomi/KoRWKV-6B"
    repo = "beomi/KoAlpaca-KoRWKV-6B"
    model = RwkvForCausalLM.from_pretrained(repo, torch_dtype=torch.bfloat16)

    state_dict = model.state_dict()
    converted = revert_state_dict(state_dict)
    name = repo.split("/")[-1] + ".bf16.pth"

    torch.save(converted, name)