diff --git a/.gitignore b/.gitignore index efd1a54..d0158c6 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ checkpoint/ checkpoint_bak/ output/ +.DS_Store __pycache__/ *.py[cod] diff --git a/inference.py b/inference.py index 598c6cb..4d721d0 100644 --- a/inference.py +++ b/inference.py @@ -494,7 +494,7 @@ class OmniInference: if current_index == nums_generate: current_index = 0 snac = get_snac(list_output, index, nums_generate) - audio_stream = generate_audio_data(snac, self.snacmodel) + audio_stream = generate_audio_data(snac, self.snacmodel, self.device) yield audio_stream input_pos = input_pos.add_(1) diff --git a/utils/snac_utils.py b/utils/snac_utils.py index a2a4a6a..e66eccc 100644 --- a/utils/snac_utils.py +++ b/utils/snac_utils.py @@ -21,8 +21,8 @@ def layershift(input_id, layer, stride=4160, shift=152000): return input_id + shift + layer * stride -def generate_audio_data(snac_tokens, snacmodel): - audio = reconstruct_tensors(snac_tokens) +def generate_audio_data(snac_tokens, snacmodel, device=None): + audio = reconstruct_tensors(snac_tokens, device) with torch.inference_mode(): audio_hat = snacmodel.decode(audio) audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0 @@ -55,9 +55,12 @@ def reconscruct_snac(output_list): return output -def reconstruct_tensors(flattened_output): +def reconstruct_tensors(flattened_output, device=None): """Reconstructs the list of tensors from the flattened output.""" + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + def count_elements_between_hashes(lst): try: # Find the index of the first '#' @@ -107,9 +110,9 @@ def reconstruct_tensors(flattened_output): tensor3.append(flattened_output[i + 6]) tensor3.append(flattened_output[i + 7]) codes = [ - list_to_torch_tensor(tensor1).cuda(), - list_to_torch_tensor(tensor2).cuda(), - list_to_torch_tensor(tensor3).cuda(), + list_to_torch_tensor(tensor1).to(device), + list_to_torch_tensor(tensor2).to(device), + list_to_torch_tensor(tensor3).to(device), ] if n_tensors == 15: @@ -133,10 +136,10 @@ def reconstruct_tensors(flattened_output): tensor4.append(flattened_output[i + 15]) codes = [ - list_to_torch_tensor(tensor1).cuda(), - list_to_torch_tensor(tensor2).cuda(), - list_to_torch_tensor(tensor3).cuda(), - list_to_torch_tensor(tensor4).cuda(), + list_to_torch_tensor(tensor1).to(device), + list_to_torch_tensor(tensor2).to(device), + list_to_torch_tensor(tensor3).to(device), + list_to_torch_tensor(tensor4).to(device), ] return codes