<i> Note: Documentation is under development </i>
Introduction
Madina (Medical Dialog Intelligent Assistant) is a substantial Large Language Model designed specifically for medical purposes and tailored to the Bahasa Indonesia language. Currently we're open-sourcing 7 billion and 13 billion model parameters based on meta-llama/Llama-2-7b-chat-hf
and meta-llama/Llama-2-13b-chat-hf
with utilize RoPE scaling so it can aquire 8k context length. Our model leveraging QLoRA (QLoRA: Efficient Finetuning of Quantized LLMs) so it can be able to run on single consumer GPU with 24 GB VRAM or less.
Dataset
We're collecting dataset with more than 400k records of publicly available medical question and answer from the internet. This dataset is entirely in Bahasa Indonesia. The dataset can be conveniently accessed through this link.
Dataset | Download Link | Status |
---|---|---|
Madina LlaMA 2 400k | 🤗 firqaaa/madina-llama2-400k | Available (Private) |
Madina LlaMA 2 15k | 🤗 firqaaa/madina-llama2-15k | Coming Soon (Public) |
If you intend to replicate the model, it's important to note that due to the significant dataset size, which could potentially lead to inefficiencies when training on consumer GPUs, we have made the decision to scale down the dataset to 15k records.
If you're interested in fine-tuning with your own data, it's essential to adhere to the default prompt format that LlaMA 2 used during its pre-training phase. The prompt is structured similarly to this:
<s>[INST] <<SYS>>
{{ system_prompt }}
<</SYS>>
{{ user_msg_1 }} [/INST] {{ model_answer_1 }} </s>
Model Zoo
Madina 7B Chat and Madina 13 Chat, which were trained using LlaMA 2, have been made publicly available. We offer model weight formats suitable for utilization with HuggingFace models.
Model | Transformers Weight Download Link | Status |
---|---|---|
Madina Chat 7B 8k | 🤗 firqaaa/Madina-7b-chat-QLoRA-8k | Coming Soon |
Madina Chat 13B 8k | 🤗 firqaaa/Madina-13b-chat-QLoRA-8k | Coming Soon |
Training procedure
The following bitsandbytes
quantization config was used during training:
- load_in_8bit: False
- load_in_4bit: True
- llm_int8_threshold: 6.0
- llm_int8_skip_modules: None
- llm_int8_enable_fp32_cpu_offload: False
- llm_int8_has_fp16_weight: False
- bnb_4bit_quant_type: nf4
- bnb_4bit_use_double_quant: False
- bnb_4bit_compute_dtype: bfloat16
Usage
import sys
import re
import time
import torch
import bitsandbytes as bnb
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig, TextStreamer
from peft import PeftModel, PeftConfig, prepare_model_for_int8_training, LoraConfig, get_peft_model
peft_model_id = "firqaaa/Madina-7B-QLoRA-8k"
tokenizer = LlamaTokenizer.from_pretrained("Ichsan2895/Merak-7B-v1")
model = LlamaForCausalLM.from_pretrained("Ichsan2895/Merak-7B-v1",
load_in_4bit=True,
device_map="auto")
# Load the QLoRA model
model = PeftModel.from_pretrained(model, peft_model_id)
system_prompt = "Jika Anda seorang dokter, tolong jawab pertanyaan medis berdasarkan deskripsi pasien. "
user_msg = "Selamat pagi dok, gejala covid-19 apa saja ya? Dan bagaimana cara mencegahnya?"
text = f"""<s>[INST] <<SYS>>
{ system_prompt }
<</SYS>>
{ user_msg } [/INST]"""
def end_of_sentences(input_string):
phrases = ["Semoga bermanfaat.",
"semoga bermanfaat.",
"Salam.", "salam.",
"Semoga membantu.",
"semoga membantu."]
for phrase in phrases:
input_string = re.sub(rf'{phrase}.*', phrase, input_string)
return input_string.strip()
def postproc(text):
pattern_1 = r'\b(?:alodokter|Alodokter)\b'
pattern_2 = r'[^.]*\b(?:artikel)\b[^.]*\.'
modified_text = re.sub(pattern_1, 'Madina', text)
cleaned_text = re.sub(pattern_2, '', modified_text, flags=re.IGNORECASE)
return cleaned_text
def text_streamer(sentence, delay=0.4):
sentence = postproc(sentence)
sentence = end_of_sentences(sentence)
words = sentence.split()
for word in words:
sys.stdout.write(' ' + word)
sys.stdout.flush()
time.sleep(delay)
sys.stdout.flush()
# streamer = TextStreamer(tokenizer)
inputs = tokenizer(
text,
return_tensors="pt"
)
input_ids = inputs["input_ids"].cuda()
generation_config = GenerationConfig(
temperature=0.5,
top_p=0.95,
top_k=4,
num_beams=1,
repetition_penalty=1.15,
)
generation_output = model.generate(
input_ids=input_ids,
# streamer=streamer,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=1024,
)
print("Generating ...\n")
print(f"Patient : {user_msg}\n")
print("Response : ")
try:
for s in generation_output.sequences:
text = str(tokenizer.decode(s).split("[/INST]")[1]).replace('<s>', '')
except:
for s in generation_output.sequences:
text = str(tokenizer.decode(s).split("[/INST]")[0]).replace('<s>', '')
text_streamer(text)
Framework versions
- PEFT 0.4.0
Authors
This repository was created with high dedication and made with love by the hands of the following engineer(s):
Citation
Please cite the repo if you use the data or code in this repo.
@misc{madina,
author = {Arasyi Firqa},
title = {Madina : Medical Instruction-Following in Bahasa Indonesia using LlaMA 2},
year = {2023},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/firqaaa/Madina}},
}