[edit] adjustment that will detect using gpu(mps/cuda) and cpu, so both windows and mac will be work

This commit is contained in:
kunci115
2024-09-07 14:53:02 +07:00
parent 01c4e5e133
commit 962e71218c
3 changed files with 124 additions and 72 deletions

View File

@@ -351,56 +351,60 @@ def T1_T2(fabric, input_ids, model, text_tokenizer, step):
def load_model(ckpt_dir, device):
# Load the SNAC model
# Load SNAC model
snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
# Load Whisper model for transcription
whispermodel = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny").to(device)
whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
# Load Whisper model for transcription based on the device
if device == 'mps':
whispermodel = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny").to(device)
whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
def embed_audio(mel):
# Convert the mel spectrogram to a format suitable for Whisper
mel_cpu = mel.cpu().numpy()
def embed_audio(mel):
# Convert mel spectrogram to a format suitable for Whisper
mel_cpu = mel.cpu().numpy()
if mel_cpu.ndim > 1:
mel_cpu = librosa.to_mono(mel_cpu)
if mel_cpu.ndim > 1:
mel_cpu = librosa.to_mono(mel_cpu)
if mel_cpu.ndim != 1:
raise ValueError(f"Audio is not mono! Shape: {mel_cpu.shape}")
if mel_cpu.ndim != 1:
raise ValueError(f"Audio is not mono! Shape: {mel_cpu.shape}")
# Process mel with Whisper processor
inputs = whisper_processor(mel_cpu, sampling_rate=16000, return_tensors="pt").input_features.to(device)
# Process mel with Whisper processor
inputs = whisper_processor(mel_cpu, sampling_rate=16000, return_tensors="pt").input_features.to(device)
# Debugging: Log input feature shapes
print(f"Input features shape: {inputs.shape}")
print(f"First 5 values of input features: {inputs[0, :, :5]}")
# Debugging: Log input feature shapes
print(f"Input features shape: {inputs.shape}")
print(f"First 5 values of input features: {inputs[0, :, :5]}")
# Set beam search and max length to force more tokens
generated_ids = whispermodel.generate(
inputs,
max_length=256, # Increased length to allow for longer sentences
num_beams=5, # Higher quality output through beam search
early_stopping=True
)
# Set beam search and max length for transcription
generated_ids = whispermodel.generate(
inputs,
max_length=256, # Increased length to allow for longer sentences
num_beams=5, # Higher quality output through beam search
early_stopping=True
)
# Decode the generated tokens
transcription = whisper_processor.decode(generated_ids[0], skip_special_tokens=True)
# Decode the generated tokens
transcription = whisper_processor.decode(generated_ids[0], skip_special_tokens=True)
# Log the transcription and token IDs
print(f"Generated token IDs: {generated_ids[0].tolist()}")
print(f"Transcription (for log purposes): {transcription}")
# Log the transcription and token IDs
print(f"Generated token IDs: {generated_ids[0].tolist()}")
print(f"Transcription (for log purposes): {transcription}")
# Now return logits as per original flow
start_token = whisper_processor.tokenizer.pad_token_id
decoder_input_ids = torch.tensor([[start_token]], device=device)
# Return logits as per original flow
start_token = whisper_processor.tokenizer.pad_token_id
decoder_input_ids = torch.tensor([[start_token]], device=device)
outputs = whispermodel(inputs, decoder_input_ids=decoder_input_ids)
outputs = whispermodel(inputs, decoder_input_ids=decoder_input_ids)
return outputs.logits
return outputs.logits
whispermodel.embed_audio = embed_audio
whispermodel.embed_audio = embed_audio
else:
# If not MPS, use the regular Whisper model loading
whispermodel = whisper.load_model("small").to(device)
# Load the GPT model as per your existing code
# Load the GPT model with Fabric
text_tokenizer = Tokenizer(ckpt_dir)
fabric = L.Fabric(devices=1, strategy="auto")
config = Config.from_file(ckpt_dir + "/model_config.yaml")
@@ -417,22 +421,43 @@ def load_model(ckpt_dir, device):
return fabric, model, text_tokenizer, snacmodel, whispermodel
def download_model(ckpt_dir):
repo_id = "gpt-omni/mini-omni"
snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
class OmniInference:
def __init__(self, ckpt_dir='./checkpoint', device=None):
# Dynamically determine device (similar to OmniChatServer's get_device)
self.device = self.get_device(device)
def __init__(self, ckpt_dir='./checkpoint', device='mps'):
self.device = device
# If checkpoint directory does not exist, download it
if not os.path.exists(ckpt_dir):
print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
print(f"Checkpoint directory {ckpt_dir} not found, downloading from HuggingFace.")
download_model(ckpt_dir)
self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device)
# Load models (SNAC, Whisper, GPT)
self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir,
self.device)
def get_device(self, device):
if device is None:
if torch.cuda.is_available():
return 'cuda'
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
else:
if device == 'cuda' and torch.cuda.is_available():
return 'cuda'
elif device == 'mps' and torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
def warm_up(self, sample='./data/samples/output1.wav'):
# Run a warm-up pass by running the AT batch stream
for _ in self.run_AT_batch_stream(sample):
pass
@@ -445,15 +470,17 @@ class OmniInference:
top_k=1,
top_p=1.0,
eos_id_a=_eoa,
eos_id_t=_eot,
):
eos_id_t=_eot):
assert os.path.exists(audio_path), f"Audio file {audio_path} not found"
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
model = self.model
# Initialize kv cache with dynamic device handling
with self.fabric.init_tensor():
model.set_kv_cache(batch_size=2, device='mps')
model.set_kv_cache(batch_size=2, device=self.device)
# Load the audio and process it using Whisper
mel, leng = load_audio(audio_path)
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
T = input_ids[0].size(1)
@@ -462,19 +489,19 @@ class OmniInference:
assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}"
if model.max_seq_length < max_returned_tokens - 1:
raise NotImplementedError(
f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
)
raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}")
input_pos = torch.tensor([T], device=device)
list_output = [[] for i in range(8)]
# Generate tokens
tokens_A, token_T = next_token_batch(
model,
audio_feature.to(torch.float32).to(model.device),
audio_feature.to(torch.float32).to(self.device),
input_ids,
[T - 3, T - 3],
["A1T2", "A1T2"],
input_pos=torch.arange(0, T, device=device),
input_pos=torch.arange(0, T, device=self.device),
temperature=temperature,
top_k=top_k,
top_p=top_p,
@@ -484,11 +511,12 @@ class OmniInference:
list_output[i].append(tokens_A[i].tolist()[0])
list_output[7].append(token_T.tolist()[0])
# Prepare model input IDs for the next iterations
model_input_ids = [[] for i in range(8)]
for i in range(7):
tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize + i * padded_audio_vocabsize
model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=device))
model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32))
model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=self.device))
model_input_ids[i] = torch.stack(model_input_ids[i])
model_input_ids[-1].append(token_T.clone().to(torch.int32))
@@ -514,7 +542,7 @@ class OmniInference:
)
if text_end:
token_T = torch.tensor([_pad_t], device=device)
token_T = torch.tensor([_pad_t], device=self.device)
if tokens_A[-1] == eos_id_a:
break
@@ -528,11 +556,9 @@ class OmniInference:
model_input_ids = [[] for i in range(8)]
for i in range(7):
tokens_A[i] = tokens_A[i].clone() +padded_text_vocabsize + i * padded_audio_vocabsize
model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
model_input_ids[i].append(
torch.tensor([layershift(4097, i)], device=device)
)
tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize + i * padded_audio_vocabsize
model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32))
model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=self.device))
model_input_ids[i] = torch.stack(model_input_ids[i])
model_input_ids[-1].append(token_T.clone().to(torch.int32))
@@ -552,6 +578,7 @@ class OmniInference:
input_pos = input_pos.add_(1)
index += 1
text = self.text_tokenizer.decode(torch.tensor(list_output[-1]))
print(f"text output: {text}")
model.clear_kv_cache()
@@ -716,4 +743,4 @@ def test_infer():
if __name__ == "__main__":
test_infer()
test_infer()

