[edit] whispermodel to embed audio method

This commit is contained in:
rino
2024-09-21 21:58:09 +07:00
parent 3b963692db
commit bce7d68e9c
2 changed files with 30 additions and 27 deletions

View File

@@ -360,6 +360,10 @@ def load_model(ckpt_dir, device):
whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
def embed_audio(mel):
"""
This function now mimics the behavior of the Whisper encoder's `embed_audio` function, where it processes
the mel spectrogram and passes it through the encoder to return the encoded audio features.
"""
# Convert mel spectrogram to a format suitable for Whisper
mel_cpu = mel.cpu().numpy()
@@ -369,37 +373,18 @@ def load_model(ckpt_dir, device):
if mel_cpu.ndim != 1:
raise ValueError(f"Audio is not mono! Shape: {mel_cpu.shape}")
# Process mel with Whisper processor
# Process mel with Whisper processor to get input features
inputs = whisper_processor(mel_cpu, sampling_rate=16000, return_tensors="pt").input_features.to(device)
# Debugging: Log input feature shapes
print(f"Input features shape: {inputs.shape}")
print(f"First 5 values of input features: {inputs[0, :, :5]}")
# Pass input features through Whisper's encoder
encoder_outputs = whispermodel.model.encoder(inputs)
# Set beam search and max length for transcription
generated_ids = whispermodel.generate(
inputs,
max_length=256, # Increased length to allow for longer sentences
num_beams=5, # Higher quality output through beam search
early_stopping=True
)
# Decode the generated tokens
transcription = whisper_processor.decode(generated_ids[0], skip_special_tokens=True)
# Log the transcription and token IDs
print(f"Generated token IDs: {generated_ids[0].tolist()}")
print(f"Transcription (for log purposes): {transcription}")
# Return logits as per original flow
start_token = whisper_processor.tokenizer.pad_token_id
decoder_input_ids = torch.tensor([[start_token]], device=device)
outputs = whispermodel(inputs, decoder_input_ids=decoder_input_ids)
return outputs.logits
# Return the encoded audio features (logits or embeddings from the encoder)
return encoder_outputs
# Assign the custom `embed_audio` method to `whispermodel`
whispermodel.embed_audio = embed_audio
else:
# If not MPS, use the regular Whisper model loading
whispermodel = whisper.load_model("small").to(device)
@@ -418,6 +403,7 @@ def load_model(ckpt_dir, device):
model.load_state_dict(state_dict, strict=True)
model.to(device).eval()
# Return everything as before, ensuring whispermodel is properly handled
return fabric, model, text_tokenizer, snacmodel, whispermodel

View File

@@ -104,7 +104,24 @@ class GPT(nn.Module):
for j in range(len(T)):
if task[j] != "T1T2" and task[j] != "T1A2":
for i in range(7):
input_ids[i][j, 1 : T[j] + 1, :] = audio_feature[j][: T[j]].clone()
# Ensure audio_feature has the correct dimensions
audio_feat = audio_feature[j][: T[j]].clone()
# Debug: print shapes for comparison
print(f"audio_feature shape: {audio_feat.shape}")
print(f"input_ids shape: {input_ids[i][j, 1: T[j] + 1, :].shape}")
# If the audio_feature is 1D, expand it to match the embedding dimensions
if audio_feat.ndim == 1:
audio_feat = audio_feat.unsqueeze(-1) # Add an extra dimension
# If audio_feature has fewer dimensions than required, repeat the values across the embedding dimension
if audio_feat.size(-1) != input_ids[i][j, 1: T[j] + 1, :].size(-1):
audio_feat = audio_feat.expand(-1, input_ids[i][j, 1: T[j] + 1, :].size(-1))
# Now, assign the expanded or reshaped audio_feature to input_ids
input_ids[i][j, 1: T[j] + 1, :] = audio_feat
else:
continue
return input_ids