8bit sharded open_llama

open_llama_13b-sharded-8bit

<a href="https://colab.research.google.com/gist/pszemraj/166ad661c6af1e024d4e2897621fc886/open_llama_13b-sharded-8bit-example.ipynb"> <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> </a>

This is open_llama_13b sharded into 2 GB shards, and in 8-bit precision using bitsandbytes==0.38.0. Please refer to the original model card for details.

loading

pip install -U -q sentencepiece transformers accelerate bitsandbytes

load the model and tokenizer:

import torch
from transformers import LlamaTokenizer, LlamaForCausalLM

model_name = "ethzanalytics/open_llama_13b-sharded-8bit"
tokenizer = LlamaTokenizer.from_pretrained(model_name, use_fast=False)
model = LlamaForCausalLM.from_pretrained(
    model_name,
    load_in_8bit=True,
    device_map="auto",
)