|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
from peft import PeftModel |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BASE_MODEL = "facebook/nllb-200-distilled-600M" |
|
|
|
|
|
|
|
|
|
|
|
LORA_REPO = "flt7007/nllb-mizo-bible-lora" |
|
|
|
|
|
|
|
|
|
|
|
SRC_LANG = "eng_Latn" |
|
|
TGT_LANG = "lus_Latn" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
|
|
|
print("Using device:", device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
LORA_REPO, |
|
|
src_lang=SRC_LANG |
|
|
) |
|
|
|
|
|
|
|
|
base_model = AutoModelForSeq2SeqLM.from_pretrained( |
|
|
BASE_MODEL, |
|
|
torch_dtype=dtype |
|
|
) |
|
|
|
|
|
|
|
|
model = PeftModel.from_pretrained( |
|
|
base_model, |
|
|
LORA_REPO |
|
|
) |
|
|
|
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
forced_bos_token_id = None |
|
|
if hasattr(tokenizer, "lang_code_to_id"): |
|
|
forced_bos_token_id = tokenizer.lang_code_to_id.get(TGT_LANG, None) |
|
|
print("forced_bos_token_id:", forced_bos_token_id) |
|
|
else: |
|
|
print("Tokenizer has no lang_code_to_id; continuing without forced BOS.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def translate_en_to_mizo(text, max_new_tokens, num_beams): |
|
|
text = text.strip() |
|
|
if not text: |
|
|
return "" |
|
|
|
|
|
inputs = tokenizer( |
|
|
text, |
|
|
return_tensors="pt" |
|
|
).to(device) |
|
|
|
|
|
gen_kwargs = { |
|
|
"max_new_tokens": int(max_new_tokens), |
|
|
"num_beams": int(num_beams), |
|
|
} |
|
|
|
|
|
if forced_bos_token_id is not None: |
|
|
gen_kwargs["forced_bos_token_id"] = forced_bos_token_id |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate(**inputs, **gen_kwargs) |
|
|
|
|
|
decoded = tokenizer.batch_decode( |
|
|
outputs, |
|
|
skip_special_tokens=True |
|
|
)[0] |
|
|
|
|
|
return decoded.strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TITLE = "English → Mizo (NLLB-200 + Bible+Dict LoRA)" |
|
|
DESC = """ |
|
|
Low-resource MT demo for **English → Mizo** using: |
|
|
- Base model: `facebook/nllb-200-distilled-600M` |
|
|
- LoRA: fine-tuned on dictionary + Bible parallel data |
|
|
Model is more Bible/education style and still in-progress. |
|
|
""" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown(f"# {TITLE}") |
|
|
gr.Markdown(DESC) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
en_input = gr.Textbox( |
|
|
label="English input", |
|
|
lines=4, |
|
|
placeholder="Type an English sentence here…" |
|
|
) |
|
|
max_new_tokens = gr.Slider( |
|
|
minimum=10, |
|
|
maximum=200, |
|
|
value=80, |
|
|
step=5, |
|
|
label="Max new tokens" |
|
|
) |
|
|
num_beams = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=8, |
|
|
value=4, |
|
|
step=1, |
|
|
label="Beam size" |
|
|
) |
|
|
translate_btn = gr.Button("Translate → Mizo") |
|
|
|
|
|
with gr.Column(): |
|
|
mz_output = gr.Textbox( |
|
|
label="Mizo output", |
|
|
lines=6 |
|
|
) |
|
|
|
|
|
translate_btn.click( |
|
|
fn=translate_en_to_mizo, |
|
|
inputs=[en_input, max_new_tokens, num_beams], |
|
|
outputs=mz_output |
|
|
) |
|
|
|
|
|
demo.queue() |
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|