fix device

This commit is contained in:
mini-omni 2024-09-04 23:00:47 +03:00
parent 494177255d
commit ae04f97733
3 changed files with 15 additions and 11 deletions

1
.gitignore vendored
View File

@ -4,6 +4,7 @@
checkpoint/ checkpoint/
checkpoint_bak/ checkpoint_bak/
output/ output/
.DS_Store
__pycache__/ __pycache__/
*.py[cod] *.py[cod]

View File

@ -494,7 +494,7 @@ class OmniInference:
if current_index == nums_generate: if current_index == nums_generate:
current_index = 0 current_index = 0
snac = get_snac(list_output, index, nums_generate) snac = get_snac(list_output, index, nums_generate)
audio_stream = generate_audio_data(snac, self.snacmodel) audio_stream = generate_audio_data(snac, self.snacmodel, self.device)
yield audio_stream yield audio_stream
input_pos = input_pos.add_(1) input_pos = input_pos.add_(1)

View File

@ -21,8 +21,8 @@ def layershift(input_id, layer, stride=4160, shift=152000):
return input_id + shift + layer * stride return input_id + shift + layer * stride
def generate_audio_data(snac_tokens, snacmodel): def generate_audio_data(snac_tokens, snacmodel, device=None):
audio = reconstruct_tensors(snac_tokens) audio = reconstruct_tensors(snac_tokens, device)
with torch.inference_mode(): with torch.inference_mode():
audio_hat = snacmodel.decode(audio) audio_hat = snacmodel.decode(audio)
audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0 audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
@ -55,9 +55,12 @@ def reconscruct_snac(output_list):
return output return output
def reconstruct_tensors(flattened_output): def reconstruct_tensors(flattened_output, device=None):
"""Reconstructs the list of tensors from the flattened output.""" """Reconstructs the list of tensors from the flattened output."""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def count_elements_between_hashes(lst): def count_elements_between_hashes(lst):
try: try:
# Find the index of the first '#' # Find the index of the first '#'
@ -107,9 +110,9 @@ def reconstruct_tensors(flattened_output):
tensor3.append(flattened_output[i + 6]) tensor3.append(flattened_output[i + 6])
tensor3.append(flattened_output[i + 7]) tensor3.append(flattened_output[i + 7])
codes = [ codes = [
list_to_torch_tensor(tensor1).cuda(), list_to_torch_tensor(tensor1).to(device),
list_to_torch_tensor(tensor2).cuda(), list_to_torch_tensor(tensor2).to(device),
list_to_torch_tensor(tensor3).cuda(), list_to_torch_tensor(tensor3).to(device),
] ]
if n_tensors == 15: if n_tensors == 15:
@ -133,10 +136,10 @@ def reconstruct_tensors(flattened_output):
tensor4.append(flattened_output[i + 15]) tensor4.append(flattened_output[i + 15])
codes = [ codes = [
list_to_torch_tensor(tensor1).cuda(), list_to_torch_tensor(tensor1).to(device),
list_to_torch_tensor(tensor2).cuda(), list_to_torch_tensor(tensor2).to(device),
list_to_torch_tensor(tensor3).cuda(), list_to_torch_tensor(tensor3).to(device),
list_to_torch_tensor(tensor4).cuda(), list_to_torch_tensor(tensor4).to(device),
] ]
return codes return codes