beltout / run.py
BillK
Make `select_checkpoint_from_menu`'s behavior more generic over filenames
f71295e
import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
import librosa
import time
import os
from beltout import BeltOutTTM
import torchaudio
import torchcrepe
import soundfile
import questionary
def select_file_from_menu(folder_path: str, prefix: str) -> str | None:
"""
Scans a folder for files with a specific prefix, displays an interactive
menu, and returns the user's selection.
Args:
folder_path: The path to the folder to search.
prefix: The file prefix to match (e.g., 'cfm_step_').
Returns:
The full path of the selected file, or None if no file was selected
or no matching files were found.
"""
try:
# Find all files matching the prefix.
all_files = os.listdir(folder_path)
matching_files = [
f for f in all_files
if f.startswith(prefix) and os.path.isfile(os.path.join(folder_path, f))
]
if not matching_files:
print(f"No checkpoints found in '{folder_path}' with prefix '{prefix}'! Make sure to have at least one checkpoint downloaded for each model.")
return None
# Sort to find the "greatest name" and make it the default. `questionary` automatically places the cursor on the first item in the list.
# Holding Enter during selection will quickly choose the latest checkpoints available for every model.
def key(name):
try:
return int(os.path.splitext(name)[0].replace(prefix, "").split("_")[-1])
except ValueError:
return 0
matching_files.sort(reverse=True, key=key)
# Show the CLI UI and get the user's choice.
selected_file_name = questionary.select(
"Checkpoints:",
choices=matching_files,
use_indicator=True # Adds a nice '>' indicator.
).ask() # .ask() returns the selection or None if the user cancels (e.g., Ctrl+C)
if selected_file_name:
return os.path.join(folder_path, selected_file_name)
else:
# User cancelled the selection.
return None
except FileNotFoundError:
print(f"Error: The folder '{folder_path}' does not exist.")
return None
def select_checkpoint_from_menu(model_name: str) -> str:
print("Searching inside path './checkpoints' for available checkpoints...")
ckpt_path = select_file_from_menu("./checkpoints", model_name + "_")
if ckpt_path is None:
print("No checkpoints found for model '" + model_name + "'!")
exit()
return ckpt_path
# --- Model and Checkpoint Loading ---
print("Loading model...")
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
if torch.backends.mps.is_available():
device = 'mps'
# Load the pre-trained model
# try:
# model = BeltOutTTM.from_pretrained_hf(local_dir="./checkpoints", device=device)
# print(f"Model loaded on {device}.")
# except Exception as e:
# print(f"Could not load pretrained model from HF: {e}.")
try:
decoder_ckpt_path = select_checkpoint_from_menu("cfm")
pitchmvmt_ckpt_path = select_checkpoint_from_menu("pitchmvmt")
encoder_ckpt_path = select_checkpoint_from_menu("encoder")
flow_ckpt_path = select_checkpoint_from_menu("flow")
mel2wav_ckpt_path = select_checkpoint_from_menu("mel2wav")
speaker_encoder_ckpt_path = select_checkpoint_from_menu("speaker_encoder")
tokenizer_ckpt_path = select_checkpoint_from_menu("tokenizer")
model = BeltOutTTM.from_local(decoder_ckpt_path,
pitchmvmt_ckpt_path,
encoder_ckpt_path,
flow_ckpt_path,
mel2wav_ckpt_path,
speaker_encoder_ckpt_path,
tokenizer_ckpt_path, device=device)
# Set model to evaluation mode
model.eval()
print(f"Model loaded from local './checkpoints' directory on {device}.")
except Exception as e_local:
print(f"Could not load model from local directory: {e_local}")
exit()
# --- Utility Functions ---
def get_x_vector_from_wav_chunk(wav_chunk):
"""Helper to get a single x-vector from a raw audio chunk."""
ref_tensor = torch.from_numpy(wav_chunk).float().to(device).unsqueeze(0)
with torch.inference_mode():
return model.embed_ref_x_vector(ref_tensor, model.sr, device=device).detach().cpu().numpy().flatten()
def smart_split(wav, sr, min_chunk_duration_s=40.0, top_db=35):
split_indices = librosa.effects.split(wav, top_db=top_db, frame_length=1024, hop_length=256)
if len(split_indices) < 1: return [wav]
min_chunk_samples = int(min_chunk_duration_s * sr)
merged_chunks = []
current_chunk_start = 0
for i in range(len(split_indices)):
next_chunk_start = split_indices[i+1][0] if i + 1 < len(split_indices) else len(wav)
if (next_chunk_start - current_chunk_start) >= min_chunk_samples and current_chunk_start != next_chunk_start:
merged_chunks.append(wav[current_chunk_start:next_chunk_start])
current_chunk_start = next_chunk_start
if current_chunk_start < len(wav):
last_piece = wav[current_chunk_start:]
if merged_chunks and len(last_piece) < min_chunk_samples / 2:
merged_chunks[-1] = np.concatenate([merged_chunks[-1], last_piece])
elif len(last_piece) > 0:
merged_chunks.append(last_piece)
return merged_chunks
def get_vector_stats(vector):
if vector is None: return "Vector Stats: N/A"
return f"Vector Norm: {np.linalg.norm(vector):.4f} | Mean: {np.mean(vector):.4f}"
# --- Gradio Backend Logic ---
def update_chunk_slider(audio_file):
"""
When a user uploads an audio file to a blender slot, this function
updates the corresponding chunk size slider to match the audio's duration.
"""
if audio_file is None:
# If audio is cleared, reset and hide the slider
return gr.update(value=0, maximum=120, visible=False)
try:
# gr.Audio(type="filepath") provides the path to a temporary file
duration = librosa.get_duration(path=audio_file)
# Make the slider visible and set its max value to the audio duration
return gr.update(maximum=int(duration), visible=True)
except Exception as e:
print(f"Could not get audio duration: {e}")
# If there's an error, just return a default state
return gr.update(value=0, maximum=120, visible=False)
def set_source_audio(state, source_audio):
if state.get("current_x_vector") is None:
gr.Warning("Please set a vector first!")
return state, None
if source_audio is None: return state, None
state["source_audio_path"] = source_audio
state["is_source_set"] = True
print("Source audio loaded. Ready for conversion.")
return state, (model.sr, np.zeros(1))
# --- Main Inference Function ---
def run_conversion(state, mode, progress=gr.Progress(track_tqdm=True)):
if not state.get("is_source_set") or "source_audio_path" not in state:
gr.Warning("Please upload a source audio file first.")
return (model.sr, np.zeros(1)), None
source_path = state["source_audio_path"]
active_x_vector = state.get("current_x_vector")
if active_x_vector is None:
gr.Warning("No active x-vector. Please set or create one before running the conversion.")
return (model.sr, np.zeros(1)), None
yield None, None
# --- HELPER FUNCTION FOR A SINGLE CHUNK ---
# This avoids code duplication between HQ and Streaming modes
def process_chunk(wav_chunk):
with torch.inference_mode():
# --- STEP 1: PREPARE ALL RAW INPUTS FIRST ---
waveform_24k_tensor = torch.from_numpy(wav_chunk).float().to(device).unsqueeze(0)
waveform_16k_tensor = torchaudio.transforms.Resample(model.sr, 16000).to(device)(waveform_24k_tensor)
# Get S3 tokens and speaker embedding
s3_tokens, _ = model.tokenizer(waveform_16k_tensor)
x_vector_tensor = torch.from_numpy(active_x_vector).float().to(device).unsqueeze(0)
speaker_embedding = model.flow.spk_embed_affine_layer(x_vector_tensor)
# --- STEP 3: PREPARE CONDITIONING SIGNALS TO MATCH THE TARGET MEL LENGTH ---
# 3a. Prepare token embeddings ('mu')
token_embeddings = model.flow.input_embedding(s3_tokens)
token_len = torch.tensor([token_embeddings.shape[1]], device=device)
h, _ = model.encoder(token_embeddings, token_len)
encoded_tokens = model.flow.encoder_proj(h)
mu = encoded_tokens.transpose(1, 2)
mel_len = mu.shape[2]
# 3b. Prepare pitch embeddings ('pitchmvmt')
pitch_mvmt_encode = None
crepe_sr = 16000
crepe_frames_per_second = 100 # 100 mel frames per second
crepe_hop_length = int(crepe_sr / float(crepe_frames_per_second)) # 10ms hop
n_crepe_frame_in_mel_frame = 2
crepe_samples_needed = mel_len * n_crepe_frame_in_mel_frame * crepe_hop_length
padded_waveform_16k = waveform_16k_tensor
pad_amount = crepe_samples_needed - padded_waveform_16k.shape[1]
if pad_amount > 0:
padded_waveform_16k = F.pad(torch.clone(padded_waveform_16k), (0, pad_amount))
crepe_embedding = torchcrepe.embed(
padded_waveform_16k,
crepe_sr,
hop_length=crepe_hop_length,
model='tiny',
device=device,
)
crepe_embedding = crepe_embedding[:, :mel_len*2, :, :]
projector_input = crepe_embedding.view(-1, n_crepe_frame_in_mel_frame, 256)
pitch_mvmt_encode_flat = model.pitchmvmt(projector_input)
pitch_mvmt_encode = pitch_mvmt_encode_flat.view(1, -1, 80).transpose(1, 2)
# --- STEP 4: GENERATE THE MEL-SPECTROGRAM ---
mask = torch.ones(1, 1, mu.shape[2], device=device, dtype=torch.bool)
output_mels, _ = model.decoder(
mu=mu, mask=mask, spks=speaker_embedding, cond=pitch_mvmt_encode, n_timesteps=10
)
# --- STEP 5: VOCODE ---
output_wav_tensor, _ = model.mel2wav.inference(speech_feat=output_mels)
return output_wav_tensor.squeeze(0).cpu().numpy()
# --- MODE SWITCH LOGIC ---
if mode == "⭐ High Quality (Single Pass)":
progress(0, desc="Starting high-quality conversion...")
source_wav, sr = librosa.load(source_path, sr=model.sr, mono=True)
output_wav_np = process_chunk(source_wav)
progress(1, desc="Conversion complete!")
filename = f"audio_{int(time.time())}.wav"
soundfile.write(filename, output_wav_np, model.sr)
yield (model.sr, output_wav_np), gr.File(value=filename, label="Saved Audio File")
else: # "⚡ Fast Preview (Streaming)"
wav, sr = librosa.load(source_path, sr=None, mono=True)
if sr != model.sr:
wav = librosa.resample(wav, orig_sr=sr, target_sr=model.sr)
source_chunks = smart_split(wav, sr=model.sr)
full_np = np.zeros((0,))
for i, chunk_wav in enumerate(source_chunks):
print(f"Streaming chunk {i+1}/{len(source_chunks)}...")
output_chunk_np = process_chunk(chunk_wav)
full_np = np.concatenate([full_np, output_chunk_np], axis=0)
if i+1 == len(source_chunks):
filename = f"audio_{int(time.time())}.wav"
soundfile.write(filename, full_np, model.sr)
yield (model.sr, output_chunk_np), gr.File(value=filename, label="Saved Audio File")
else:
yield (model.sr, output_chunk_np), None
def synth_style_blender(state, *all_inputs, progress=gr.Progress()):
audio_tasks, npy_tasks = [], []
for i in range(0, 8 * 3, 3):
audio, weight, chunk_size = all_inputs[i:i+3]
if audio is not None:
audio_tasks.append({'audio_path': audio, 'weight': weight, 'chunk_size_s': chunk_size, 'label': f"Voice {chr(65 + i//3)}"})
npy_start_index = 8 * 3
for i in range(0, 8 * 2, 2):
npy_file, weight = all_inputs[npy_start_index + i : npy_start_index + i + 2]
if npy_file is not None:
npy_tasks.append({'npy_path': npy_file.name, 'weight': weight, 'label': f"Vector {i//2 + 1}"})
if not audio_tasks and not npy_tasks:
gr.Warning("Please upload at least one voice/vector.")
active_vector = state.get("current_x_vector")
return state, get_vector_stats(active_vector)
# --- Phase 1 & 2: Combined Processing with Progress ---
all_vectors, all_weights = [], []
for task in audio_tasks:
wav, sr = librosa.load(task['audio_path'], sr=None, mono=True)
if sr != model.sr: wav = librosa.resample(wav, orig_sr=sr, target_sr=model.sr)
chunk_size_s = task['chunk_size_s']
partial_vectors = []
if chunk_size_s > 0:
chunk_samples = int(chunk_size_s * model.sr)
if len(wav) < chunk_samples:
progress(0, desc=f"Processing {task['label']} (clip shorter than chunk size)")
partial_vectors.append(get_x_vector_from_wav_chunk(wav))
else:
num_chunks = (len(wav) - chunk_samples) // chunk_samples + 1
if len(wav) % chunk_samples != 0: num_chunks += 1
for i, start_idx in enumerate(range(0, len(wav) - chunk_samples + 1, chunk_samples)):
progress(i / num_chunks, desc=f"Processing {task['label']}, Chunk {i+1}/{num_chunks}")
chunk = wav[start_idx:start_idx+chunk_samples]
partial_vectors.append(get_x_vector_from_wav_chunk(chunk))
if len(wav) % chunk_samples != 0:
progress((num_chunks-1) / num_chunks, desc=f"Processing {task['label']}, Chunk {num_chunks}/{num_chunks} (final)")
last_chunk = wav[-chunk_samples:]
partial_vectors.append(get_x_vector_from_wav_chunk(last_chunk))
else:
progress(0, desc=f"Processing {task['label']} (full clip)")
partial_vectors.append(get_x_vector_from_wav_chunk(wav))
if partial_vectors:
avg_vector = np.mean(partial_vectors, axis=0)
all_vectors.append(avg_vector)
all_weights.append(task['weight'])
for task in npy_tasks:
progress(0, desc=f"Loading {task['label']}...")
try:
loaded_vector = np.load(task['npy_path'])
if loaded_vector.shape == (192,):
all_vectors.append(loaded_vector); all_weights.append(task['weight'])
else: gr.Warning(f"Skipping {task['label']}: invalid shape {loaded_vector.shape}")
except Exception as e: gr.Warning(f"Skipping {task['label']}: could not load file. Error: {e}")
# --- Phase 3: Final Blending ---
if not all_vectors:
gr.Warning("Failed to process any voices/vectors.");
active_vector = state.get("current_x_vector")
return state, get_vector_stats(active_vector)
progress(0.99, desc="Blending final vectors...")
all_vectors, all_weights = np.array(all_vectors), np.array(all_weights).reshape(-1, 1)
blended_vec = np.sum(all_vectors * all_weights, axis=0)
state["current_x_vector"] = blended_vec
progress(1.0, desc="Blend complete!")
gr.Info("Blended successfully!")
return state, get_vector_stats(blended_vec)
def randomize_vector(state, strength):
random_vector = np.random.randn(192).astype(np.float32)
final_vector = (random_vector / np.linalg.norm(random_vector)) * strength
state["current_x_vector"] = final_vector
gr.Info("Random vector generated! Check the stats for information on it.")
return state, get_vector_stats(final_vector)
def reset_vector(state):
state["current_x_vector"] = None
gr.Info("Cleared.")
return state, get_vector_stats(state.get("current_x_vector"))
def save_vector(state):
active_vector = state.get("current_x_vector")
if active_vector is None:
gr.Warning("No active x-vector to save."); return None
filename = f"vec_{int(time.time())}.npy"
np.save(filename, active_vector)
gr.Info(f"Saved as {filename}.")
return gr.File(value=filename, label="Saved `.npy` File")
def load_vector(state, vector_file):
if vector_file is None:
gr.Warning("Please upload a vector file.")
return state, get_vector_stats(state.get("current_x_vector"))
try:
loaded_vector = np.load(vector_file.name)
except Exception as e:
gr.Warning(f"Failed to load vector file: {e}"); return state, get_vector_stats(state.get("current_x_vector"))
if loaded_vector.shape != (192,):
gr.Warning(f"Invalid vector file. Shape must be (192,), got {loaded_vector.shape}"); return state, get_vector_stats(state.get("current_x_vector"))
active_vector = loaded_vector
state["current_x_vector"] = loaded_vector
gr.Info("Vector loaded!")
return state, get_vector_stats(active_vector)
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
state = gr.State({})
gr.Markdown("# BeltOut Timbre Workshop!!")
with gr.Tabs():
with gr.TabItem("Vectors"):
gr.Markdown("""
The model represents timbre as a list of 192 numbers, which is called the **x-vector** by the originating literature. Taking this in along with your voice recording, the model produces a new audio file with the timbre applied.
You can:
- Load these numbers directly from pre-saved **npy (numpy)** files,
- Or calculate an average timbre vector from example audio files,
- Or mix multiple vectors into a new one,
- Or generate one randomly.
""")
vector_stats_display = gr.Textbox(label="Active Vector Stats", interactive=False)
with gr.Tabs():
with gr.TabItem("🧬 Synth-Style Blender"):
gr.Markdown("""
Blend up to 8 voices and 8 pre-saved vector files. Only slots with uploaded files will be used.
Weights are not normalized. If a weight is set to 2, then the vector will be mixed in with twice the magnitude as expected.
When a sample audio file is uploaded, a new field called "Chunk Size (s)" will appear. While the vector statistics model can theoretically handle any length, the maximum amount of audio it can look at at once is still limited by how much VRAM you have. For such a scenario, you can set this chunk size to something less than the full thing, which will split the audio file into chunks of that size, process those chunks one-by-one, and then merge the vectors back at the end.
""")
all_blend_inputs = []
with gr.Row():
for i in range(2):
with gr.Column():
audio = gr.Audio(type="filepath", label=f"Voice {chr(65+i)}")
weight = gr.Slider(-5, 5, value=1.0, label=f"Weight {chr(65+i)}")
chunk_size = gr.Slider(0, 120, value=0, label="Chunk Size (s)", step=1, visible=False)
audio.upload(fn=update_chunk_slider, inputs=[audio], outputs=[chunk_size])
audio.clear(fn=update_chunk_slider, inputs=[audio], outputs=[chunk_size])
all_blend_inputs.extend([audio, weight, chunk_size])
with gr.Accordion("➕ More Voices", open=False):
gr.Markdown("### Audio Voices (C-H)")
for i in range(2, 8, 2):
with gr.Row():
for j in range(2):
with gr.Column():
audio = gr.Audio(type="filepath", label=f"Voice {chr(65+i+j)}")
weight = gr.Slider(-5, 5, value=0.0, label=f"Weight {chr(65+i+j)}")
chunk_size = gr.Slider(0, 120, value=0, label=f"Chunk Size (s)", step=1, visible=False)
audio.upload(fn=update_chunk_slider, inputs=[audio], outputs=[chunk_size])
audio.clear(fn=update_chunk_slider, inputs=[audio], outputs=[chunk_size])
all_blend_inputs.extend([audio, weight, chunk_size])
with gr.Accordion("➕ Vectors", open=False):
gr.Markdown("### Saved Vector Files (.npy)")
for i in range(0, 8, 2):
with gr.Row():
for j in range(2):
with gr.Column():
npy = gr.File(label=f"Vector {i+j+1}", file_types=[".npy"])
weight = gr.Slider(-5, 5, value=0.0, label=f"Weight {i+j+1}")
all_blend_inputs.extend([npy, weight])
blend_button = gr.Button("Blend Voices")
with gr.TabItem("🔀 Voice Randomizer"):
gr.Markdown("Generate a new, random timbre vector. 'Strength' controls the magnitude (norm) of the random x-vector.")
random_strength_input = gr.Number(value=12.0, label="Randomization Strength (Vector Norm)", minimum=0.01)
randomize_button = gr.Button("✨ Generate Random Voice")
with gr.TabItem("💾 Save / Load Voice"):
gr.Markdown("Save the current active x-vector as a `.npy` file or load a previously saved one.")
with gr.Row():
save_button = gr.Button("Save Active Vector")
load_vector_file = gr.File(label="Load Vector File (.npy)")
saved_file_output = gr.File(label="Saved Vector File", interactive=False)
with gr.TabItem("Main Conversion"):
source_audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Upload or Record Source Audio")
with gr.Accordion("⚙️ Generation Settings", open=False):
mode_switch = gr.Radio(["⭐ High Quality (Single Pass)", "⚡ Fast Preview (Streaming)"], value="⭐ High Quality (Single Pass)", label="Conversion Mode")
start_button = gr.Button("Run", variant="primary")
gr.Markdown("### Output")
output_audio = gr.Audio(label="Conversion Result", streaming=True, autoplay=False) # streaming=True works with generators
saved_audio_file_output = gr.File(label="Saved Audio File", interactive=False)
# --- Event Handlers ---
source_audio_input.upload(fn=set_source_audio, inputs=[state, source_audio_input], outputs=[state, output_audio])
source_audio_input.stop_recording(fn=set_source_audio, inputs=[state, source_audio_input], outputs=[state, output_audio])
start_button.click(fn=run_conversion, inputs=[state, mode_switch], outputs=[output_audio, saved_audio_file_output])
randomize_button.click(fn=randomize_vector, inputs=[state, random_strength_input], outputs=[state, vector_stats_display])
blend_button.click(fn=synth_style_blender, inputs=[state, *all_blend_inputs], outputs=[state, vector_stats_display])
save_button.click(fn=save_vector, inputs=[state], outputs=[saved_file_output])
load_vector_file.upload(fn=load_vector, inputs=[state, load_vector_file], outputs=[state, vector_stats_display])
demo.queue().launch(debug=True, share=False)