mirror of
https://github.com/gpt-omni/mini-omni
synced 2025-06-26 18:16:26 +00:00
[edit] adjustment that will detect using gpu(mps/cuda) and cpu, so both windows and mac will be work
This commit is contained in:
143
inference.py
143
inference.py
@@ -351,56 +351,60 @@ def T1_T2(fabric, input_ids, model, text_tokenizer, step):
|
||||
|
||||
|
||||
def load_model(ckpt_dir, device):
|
||||
# Load the SNAC model
|
||||
# Load SNAC model
|
||||
snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
|
||||
|
||||
# Load Whisper model for transcription
|
||||
whispermodel = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny").to(device)
|
||||
whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||
# Load Whisper model for transcription based on the device
|
||||
if device == 'mps':
|
||||
whispermodel = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny").to(device)
|
||||
whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||
|
||||
def embed_audio(mel):
|
||||
# Convert the mel spectrogram to a format suitable for Whisper
|
||||
mel_cpu = mel.cpu().numpy()
|
||||
def embed_audio(mel):
|
||||
# Convert mel spectrogram to a format suitable for Whisper
|
||||
mel_cpu = mel.cpu().numpy()
|
||||
|
||||
if mel_cpu.ndim > 1:
|
||||
mel_cpu = librosa.to_mono(mel_cpu)
|
||||
if mel_cpu.ndim > 1:
|
||||
mel_cpu = librosa.to_mono(mel_cpu)
|
||||
|
||||
if mel_cpu.ndim != 1:
|
||||
raise ValueError(f"Audio is not mono! Shape: {mel_cpu.shape}")
|
||||
if mel_cpu.ndim != 1:
|
||||
raise ValueError(f"Audio is not mono! Shape: {mel_cpu.shape}")
|
||||
|
||||
# Process mel with Whisper processor
|
||||
inputs = whisper_processor(mel_cpu, sampling_rate=16000, return_tensors="pt").input_features.to(device)
|
||||
# Process mel with Whisper processor
|
||||
inputs = whisper_processor(mel_cpu, sampling_rate=16000, return_tensors="pt").input_features.to(device)
|
||||
|
||||
# Debugging: Log input feature shapes
|
||||
print(f"Input features shape: {inputs.shape}")
|
||||
print(f"First 5 values of input features: {inputs[0, :, :5]}")
|
||||
# Debugging: Log input feature shapes
|
||||
print(f"Input features shape: {inputs.shape}")
|
||||
print(f"First 5 values of input features: {inputs[0, :, :5]}")
|
||||
|
||||
# Set beam search and max length to force more tokens
|
||||
generated_ids = whispermodel.generate(
|
||||
inputs,
|
||||
max_length=256, # Increased length to allow for longer sentences
|
||||
num_beams=5, # Higher quality output through beam search
|
||||
early_stopping=True
|
||||
)
|
||||
# Set beam search and max length for transcription
|
||||
generated_ids = whispermodel.generate(
|
||||
inputs,
|
||||
max_length=256, # Increased length to allow for longer sentences
|
||||
num_beams=5, # Higher quality output through beam search
|
||||
early_stopping=True
|
||||
)
|
||||
|
||||
# Decode the generated tokens
|
||||
transcription = whisper_processor.decode(generated_ids[0], skip_special_tokens=True)
|
||||
# Decode the generated tokens
|
||||
transcription = whisper_processor.decode(generated_ids[0], skip_special_tokens=True)
|
||||
|
||||
# Log the transcription and token IDs
|
||||
print(f"Generated token IDs: {generated_ids[0].tolist()}")
|
||||
print(f"Transcription (for log purposes): {transcription}")
|
||||
# Log the transcription and token IDs
|
||||
print(f"Generated token IDs: {generated_ids[0].tolist()}")
|
||||
print(f"Transcription (for log purposes): {transcription}")
|
||||
|
||||
# Now return logits as per original flow
|
||||
start_token = whisper_processor.tokenizer.pad_token_id
|
||||
decoder_input_ids = torch.tensor([[start_token]], device=device)
|
||||
# Return logits as per original flow
|
||||
start_token = whisper_processor.tokenizer.pad_token_id
|
||||
decoder_input_ids = torch.tensor([[start_token]], device=device)
|
||||
|
||||
outputs = whispermodel(inputs, decoder_input_ids=decoder_input_ids)
|
||||
outputs = whispermodel(inputs, decoder_input_ids=decoder_input_ids)
|
||||
|
||||
return outputs.logits
|
||||
return outputs.logits
|
||||
|
||||
whispermodel.embed_audio = embed_audio
|
||||
whispermodel.embed_audio = embed_audio
|
||||
else:
|
||||
# If not MPS, use the regular Whisper model loading
|
||||
whispermodel = whisper.load_model("small").to(device)
|
||||
|
||||
# Load the GPT model as per your existing code
|
||||
# Load the GPT model with Fabric
|
||||
text_tokenizer = Tokenizer(ckpt_dir)
|
||||
fabric = L.Fabric(devices=1, strategy="auto")
|
||||
config = Config.from_file(ckpt_dir + "/model_config.yaml")
|
||||
@@ -417,22 +421,43 @@ def load_model(ckpt_dir, device):
|
||||
return fabric, model, text_tokenizer, snacmodel, whispermodel
|
||||
|
||||
|
||||
|
||||
def download_model(ckpt_dir):
|
||||
repo_id = "gpt-omni/mini-omni"
|
||||
snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
|
||||
|
||||
|
||||
class OmniInference:
|
||||
def __init__(self, ckpt_dir='./checkpoint', device=None):
|
||||
# Dynamically determine device (similar to OmniChatServer's get_device)
|
||||
self.device = self.get_device(device)
|
||||
|
||||
def __init__(self, ckpt_dir='./checkpoint', device='mps'):
|
||||
self.device = device
|
||||
# If checkpoint directory does not exist, download it
|
||||
if not os.path.exists(ckpt_dir):
|
||||
print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
|
||||
print(f"Checkpoint directory {ckpt_dir} not found, downloading from HuggingFace.")
|
||||
download_model(ckpt_dir)
|
||||
self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device)
|
||||
|
||||
# Load models (SNAC, Whisper, GPT)
|
||||
self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir,
|
||||
self.device)
|
||||
|
||||
def get_device(self, device):
|
||||
if device is None:
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
elif torch.backends.mps.is_available():
|
||||
return 'mps'
|
||||
else:
|
||||
return 'cpu'
|
||||
else:
|
||||
if device == 'cuda' and torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
elif device == 'mps' and torch.backends.mps.is_available():
|
||||
return 'mps'
|
||||
else:
|
||||
return 'cpu'
|
||||
|
||||
def warm_up(self, sample='./data/samples/output1.wav'):
|
||||
# Run a warm-up pass by running the AT batch stream
|
||||
for _ in self.run_AT_batch_stream(sample):
|
||||
pass
|
||||
|
||||
@@ -445,15 +470,17 @@ class OmniInference:
|
||||
top_k=1,
|
||||
top_p=1.0,
|
||||
eos_id_a=_eoa,
|
||||
eos_id_t=_eot,
|
||||
):
|
||||
eos_id_t=_eot):
|
||||
|
||||
assert os.path.exists(audio_path), f"Audio file {audio_path} not found"
|
||||
|
||||
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
|
||||
model = self.model
|
||||
|
||||
# Initialize kv cache with dynamic device handling
|
||||
with self.fabric.init_tensor():
|
||||
model.set_kv_cache(batch_size=2, device='mps')
|
||||
model.set_kv_cache(batch_size=2, device=self.device)
|
||||
|
||||
# Load the audio and process it using Whisper
|
||||
mel, leng = load_audio(audio_path)
|
||||
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
|
||||
T = input_ids[0].size(1)
|
||||
@@ -462,19 +489,19 @@ class OmniInference:
|
||||
assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}"
|
||||
|
||||
if model.max_seq_length < max_returned_tokens - 1:
|
||||
raise NotImplementedError(
|
||||
f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
|
||||
)
|
||||
raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}")
|
||||
|
||||
input_pos = torch.tensor([T], device=device)
|
||||
list_output = [[] for i in range(8)]
|
||||
|
||||
# Generate tokens
|
||||
tokens_A, token_T = next_token_batch(
|
||||
model,
|
||||
audio_feature.to(torch.float32).to(model.device),
|
||||
audio_feature.to(torch.float32).to(self.device),
|
||||
input_ids,
|
||||
[T - 3, T - 3],
|
||||
["A1T2", "A1T2"],
|
||||
input_pos=torch.arange(0, T, device=device),
|
||||
input_pos=torch.arange(0, T, device=self.device),
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
@@ -484,11 +511,12 @@ class OmniInference:
|
||||
list_output[i].append(tokens_A[i].tolist()[0])
|
||||
list_output[7].append(token_T.tolist()[0])
|
||||
|
||||
# Prepare model input IDs for the next iterations
|
||||
model_input_ids = [[] for i in range(8)]
|
||||
for i in range(7):
|
||||
tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize + i * padded_audio_vocabsize
|
||||
model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
|
||||
model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=device))
|
||||
model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32))
|
||||
model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=self.device))
|
||||
model_input_ids[i] = torch.stack(model_input_ids[i])
|
||||
|
||||
model_input_ids[-1].append(token_T.clone().to(torch.int32))
|
||||
@@ -514,7 +542,7 @@ class OmniInference:
|
||||
)
|
||||
|
||||
if text_end:
|
||||
token_T = torch.tensor([_pad_t], device=device)
|
||||
token_T = torch.tensor([_pad_t], device=self.device)
|
||||
|
||||
if tokens_A[-1] == eos_id_a:
|
||||
break
|
||||
@@ -528,11 +556,9 @@ class OmniInference:
|
||||
|
||||
model_input_ids = [[] for i in range(8)]
|
||||
for i in range(7):
|
||||
tokens_A[i] = tokens_A[i].clone() +padded_text_vocabsize + i * padded_audio_vocabsize
|
||||
model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
|
||||
model_input_ids[i].append(
|
||||
torch.tensor([layershift(4097, i)], device=device)
|
||||
)
|
||||
tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize + i * padded_audio_vocabsize
|
||||
model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32))
|
||||
model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=self.device))
|
||||
model_input_ids[i] = torch.stack(model_input_ids[i])
|
||||
|
||||
model_input_ids[-1].append(token_T.clone().to(torch.int32))
|
||||
@@ -552,6 +578,7 @@ class OmniInference:
|
||||
|
||||
input_pos = input_pos.add_(1)
|
||||
index += 1
|
||||
|
||||
text = self.text_tokenizer.decode(torch.tensor(list_output[-1]))
|
||||
print(f"text output: {text}")
|
||||
model.clear_kv_cache()
|
||||
@@ -716,4 +743,4 @@ def test_infer():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_infer()
|
||||
test_infer()
|
||||
@@ -453,7 +453,7 @@ class LLaMAMLP(nn.Module):
|
||||
|
||||
|
||||
class whisperMLP(nn.Module):
|
||||
def __init__(self, config: Config) -> None:
|
||||
def __init__(self, config) -> None:
|
||||
super().__init__()
|
||||
self.fc_1 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias)
|
||||
self.fc_2 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias)
|
||||
@@ -461,25 +461,33 @@ class whisperMLP(nn.Module):
|
||||
|
||||
self.config = config
|
||||
|
||||
# Pooling layer for dimensionality reduction (if necessary)
|
||||
self.pooling = nn.AdaptiveAvgPool1d(768)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
device = x.device
|
||||
|
||||
# Handle MPS-specific operations
|
||||
if device.type == 'mps':
|
||||
# Move the tensor to CPU if necessary (for unsupported ops)
|
||||
x = x.to('cpu')
|
||||
|
||||
x = x.view(x.size(0), -1)
|
||||
x = x.unsqueeze(1)
|
||||
x = self.pooling(x).squeeze(1) # Remove the channel dimension after pooling
|
||||
# Reshape the tensor for pooling
|
||||
x = x.view(x.size(0), -1) # Flatten dimensions except the batch size
|
||||
x = x.unsqueeze(1) # Add a channel dimension for pooling
|
||||
x = self.pooling(x).squeeze(1) # Apply pooling and remove the channel dimension after pooling
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
# Make sure the tensor is back on MPS if available
|
||||
if device.type == 'mps':
|
||||
x = x.to('mps')
|
||||
|
||||
# Perform the standard operations
|
||||
x_fc_1 = self.fc_1(x)
|
||||
x_fc_2 = self.fc_2(x)
|
||||
x = torch.nn.functional.silu(x_fc_1) * x_fc_2
|
||||
return self.proj(x)
|
||||
x = F.silu(x_fc_1) * x_fc_2
|
||||
|
||||
# Final projection
|
||||
return self.proj(x)
|
||||
|
||||
class GemmaMLP(LLaMAMLP):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
31
server.py
31
server.py
@@ -2,18 +2,20 @@ import flask
|
||||
import base64
|
||||
import tempfile
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
from flask import Flask, Response, stream_with_context
|
||||
from inference import OmniInference
|
||||
|
||||
|
||||
class OmniChatServer(object):
|
||||
def __init__(self, ip='0.0.0.0', port=60808, run_app=True,
|
||||
ckpt_dir='./checkpoint', device='mps') -> None:
|
||||
ckpt_dir='./checkpoint', device=None) -> None:
|
||||
server = Flask(__name__)
|
||||
# CORS(server, resources=r"/*")
|
||||
# server.config["JSON_AS_ASCII"] = False
|
||||
|
||||
self.client = OmniInference(ckpt_dir, device)
|
||||
self.device = self.get_device(device)
|
||||
self.client = OmniInference(ckpt_dir, self.device)
|
||||
self.client.warm_up()
|
||||
|
||||
server.route("/chat", methods=["POST"])(self.chat)
|
||||
@@ -23,6 +25,22 @@ class OmniChatServer(object):
|
||||
else:
|
||||
self.server = server
|
||||
|
||||
def get_device(self, device):
|
||||
if device is None:
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
elif torch.backends.mps.is_available():
|
||||
return 'mps'
|
||||
else:
|
||||
return 'cpu'
|
||||
else:
|
||||
if device == 'cuda' and torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
elif device == 'mps' and torch.backends.mps.is_available():
|
||||
return 'mps'
|
||||
else:
|
||||
return 'cpu'
|
||||
|
||||
def chat(self) -> Response:
|
||||
|
||||
req_data = flask.request.get_json()
|
||||
@@ -50,12 +68,11 @@ def create_app():
|
||||
return server.server
|
||||
|
||||
|
||||
def serve(ip='0.0.0.0', port=60808):
|
||||
OmniChatServer(ip, port=port, run_app=True)
|
||||
def serve(ip='0.0.0.0', port=60808, device=None):
|
||||
OmniChatServer(ip, port=port, run_app=True, device=device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
fire.Fire(serve)
|
||||
|
||||
fire.Fire(serve)
|
||||
Reference in New Issue
Block a user