diff --git a/inference.py b/inference.py index 3ebaec3..793f1b2 100644 --- a/inference.py +++ b/inference.py @@ -351,56 +351,60 @@ def T1_T2(fabric, input_ids, model, text_tokenizer, step): def load_model(ckpt_dir, device): - # Load the SNAC model + # Load SNAC model snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device) - # Load Whisper model for transcription - whispermodel = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny").to(device) - whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + # Load Whisper model for transcription based on the device + if device == 'mps': + whispermodel = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny").to(device) + whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") - def embed_audio(mel): - # Convert the mel spectrogram to a format suitable for Whisper - mel_cpu = mel.cpu().numpy() + def embed_audio(mel): + # Convert mel spectrogram to a format suitable for Whisper + mel_cpu = mel.cpu().numpy() - if mel_cpu.ndim > 1: - mel_cpu = librosa.to_mono(mel_cpu) + if mel_cpu.ndim > 1: + mel_cpu = librosa.to_mono(mel_cpu) - if mel_cpu.ndim != 1: - raise ValueError(f"Audio is not mono! Shape: {mel_cpu.shape}") + if mel_cpu.ndim != 1: + raise ValueError(f"Audio is not mono! Shape: {mel_cpu.shape}") - # Process mel with Whisper processor - inputs = whisper_processor(mel_cpu, sampling_rate=16000, return_tensors="pt").input_features.to(device) + # Process mel with Whisper processor + inputs = whisper_processor(mel_cpu, sampling_rate=16000, return_tensors="pt").input_features.to(device) - # Debugging: Log input feature shapes - print(f"Input features shape: {inputs.shape}") - print(f"First 5 values of input features: {inputs[0, :, :5]}") + # Debugging: Log input feature shapes + print(f"Input features shape: {inputs.shape}") + print(f"First 5 values of input features: {inputs[0, :, :5]}") - # Set beam search and max length to force more tokens - generated_ids = whispermodel.generate( - inputs, - max_length=256, # Increased length to allow for longer sentences - num_beams=5, # Higher quality output through beam search - early_stopping=True - ) + # Set beam search and max length for transcription + generated_ids = whispermodel.generate( + inputs, + max_length=256, # Increased length to allow for longer sentences + num_beams=5, # Higher quality output through beam search + early_stopping=True + ) - # Decode the generated tokens - transcription = whisper_processor.decode(generated_ids[0], skip_special_tokens=True) + # Decode the generated tokens + transcription = whisper_processor.decode(generated_ids[0], skip_special_tokens=True) - # Log the transcription and token IDs - print(f"Generated token IDs: {generated_ids[0].tolist()}") - print(f"Transcription (for log purposes): {transcription}") + # Log the transcription and token IDs + print(f"Generated token IDs: {generated_ids[0].tolist()}") + print(f"Transcription (for log purposes): {transcription}") - # Now return logits as per original flow - start_token = whisper_processor.tokenizer.pad_token_id - decoder_input_ids = torch.tensor([[start_token]], device=device) + # Return logits as per original flow + start_token = whisper_processor.tokenizer.pad_token_id + decoder_input_ids = torch.tensor([[start_token]], device=device) - outputs = whispermodel(inputs, decoder_input_ids=decoder_input_ids) + outputs = whispermodel(inputs, decoder_input_ids=decoder_input_ids) - return outputs.logits + return outputs.logits - whispermodel.embed_audio = embed_audio + whispermodel.embed_audio = embed_audio + else: + # If not MPS, use the regular Whisper model loading + whispermodel = whisper.load_model("small").to(device) - # Load the GPT model as per your existing code + # Load the GPT model with Fabric text_tokenizer = Tokenizer(ckpt_dir) fabric = L.Fabric(devices=1, strategy="auto") config = Config.from_file(ckpt_dir + "/model_config.yaml") @@ -417,22 +421,43 @@ def load_model(ckpt_dir, device): return fabric, model, text_tokenizer, snacmodel, whispermodel - def download_model(ckpt_dir): repo_id = "gpt-omni/mini-omni" snapshot_download(repo_id, local_dir=ckpt_dir, revision="main") class OmniInference: + def __init__(self, ckpt_dir='./checkpoint', device=None): + # Dynamically determine device (similar to OmniChatServer's get_device) + self.device = self.get_device(device) - def __init__(self, ckpt_dir='./checkpoint', device='mps'): - self.device = device + # If checkpoint directory does not exist, download it if not os.path.exists(ckpt_dir): - print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface") + print(f"Checkpoint directory {ckpt_dir} not found, downloading from HuggingFace.") download_model(ckpt_dir) - self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device) + + # Load models (SNAC, Whisper, GPT) + self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, + self.device) + + def get_device(self, device): + if device is None: + if torch.cuda.is_available(): + return 'cuda' + elif torch.backends.mps.is_available(): + return 'mps' + else: + return 'cpu' + else: + if device == 'cuda' and torch.cuda.is_available(): + return 'cuda' + elif device == 'mps' and torch.backends.mps.is_available(): + return 'mps' + else: + return 'cpu' def warm_up(self, sample='./data/samples/output1.wav'): + # Run a warm-up pass by running the AT batch stream for _ in self.run_AT_batch_stream(sample): pass @@ -445,15 +470,17 @@ class OmniInference: top_k=1, top_p=1.0, eos_id_a=_eoa, - eos_id_t=_eot, - ): + eos_id_t=_eot): + + assert os.path.exists(audio_path), f"Audio file {audio_path} not found" - assert os.path.exists(audio_path), f"audio file {audio_path} not found" model = self.model + # Initialize kv cache with dynamic device handling with self.fabric.init_tensor(): - model.set_kv_cache(batch_size=2, device='mps') + model.set_kv_cache(batch_size=2, device=self.device) + # Load the audio and process it using Whisper mel, leng = load_audio(audio_path) audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device) T = input_ids[0].size(1) @@ -462,19 +489,19 @@ class OmniInference: assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}" if model.max_seq_length < max_returned_tokens - 1: - raise NotImplementedError( - f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}" - ) + raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}") input_pos = torch.tensor([T], device=device) list_output = [[] for i in range(8)] + + # Generate tokens tokens_A, token_T = next_token_batch( model, - audio_feature.to(torch.float32).to(model.device), + audio_feature.to(torch.float32).to(self.device), input_ids, [T - 3, T - 3], ["A1T2", "A1T2"], - input_pos=torch.arange(0, T, device=device), + input_pos=torch.arange(0, T, device=self.device), temperature=temperature, top_k=top_k, top_p=top_p, @@ -484,11 +511,12 @@ class OmniInference: list_output[i].append(tokens_A[i].tolist()[0]) list_output[7].append(token_T.tolist()[0]) + # Prepare model input IDs for the next iterations model_input_ids = [[] for i in range(8)] for i in range(7): tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize + i * padded_audio_vocabsize - model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32)) - model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=device)) + model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32)) + model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=self.device)) model_input_ids[i] = torch.stack(model_input_ids[i]) model_input_ids[-1].append(token_T.clone().to(torch.int32)) @@ -514,7 +542,7 @@ class OmniInference: ) if text_end: - token_T = torch.tensor([_pad_t], device=device) + token_T = torch.tensor([_pad_t], device=self.device) if tokens_A[-1] == eos_id_a: break @@ -528,11 +556,9 @@ class OmniInference: model_input_ids = [[] for i in range(8)] for i in range(7): - tokens_A[i] = tokens_A[i].clone() +padded_text_vocabsize + i * padded_audio_vocabsize - model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32)) - model_input_ids[i].append( - torch.tensor([layershift(4097, i)], device=device) - ) + tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize + i * padded_audio_vocabsize + model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32)) + model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=self.device)) model_input_ids[i] = torch.stack(model_input_ids[i]) model_input_ids[-1].append(token_T.clone().to(torch.int32)) @@ -552,6 +578,7 @@ class OmniInference: input_pos = input_pos.add_(1) index += 1 + text = self.text_tokenizer.decode(torch.tensor(list_output[-1])) print(f"text output: {text}") model.clear_kv_cache() @@ -716,4 +743,4 @@ def test_infer(): if __name__ == "__main__": - test_infer() + test_infer() \ No newline at end of file diff --git a/litgpt/model.py b/litgpt/model.py index 60df29c..dcfe2eb 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -453,7 +453,7 @@ class LLaMAMLP(nn.Module): class whisperMLP(nn.Module): - def __init__(self, config: Config) -> None: + def __init__(self, config) -> None: super().__init__() self.fc_1 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias) self.fc_2 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias) @@ -461,25 +461,33 @@ class whisperMLP(nn.Module): self.config = config + # Pooling layer for dimensionality reduction (if necessary) self.pooling = nn.AdaptiveAvgPool1d(768) def forward(self, x: torch.Tensor) -> torch.Tensor: device = x.device + + # Handle MPS-specific operations if device.type == 'mps': + # Move the tensor to CPU if necessary (for unsupported ops) x = x.to('cpu') - x = x.view(x.size(0), -1) - x = x.unsqueeze(1) - x = self.pooling(x).squeeze(1) # Remove the channel dimension after pooling + # Reshape the tensor for pooling + x = x.view(x.size(0), -1) # Flatten dimensions except the batch size + x = x.unsqueeze(1) # Add a channel dimension for pooling + x = self.pooling(x).squeeze(1) # Apply pooling and remove the channel dimension after pooling - if torch.backends.mps.is_available(): + # Make sure the tensor is back on MPS if available + if device.type == 'mps': x = x.to('mps') + # Perform the standard operations x_fc_1 = self.fc_1(x) x_fc_2 = self.fc_2(x) - x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - return self.proj(x) + x = F.silu(x_fc_1) * x_fc_2 + # Final projection + return self.proj(x) class GemmaMLP(LLaMAMLP): def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/server.py b/server.py index a0197a5..7bff98b 100644 --- a/server.py +++ b/server.py @@ -2,18 +2,20 @@ import flask import base64 import tempfile import traceback + +import torch from flask import Flask, Response, stream_with_context from inference import OmniInference class OmniChatServer(object): def __init__(self, ip='0.0.0.0', port=60808, run_app=True, - ckpt_dir='./checkpoint', device='mps') -> None: + ckpt_dir='./checkpoint', device=None) -> None: server = Flask(__name__) # CORS(server, resources=r"/*") # server.config["JSON_AS_ASCII"] = False - - self.client = OmniInference(ckpt_dir, device) + self.device = self.get_device(device) + self.client = OmniInference(ckpt_dir, self.device) self.client.warm_up() server.route("/chat", methods=["POST"])(self.chat) @@ -23,6 +25,22 @@ class OmniChatServer(object): else: self.server = server + def get_device(self, device): + if device is None: + if torch.cuda.is_available(): + return 'cuda' + elif torch.backends.mps.is_available(): + return 'mps' + else: + return 'cpu' + else: + if device == 'cuda' and torch.cuda.is_available(): + return 'cuda' + elif device == 'mps' and torch.backends.mps.is_available(): + return 'mps' + else: + return 'cpu' + def chat(self) -> Response: req_data = flask.request.get_json() @@ -50,12 +68,11 @@ def create_app(): return server.server -def serve(ip='0.0.0.0', port=60808): - OmniChatServer(ip, port=port, run_app=True) +def serve(ip='0.0.0.0', port=60808, device=None): + OmniChatServer(ip, port=port, run_app=True, device=device) if __name__ == "__main__": import fire - fire.Fire(serve) - + fire.Fire(serve) \ No newline at end of file