Merge pull request #5400 from thiswillbeyourgithub/fix_fallback_cuda

fix: if cuda is not available fallback to cpu
This commit is contained in:
Timothy Jaeryang Baek 2024-09-14 20:57:14 +01:00 committed by GitHub
commit 2f9f568dd9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 29 additions and 1 deletions

View File

@ -39,6 +39,19 @@ def serve(
"/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib", "/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib",
] ]
) )
try:
import torch
assert torch.cuda.is_available(), "CUDA not available"
typer.echo("CUDA seems to be working")
except Exception as e:
typer.echo(
"Error when testing CUDA but USE_CUDA_DOCKER is true. "
"Resetting USE_CUDA_DOCKER to false and removing "
f"LD_LIBRARY_PATH modifications: {e}"
)
os.environ["USE_CUDA_DOCKER"] = "false"
os.environ["LD_LIBRARY_PATH"] = ":".join(LD_LIBRARY_PATH)
import open_webui.main # we need set environment variables before importing main import open_webui.main # we need set environment variables before importing main
uvicorn.run(open_webui.main.app, host=host, port=port, forwarded_allow_ips="*") uvicorn.run(open_webui.main.app, host=host, port=port, forwarded_allow_ips="*")

View File

@ -36,7 +36,19 @@ except ImportError:
USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
if USE_CUDA.lower() == "true": if USE_CUDA.lower() == "true":
try:
import torch
assert torch.cuda.is_available(), "CUDA not available"
DEVICE_TYPE = "cuda" DEVICE_TYPE = "cuda"
except Exception as e:
cuda_error = (
"Error when testing CUDA but USE_CUDA_DOCKER is true. "
f"Resetting USE_CUDA_DOCKER to false: {e}"
)
os.environ["USE_CUDA_DOCKER"] = "false"
USE_CUDA = "false"
DEVICE_TYPE = "cpu"
else: else:
DEVICE_TYPE = "cpu" DEVICE_TYPE = "cpu"
@ -56,6 +68,9 @@ else:
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}") log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
if "cuda_error" in locals():
log.exception(cuda_error)
log_sources = [ log_sources = [
"AUDIO", "AUDIO",
"COMFYUI", "COMFYUI",