fix: torch mps not working

Co-Authored-By: Rich Tong <1782087+richtong@users.noreply.github.com>
This commit is contained in:
Timothy Jaeryang Baek 2025-01-06 10:08:12 -08:00
parent 15a182c9d6
commit 960683eced

View File

@ -54,6 +54,8 @@ else:
DEVICE_TYPE = "cpu"
try:
import torch
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
DEVICE_TYPE = "mps"
except Exception: