diff --git a/inference.py b/inference.py index 94ce345..6bfa3e1 100644 --- a/inference.py +++ b/inference.py @@ -360,6 +360,10 @@ def load_model(ckpt_dir, device): whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small") def embed_audio(mel): + """ + This function now mimics the behavior of the Whisper encoder's `embed_audio` function, where it processes + the mel spectrogram and passes it through the encoder to return the encoded audio features. + """ # Convert mel spectrogram to a format suitable for Whisper mel_cpu = mel.cpu().numpy() @@ -369,37 +373,18 @@ def load_model(ckpt_dir, device): if mel_cpu.ndim != 1: raise ValueError(f"Audio is not mono! Shape: {mel_cpu.shape}") - # Process mel with Whisper processor + # Process mel with Whisper processor to get input features inputs = whisper_processor(mel_cpu, sampling_rate=16000, return_tensors="pt").input_features.to(device) - # Debugging: Log input feature shapes - print(f"Input features shape: {inputs.shape}") - print(f"First 5 values of input features: {inputs[0, :, :5]}") + # Pass input features through Whisper's encoder + encoder_outputs = whispermodel.model.encoder(inputs) - # Set beam search and max length for transcription - generated_ids = whispermodel.generate( - inputs, - max_length=256, # Increased length to allow for longer sentences - num_beams=5, # Higher quality output through beam search - early_stopping=True - ) - - # Decode the generated tokens - transcription = whisper_processor.decode(generated_ids[0], skip_special_tokens=True) - - # Log the transcription and token IDs - print(f"Generated token IDs: {generated_ids[0].tolist()}") - print(f"Transcription (for log purposes): {transcription}") - - # Return logits as per original flow - start_token = whisper_processor.tokenizer.pad_token_id - decoder_input_ids = torch.tensor([[start_token]], device=device) - - outputs = whispermodel(inputs, decoder_input_ids=decoder_input_ids) - - return outputs.logits + # Return the encoded audio features (logits or embeddings from the encoder) + return encoder_outputs + # Assign the custom `embed_audio` method to `whispermodel` whispermodel.embed_audio = embed_audio + else: # If not MPS, use the regular Whisper model loading whispermodel = whisper.load_model("small").to(device) @@ -418,6 +403,7 @@ def load_model(ckpt_dir, device): model.load_state_dict(state_dict, strict=True) model.to(device).eval() + # Return everything as before, ensuring whispermodel is properly handled return fabric, model, text_tokenizer, snacmodel, whispermodel diff --git a/litgpt/model.py b/litgpt/model.py index 85837d3..7870bd4 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -104,7 +104,24 @@ class GPT(nn.Module): for j in range(len(T)): if task[j] != "T1T2" and task[j] != "T1A2": for i in range(7): - input_ids[i][j, 1 : T[j] + 1, :] = audio_feature[j][: T[j]].clone() + # Ensure audio_feature has the correct dimensions + audio_feat = audio_feature[j][: T[j]].clone() + + # Debug: print shapes for comparison + print(f"audio_feature shape: {audio_feat.shape}") + print(f"input_ids shape: {input_ids[i][j, 1: T[j] + 1, :].shape}") + + # If the audio_feature is 1D, expand it to match the embedding dimensions + if audio_feat.ndim == 1: + audio_feat = audio_feat.unsqueeze(-1) # Add an extra dimension + + # If audio_feature has fewer dimensions than required, repeat the values across the embedding dimension + if audio_feat.size(-1) != input_ids[i][j, 1: T[j] + 1, :].size(-1): + audio_feat = audio_feat.expand(-1, input_ids[i][j, 1: T[j] + 1, :].size(-1)) + + # Now, assign the expanded or reshaped audio_feature to input_ids + input_ids[i][j, 1: T[j] + 1, :] = audio_feat + else: continue return input_ids