mirror of
https://github.com/gpt-omni/mini-omni
synced 2024-11-24 21:14:01 +00:00
fix device
This commit is contained in:
parent
494177255d
commit
ae04f97733
1
.gitignore
vendored
1
.gitignore
vendored
@ -4,6 +4,7 @@
|
|||||||
checkpoint/
|
checkpoint/
|
||||||
checkpoint_bak/
|
checkpoint_bak/
|
||||||
output/
|
output/
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
|
@ -494,7 +494,7 @@ class OmniInference:
|
|||||||
if current_index == nums_generate:
|
if current_index == nums_generate:
|
||||||
current_index = 0
|
current_index = 0
|
||||||
snac = get_snac(list_output, index, nums_generate)
|
snac = get_snac(list_output, index, nums_generate)
|
||||||
audio_stream = generate_audio_data(snac, self.snacmodel)
|
audio_stream = generate_audio_data(snac, self.snacmodel, self.device)
|
||||||
yield audio_stream
|
yield audio_stream
|
||||||
|
|
||||||
input_pos = input_pos.add_(1)
|
input_pos = input_pos.add_(1)
|
||||||
|
@ -21,8 +21,8 @@ def layershift(input_id, layer, stride=4160, shift=152000):
|
|||||||
return input_id + shift + layer * stride
|
return input_id + shift + layer * stride
|
||||||
|
|
||||||
|
|
||||||
def generate_audio_data(snac_tokens, snacmodel):
|
def generate_audio_data(snac_tokens, snacmodel, device=None):
|
||||||
audio = reconstruct_tensors(snac_tokens)
|
audio = reconstruct_tensors(snac_tokens, device)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
audio_hat = snacmodel.decode(audio)
|
audio_hat = snacmodel.decode(audio)
|
||||||
audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
|
audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
|
||||||
@ -55,9 +55,12 @@ def reconscruct_snac(output_list):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def reconstruct_tensors(flattened_output):
|
def reconstruct_tensors(flattened_output, device=None):
|
||||||
"""Reconstructs the list of tensors from the flattened output."""
|
"""Reconstructs the list of tensors from the flattened output."""
|
||||||
|
|
||||||
|
if device is None:
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
def count_elements_between_hashes(lst):
|
def count_elements_between_hashes(lst):
|
||||||
try:
|
try:
|
||||||
# Find the index of the first '#'
|
# Find the index of the first '#'
|
||||||
@ -107,9 +110,9 @@ def reconstruct_tensors(flattened_output):
|
|||||||
tensor3.append(flattened_output[i + 6])
|
tensor3.append(flattened_output[i + 6])
|
||||||
tensor3.append(flattened_output[i + 7])
|
tensor3.append(flattened_output[i + 7])
|
||||||
codes = [
|
codes = [
|
||||||
list_to_torch_tensor(tensor1).cuda(),
|
list_to_torch_tensor(tensor1).to(device),
|
||||||
list_to_torch_tensor(tensor2).cuda(),
|
list_to_torch_tensor(tensor2).to(device),
|
||||||
list_to_torch_tensor(tensor3).cuda(),
|
list_to_torch_tensor(tensor3).to(device),
|
||||||
]
|
]
|
||||||
|
|
||||||
if n_tensors == 15:
|
if n_tensors == 15:
|
||||||
@ -133,10 +136,10 @@ def reconstruct_tensors(flattened_output):
|
|||||||
tensor4.append(flattened_output[i + 15])
|
tensor4.append(flattened_output[i + 15])
|
||||||
|
|
||||||
codes = [
|
codes = [
|
||||||
list_to_torch_tensor(tensor1).cuda(),
|
list_to_torch_tensor(tensor1).to(device),
|
||||||
list_to_torch_tensor(tensor2).cuda(),
|
list_to_torch_tensor(tensor2).to(device),
|
||||||
list_to_torch_tensor(tensor3).cuda(),
|
list_to_torch_tensor(tensor3).to(device),
|
||||||
list_to_torch_tensor(tensor4).cuda(),
|
list_to_torch_tensor(tensor4).to(device),
|
||||||
]
|
]
|
||||||
|
|
||||||
return codes
|
return codes
|
||||||
|
Loading…
Reference in New Issue
Block a user