mirror of
https://github.com/gpt-omni/mini-omni
synced 2025-06-26 18:16:26 +00:00
[edit] whispermodel to embed audio method
This commit is contained in:
38
inference.py
38
inference.py
@@ -360,6 +360,10 @@ def load_model(ckpt_dir, device):
|
||||
whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
|
||||
|
||||
def embed_audio(mel):
|
||||
"""
|
||||
This function now mimics the behavior of the Whisper encoder's `embed_audio` function, where it processes
|
||||
the mel spectrogram and passes it through the encoder to return the encoded audio features.
|
||||
"""
|
||||
# Convert mel spectrogram to a format suitable for Whisper
|
||||
mel_cpu = mel.cpu().numpy()
|
||||
|
||||
@@ -369,37 +373,18 @@ def load_model(ckpt_dir, device):
|
||||
if mel_cpu.ndim != 1:
|
||||
raise ValueError(f"Audio is not mono! Shape: {mel_cpu.shape}")
|
||||
|
||||
# Process mel with Whisper processor
|
||||
# Process mel with Whisper processor to get input features
|
||||
inputs = whisper_processor(mel_cpu, sampling_rate=16000, return_tensors="pt").input_features.to(device)
|
||||
|
||||
# Debugging: Log input feature shapes
|
||||
print(f"Input features shape: {inputs.shape}")
|
||||
print(f"First 5 values of input features: {inputs[0, :, :5]}")
|
||||
# Pass input features through Whisper's encoder
|
||||
encoder_outputs = whispermodel.model.encoder(inputs)
|
||||
|
||||
# Set beam search and max length for transcription
|
||||
generated_ids = whispermodel.generate(
|
||||
inputs,
|
||||
max_length=256, # Increased length to allow for longer sentences
|
||||
num_beams=5, # Higher quality output through beam search
|
||||
early_stopping=True
|
||||
)
|
||||
|
||||
# Decode the generated tokens
|
||||
transcription = whisper_processor.decode(generated_ids[0], skip_special_tokens=True)
|
||||
|
||||
# Log the transcription and token IDs
|
||||
print(f"Generated token IDs: {generated_ids[0].tolist()}")
|
||||
print(f"Transcription (for log purposes): {transcription}")
|
||||
|
||||
# Return logits as per original flow
|
||||
start_token = whisper_processor.tokenizer.pad_token_id
|
||||
decoder_input_ids = torch.tensor([[start_token]], device=device)
|
||||
|
||||
outputs = whispermodel(inputs, decoder_input_ids=decoder_input_ids)
|
||||
|
||||
return outputs.logits
|
||||
# Return the encoded audio features (logits or embeddings from the encoder)
|
||||
return encoder_outputs
|
||||
|
||||
# Assign the custom `embed_audio` method to `whispermodel`
|
||||
whispermodel.embed_audio = embed_audio
|
||||
|
||||
else:
|
||||
# If not MPS, use the regular Whisper model loading
|
||||
whispermodel = whisper.load_model("small").to(device)
|
||||
@@ -418,6 +403,7 @@ def load_model(ckpt_dir, device):
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
model.to(device).eval()
|
||||
|
||||
# Return everything as before, ensuring whispermodel is properly handled
|
||||
return fabric, model, text_tokenizer, snacmodel, whispermodel
|
||||
|
||||
|
||||
|
||||
@@ -104,7 +104,24 @@ class GPT(nn.Module):
|
||||
for j in range(len(T)):
|
||||
if task[j] != "T1T2" and task[j] != "T1A2":
|
||||
for i in range(7):
|
||||
input_ids[i][j, 1 : T[j] + 1, :] = audio_feature[j][: T[j]].clone()
|
||||
# Ensure audio_feature has the correct dimensions
|
||||
audio_feat = audio_feature[j][: T[j]].clone()
|
||||
|
||||
# Debug: print shapes for comparison
|
||||
print(f"audio_feature shape: {audio_feat.shape}")
|
||||
print(f"input_ids shape: {input_ids[i][j, 1: T[j] + 1, :].shape}")
|
||||
|
||||
# If the audio_feature is 1D, expand it to match the embedding dimensions
|
||||
if audio_feat.ndim == 1:
|
||||
audio_feat = audio_feat.unsqueeze(-1) # Add an extra dimension
|
||||
|
||||
# If audio_feature has fewer dimensions than required, repeat the values across the embedding dimension
|
||||
if audio_feat.size(-1) != input_ids[i][j, 1: T[j] + 1, :].size(-1):
|
||||
audio_feat = audio_feat.expand(-1, input_ids[i][j, 1: T[j] + 1, :].size(-1))
|
||||
|
||||
# Now, assign the expanded or reshaped audio_feature to input_ids
|
||||
input_ids[i][j, 1: T[j] + 1, :] = audio_feat
|
||||
|
||||
else:
|
||||
continue
|
||||
return input_ids
|
||||
|
||||
Reference in New Issue
Block a user