fix device

This commit is contained in:
mini-omni 2024-09-04 23:00:47 +03:00
parent 494177255d
commit ae04f97733
3 changed files with 15 additions and 11 deletions

1
.gitignore vendored
View File

@ -4,6 +4,7 @@
checkpoint/
checkpoint_bak/
output/
.DS_Store
__pycache__/
*.py[cod]

View File

@ -494,7 +494,7 @@ class OmniInference:
if current_index == nums_generate:
current_index = 0
snac = get_snac(list_output, index, nums_generate)
audio_stream = generate_audio_data(snac, self.snacmodel)
audio_stream = generate_audio_data(snac, self.snacmodel, self.device)
yield audio_stream
input_pos = input_pos.add_(1)

View File

@ -21,8 +21,8 @@ def layershift(input_id, layer, stride=4160, shift=152000):
return input_id + shift + layer * stride
def generate_audio_data(snac_tokens, snacmodel):
audio = reconstruct_tensors(snac_tokens)
def generate_audio_data(snac_tokens, snacmodel, device=None):
audio = reconstruct_tensors(snac_tokens, device)
with torch.inference_mode():
audio_hat = snacmodel.decode(audio)
audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
@ -55,9 +55,12 @@ def reconscruct_snac(output_list):
return output
def reconstruct_tensors(flattened_output):
def reconstruct_tensors(flattened_output, device=None):
"""Reconstructs the list of tensors from the flattened output."""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def count_elements_between_hashes(lst):
try:
# Find the index of the first '#'
@ -107,9 +110,9 @@ def reconstruct_tensors(flattened_output):
tensor3.append(flattened_output[i + 6])
tensor3.append(flattened_output[i + 7])
codes = [
list_to_torch_tensor(tensor1).cuda(),
list_to_torch_tensor(tensor2).cuda(),
list_to_torch_tensor(tensor3).cuda(),
list_to_torch_tensor(tensor1).to(device),
list_to_torch_tensor(tensor2).to(device),
list_to_torch_tensor(tensor3).to(device),
]
if n_tensors == 15:
@ -133,10 +136,10 @@ def reconstruct_tensors(flattened_output):
tensor4.append(flattened_output[i + 15])
codes = [
list_to_torch_tensor(tensor1).cuda(),
list_to_torch_tensor(tensor2).cuda(),
list_to_torch_tensor(tensor3).cuda(),
list_to_torch_tensor(tensor4).cuda(),
list_to_torch_tensor(tensor1).to(device),
list_to_torch_tensor(tensor2).to(device),
list_to_torch_tensor(tensor3).to(device),
list_to_torch_tensor(tensor4).to(device),
]
return codes