View File

@@ -453,7 +453,7 @@ class LLaMAMLP(nn.Module):
class whisperMLP(nn.Module):
def __init__(self, config: Config) -> None:
def __init__(self, config) -> None:
super().__init__()
self.fc_1 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias)
self.fc_2 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias)
@@ -461,25 +461,33 @@ class whisperMLP(nn.Module):
self.config = config
# Pooling layer for dimensionality reduction (if necessary)
self.pooling = nn.AdaptiveAvgPool1d(768)
def forward(self, x: torch.Tensor) -> torch.Tensor:
device = x.device
# Handle MPS-specific operations
if device.type == 'mps':
# Move the tensor to CPU if necessary (for unsupported ops)
x = x.to('cpu')
x = x.view(x.size(0), -1)
x = x.unsqueeze(1)
x = self.pooling(x).squeeze(1) # Remove the channel dimension after pooling
# Reshape the tensor for pooling
x = x.view(x.size(0), -1) # Flatten dimensions except the batch size
x = x.unsqueeze(1) # Add a channel dimension for pooling
x = self.pooling(x).squeeze(1) # Apply pooling and remove the channel dimension after pooling
if torch.backends.mps.is_available():
# Make sure the tensor is back on MPS if available
if device.type == 'mps':
x = x.to('mps')
# Perform the standard operations
x_fc_1 = self.fc_1(x)
x_fc_2 = self.fc_2(x)
x = torch.nn.functional.silu(x_fc_1) * x_fc_2
return self.proj(x)
x = F.silu(x_fc_1) * x_fc_2
# Final projection
return self.proj(x)
class GemmaMLP(LLaMAMLP):
def forward(self, x: torch.Tensor) -> torch.Tensor:

View File

@@ -2,18 +2,20 @@ import flask
import base64
import tempfile
import traceback
import torch
from flask import Flask, Response, stream_with_context
from inference import OmniInference
class OmniChatServer(object):
def __init__(self, ip='0.0.0.0', port=60808, run_app=True,
ckpt_dir='./checkpoint', device='mps') -> None:
ckpt_dir='./checkpoint', device=None) -> None:
server = Flask(__name__)
# CORS(server, resources=r"/*")
# server.config["JSON_AS_ASCII"] = False
self.client = OmniInference(ckpt_dir, device)
self.device = self.get_device(device)
self.client = OmniInference(ckpt_dir, self.device)
self.client.warm_up()
server.route("/chat", methods=["POST"])(self.chat)
@@ -23,6 +25,22 @@ class OmniChatServer(object):
else:
self.server = server
def get_device(self, device):
if device is None:
if torch.cuda.is_available():
return 'cuda'
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
else:
if device == 'cuda' and torch.cuda.is_available():
return 'cuda'
elif device == 'mps' and torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
def chat(self) -> Response:
req_data = flask.request.get_json()
@@ -50,12 +68,11 @@ def create_app():
return server.server
def serve(ip='0.0.0.0', port=60808):
OmniChatServer(ip, port=port, run_app=True)
def serve(ip='0.0.0.0', port=60808, device=None):
OmniChatServer(ip, port=port, run_app=True, device=device)
if __name__ == "__main__":
import fire
fire.Fire(serve)
fire.Fire(serve)