flt7007's picture
Create app.py
4b72e8d verified
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
import gradio as gr
# =========================
# CONFIG
# =========================
# Base NLLB model
BASE_MODEL = "facebook/nllb-200-distilled-600M"
# Your LoRA repo on HF Hub
# 👉 CHANGE THIS to your actual repo if different
LORA_REPO = "flt7007/nllb-mizo-bible-lora"
# e.g. "frankiethiak/nllb-mizo-bible-lora"
# NLLB language codes
SRC_LANG = "eng_Latn" # English
TGT_LANG = "lus_Latn" # Mizo (Lushai / Mizo)
# =========================
# LOAD TOKENIZER + MODEL
# =========================
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
print("Using device:", device)
# 🔴 IMPORTANT:
# Load tokenizer from the LoRA repo, not the base model
# This fixes the “ mojibake issue.
tokenizer = AutoTokenizer.from_pretrained(
LORA_REPO,
src_lang=SRC_LANG
)
# Load base NLLB model
base_model = AutoModelForSeq2SeqLM.from_pretrained(
BASE_MODEL,
torch_dtype=dtype
)
# Attach LoRA
model = PeftModel.from_pretrained(
base_model,
LORA_REPO
)
model.to(device)
model.eval()
# Try to set forced BOS for Mizo if available
forced_bos_token_id = None
if hasattr(tokenizer, "lang_code_to_id"):
forced_bos_token_id = tokenizer.lang_code_to_id.get(TGT_LANG, None)
print("forced_bos_token_id:", forced_bos_token_id)
else:
print("Tokenizer has no lang_code_to_id; continuing without forced BOS.")
# =========================
# TRANSLATION FUNCTION
# =========================
def translate_en_to_mizo(text, max_new_tokens, num_beams):
text = text.strip()
if not text:
return ""
inputs = tokenizer(
text,
return_tensors="pt"
).to(device)
gen_kwargs = {
"max_new_tokens": int(max_new_tokens),
"num_beams": int(num_beams),
}
# Only pass forced_bos_token_id if we actually have it
if forced_bos_token_id is not None:
gen_kwargs["forced_bos_token_id"] = forced_bos_token_id
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
decoded = tokenizer.batch_decode(
outputs,
skip_special_tokens=True
)[0]
return decoded.strip()
# =========================
# GRADIO UI
# =========================
TITLE = "English → Mizo (NLLB-200 + Bible+Dict LoRA)"
DESC = """
Low-resource MT demo for **English → Mizo** using:
- Base model: `facebook/nllb-200-distilled-600M`
- LoRA: fine-tuned on dictionary + Bible parallel data
Model is more Bible/education style and still in-progress.
"""
with gr.Blocks() as demo:
gr.Markdown(f"# {TITLE}")
gr.Markdown(DESC)
with gr.Row():
with gr.Column():
en_input = gr.Textbox(
label="English input",
lines=4,
placeholder="Type an English sentence here…"
)
max_new_tokens = gr.Slider(
minimum=10,
maximum=200,
value=80,
step=5,
label="Max new tokens"
)
num_beams = gr.Slider(
minimum=1,
maximum=8,
value=4,
step=1,
label="Beam size"
)
translate_btn = gr.Button("Translate → Mizo")
with gr.Column():
mz_output = gr.Textbox(
label="Mizo output",
lines=6
)
translate_btn.click(
fn=translate_en_to_mizo,
inputs=[en_input, max_new_tokens, num_beams],
outputs=mz_output
)
demo.queue()
if __name__ == "__main__":
demo.launch()