transformers QLoRA llama2 Medical PyTorch

<i> Note: Documentation is under development </i>

Introduction

Madina (Medical Dialog Intelligent Assistant) is a substantial Large Language Model designed specifically for medical purposes and tailored to the Bahasa Indonesia language. Currently we're open-sourcing 7 billion and 13 billion model parameters based on meta-llama/Llama-2-7b-chat-hf and meta-llama/Llama-2-13b-chat-hf with utilize RoPE scaling so it can aquire 8k context length. Our model leveraging QLoRA (QLoRA: Efficient Finetuning of Quantized LLMs) so it can be able to run on single consumer GPU with 24 GB VRAM or less.

Dataset

We're collecting dataset with more than 400k records of publicly available medical question and answer from the internet. This dataset is entirely in Bahasa Indonesia. The dataset can be conveniently accessed through this link.

Dataset Download Link Status
Madina LlaMA 2 400k 🤗 firqaaa/madina-llama2-400k Available (Private)
Madina LlaMA 2 15k 🤗 firqaaa/madina-llama2-15k Coming Soon (Public)

If you intend to replicate the model, it's important to note that due to the significant dataset size, which could potentially lead to inefficiencies when training on consumer GPUs, we have made the decision to scale down the dataset to 15k records.

If you're interested in fine-tuning with your own data, it's essential to adhere to the default prompt format that LlaMA 2 used during its pre-training phase. The prompt is structured similarly to this:

<s>[INST] <<SYS>>
{{ system_prompt }}
<</SYS>>

{{ user_msg_1 }} [/INST] {{ model_answer_1 }} </s>

Model Zoo

Madina 7B Chat and Madina 13 Chat, which were trained using LlaMA 2, have been made publicly available. We offer model weight formats suitable for utilization with HuggingFace models.

Model Transformers Weight Download Link Status
Madina Chat 7B 8k 🤗 firqaaa/Madina-7b-chat-QLoRA-8k Coming Soon
Madina Chat 13B 8k 🤗 firqaaa/Madina-13b-chat-QLoRA-8k Coming Soon

Training procedure

The following bitsandbytes quantization config was used during training:

Usage

import sys
import re
import time

import torch
import bitsandbytes as bnb
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig, TextStreamer
from peft import PeftModel, PeftConfig, prepare_model_for_int8_training, LoraConfig, get_peft_model

peft_model_id = "firqaaa/Madina-7B-QLoRA-8k"

tokenizer = LlamaTokenizer.from_pretrained("Ichsan2895/Merak-7B-v1")
model = LlamaForCausalLM.from_pretrained("Ichsan2895/Merak-7B-v1",
                                         load_in_4bit=True,
                                         device_map="auto")
# Load the QLoRA model
model = PeftModel.from_pretrained(model, peft_model_id)

system_prompt = "Jika Anda seorang dokter, tolong jawab pertanyaan medis berdasarkan deskripsi pasien. "
user_msg = "Selamat pagi dok, gejala covid-19 apa saja ya? Dan bagaimana cara mencegahnya?"

text = f"""<s>[INST] <<SYS>>
{ system_prompt }
<</SYS>>
{ user_msg } [/INST]"""

def end_of_sentences(input_string):
    phrases = ["Semoga bermanfaat.", 
               "semoga bermanfaat.", 
               "Salam.", "salam.", 
               "Semoga membantu.", 
               "semoga membantu."]
    for phrase in phrases:
        input_string = re.sub(rf'{phrase}.*', phrase, input_string)
    return input_string.strip()

def postproc(text):
    pattern_1 = r'\b(?:alodokter|Alodokter)\b'
    pattern_2 = r'[^.]*\b(?:artikel)\b[^.]*\.'
    modified_text = re.sub(pattern_1, 'Madina', text)
    cleaned_text = re.sub(pattern_2, '', modified_text, flags=re.IGNORECASE)
    return cleaned_text

def text_streamer(sentence, delay=0.4):
    sentence = postproc(sentence)
    sentence = end_of_sentences(sentence)
    words = sentence.split()
    for word in words:
        sys.stdout.write(' ' + word)
        sys.stdout.flush()
        time.sleep(delay)
        sys.stdout.flush()

# streamer = TextStreamer(tokenizer)

inputs = tokenizer(
    text,
    return_tensors="pt"
)
input_ids = inputs["input_ids"].cuda()

generation_config = GenerationConfig(
    temperature=0.5,
    top_p=0.95,
    top_k=4,
    num_beams=1,
    repetition_penalty=1.15,
)

generation_output = model.generate(
    input_ids=input_ids,
    # streamer=streamer,
    generation_config=generation_config,
    return_dict_in_generate=True,
    output_scores=True,
    max_new_tokens=1024,
)

print("Generating ...\n")
print(f"Patient : {user_msg}\n")
print("Response : ") 
try:
    for s in generation_output.sequences:
        text = str(tokenizer.decode(s).split("[/INST]")[1]).replace('<s>', '')
except:
    for s in generation_output.sequences:
        text = str(tokenizer.decode(s).split("[/INST]")[0]).replace('<s>', '')
text_streamer(text)

Framework versions

Authors

This repository was created with high dedication and made with love by the hands of the following engineer(s):

Citation

Please cite the repo if you use the data or code in this repo.

@misc{madina,
  author = {Arasyi Firqa},
  title = {Madina : Medical Instruction-Following in Bahasa Indonesia using LlaMA 2},
  year = {2023},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/firqaaa/Madina}},
}