mirror of
https://github.com/gpt-omni/mini-omni
synced 2024-11-21 15:27:37 +00:00
fix device
This commit is contained in:
parent
494177255d
commit
ae04f97733
1
.gitignore
vendored
1
.gitignore
vendored
@ -4,6 +4,7 @@
|
||||
checkpoint/
|
||||
checkpoint_bak/
|
||||
output/
|
||||
.DS_Store
|
||||
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
@ -494,7 +494,7 @@ class OmniInference:
|
||||
if current_index == nums_generate:
|
||||
current_index = 0
|
||||
snac = get_snac(list_output, index, nums_generate)
|
||||
audio_stream = generate_audio_data(snac, self.snacmodel)
|
||||
audio_stream = generate_audio_data(snac, self.snacmodel, self.device)
|
||||
yield audio_stream
|
||||
|
||||
input_pos = input_pos.add_(1)
|
||||
|
@ -21,8 +21,8 @@ def layershift(input_id, layer, stride=4160, shift=152000):
|
||||
return input_id + shift + layer * stride
|
||||
|
||||
|
||||
def generate_audio_data(snac_tokens, snacmodel):
|
||||
audio = reconstruct_tensors(snac_tokens)
|
||||
def generate_audio_data(snac_tokens, snacmodel, device=None):
|
||||
audio = reconstruct_tensors(snac_tokens, device)
|
||||
with torch.inference_mode():
|
||||
audio_hat = snacmodel.decode(audio)
|
||||
audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
|
||||
@ -55,9 +55,12 @@ def reconscruct_snac(output_list):
|
||||
return output
|
||||
|
||||
|
||||
def reconstruct_tensors(flattened_output):
|
||||
def reconstruct_tensors(flattened_output, device=None):
|
||||
"""Reconstructs the list of tensors from the flattened output."""
|
||||
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def count_elements_between_hashes(lst):
|
||||
try:
|
||||
# Find the index of the first '#'
|
||||
@ -107,9 +110,9 @@ def reconstruct_tensors(flattened_output):
|
||||
tensor3.append(flattened_output[i + 6])
|
||||
tensor3.append(flattened_output[i + 7])
|
||||
codes = [
|
||||
list_to_torch_tensor(tensor1).cuda(),
|
||||
list_to_torch_tensor(tensor2).cuda(),
|
||||
list_to_torch_tensor(tensor3).cuda(),
|
||||
list_to_torch_tensor(tensor1).to(device),
|
||||
list_to_torch_tensor(tensor2).to(device),
|
||||
list_to_torch_tensor(tensor3).to(device),
|
||||
]
|
||||
|
||||
if n_tensors == 15:
|
||||
@ -133,10 +136,10 @@ def reconstruct_tensors(flattened_output):
|
||||
tensor4.append(flattened_output[i + 15])
|
||||
|
||||
codes = [
|
||||
list_to_torch_tensor(tensor1).cuda(),
|
||||
list_to_torch_tensor(tensor2).cuda(),
|
||||
list_to_torch_tensor(tensor3).cuda(),
|
||||
list_to_torch_tensor(tensor4).cuda(),
|
||||
list_to_torch_tensor(tensor1).to(device),
|
||||
list_to_torch_tensor(tensor2).to(device),
|
||||
list_to_torch_tensor(tensor3).to(device),
|
||||
list_to_torch_tensor(tensor4).to(device),
|
||||
]
|
||||
|
||||
return codes
|
||||
|
Loading…
Reference in New Issue
Block a user