From d2f10d50bf30e77a0355bf9511bd56060580656d Mon Sep 17 00:00:00 2001 From: Craig Quiter Date: Sun, 18 Aug 2024 14:17:26 -0700 Subject: [PATCH] Allow seting CORS origin --- backend/apps/audio/main.py | 3 ++- backend/apps/images/main.py | 3 ++- backend/apps/ollama/main.py | 3 ++- backend/apps/openai/main.py | 3 ++- backend/apps/rag/main.py | 6 ++---- backend/apps/webui/main.py | 5 ++--- backend/config.py | 31 +++++++++++++++++++++++++++++++ backend/main.py | 5 ++--- 8 files changed, 45 insertions(+), 14 deletions(-) diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py index 2121ffe6f..d66a9fa11 100644 --- a/backend/apps/audio/main.py +++ b/backend/apps/audio/main.py @@ -38,6 +38,7 @@ from config import ( AUDIO_TTS_MODEL, AUDIO_TTS_VOICE, AppConfig, + CORS_ALLOW_ORIGIN, ) from constants import ERROR_MESSAGES from utils.utils import ( @@ -52,7 +53,7 @@ log.setLevel(SRC_LOG_LEVELS["AUDIO"]) app = FastAPI() app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index d2f5ddd5d..ad7d1a234 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -51,6 +51,7 @@ from config import ( IMAGE_SIZE, IMAGE_STEPS, AppConfig, + CORS_ALLOW_ORIGIN, ) log = logging.getLogger(__name__) @@ -62,7 +63,7 @@ IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True) app = FastAPI() app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 810a05999..cc82089b5 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -41,6 +41,7 @@ from config import ( MODEL_FILTER_LIST, UPLOAD_DIR, AppConfig, + CORS_ALLOW_ORIGIN, ) from utils.misc import ( calculate_sha256, @@ -55,7 +56,7 @@ log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) app = FastAPI() app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index d344c6622..9ad67c40c 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -32,6 +32,7 @@ from config import ( ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, AppConfig, + CORS_ALLOW_ORIGIN, ) from typing import Optional, Literal, overload @@ -45,7 +46,7 @@ log.setLevel(SRC_LOG_LEVELS["OPENAI"]) app = FastAPI() app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index f9788556b..7b2fbc679 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -129,6 +129,7 @@ from config import ( RAG_WEB_SEARCH_RESULT_COUNT, RAG_WEB_SEARCH_CONCURRENT_REQUESTS, RAG_EMBEDDING_OPENAI_BATCH_SIZE, + CORS_ALLOW_ORIGIN, ) from constants import ERROR_MESSAGES @@ -240,12 +241,9 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, ) -origins = ["*"] - - app.add_middleware( CORSMiddleware, - allow_origins=origins, + allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index dddf3fbb2..2fd73a22c 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -47,6 +47,7 @@ from config import ( OAUTH_USERNAME_CLAIM, OAUTH_PICTURE_CLAIM, OAUTH_EMAIL_CLAIM, + CORS_ALLOW_ORIGIN, ) from apps.socket.main import get_event_call, get_event_emitter @@ -59,8 +60,6 @@ from pydantic import BaseModel app = FastAPI() -origins = ["*"] - app.state.config = AppConfig() app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP @@ -93,7 +92,7 @@ app.state.FUNCTIONS = {} app.add_middleware( CORSMiddleware, - allow_origins=origins, + allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], diff --git a/backend/config.py b/backend/config.py index 3453ee24b..46e9b1a73 100644 --- a/backend/config.py +++ b/backend/config.py @@ -3,6 +3,8 @@ import sys import logging import importlib.metadata import pkgutil +from urllib.parse import urlparse + import chromadb from chromadb import Settings from bs4 import BeautifulSoup @@ -840,6 +842,35 @@ ENABLE_COMMUNITY_SHARING = PersistentConfig( os.environ.get("ENABLE_COMMUNITY_SHARING", "True").lower() == "true", ) +def validate_cors_origins(origins): + for origin in origins: + if origin != "*": + validate_cors_origin(origin) + + +def validate_cors_origin(origin): + parsed_url = urlparse(origin) + + # Check if the scheme is either http or https + if parsed_url.scheme not in ["http", "https"]: + raise ValueError(f"Invalid scheme in CORS_ALLOW_ORIGIN: '{origin}'. Only 'http' and 'https' are allowed.") + + # Ensure that the netloc (domain + port) is present, indicating it's a valid URL + if not parsed_url.netloc: + raise ValueError(f"Invalid URL structure in CORS_ALLOW_ORIGIN: '{origin}'.") + + +# For production, you should only need one host as +# fastapi serves the svelte-kit built frontend and backend from the same host and port. +# To test CORS_ALLOW_ORIGIN locally, you can set something like +# CORS_ALLOW_ORIGIN=http://localhost:5173;http://localhost:8080 +# in your .env file depending on your frontend port, 5173 in this case. +CORS_ALLOW_ORIGIN = os.environ.get("CORS_ALLOW_ORIGIN", "*").split(";") + +if "*" in CORS_ALLOW_ORIGIN: + log.warning("\n\nWARNING: CORS_ALLOW_ORIGIN IS SET TO '*' - NOT RECOMMENDED FOR PRODUCTION DEPLOYMENTS.\n") + +validate_cors_origins(CORS_ALLOW_ORIGIN) class BannerModel(BaseModel): id: str diff --git a/backend/main.py b/backend/main.py index d539834ed..7377cc2df 100644 --- a/backend/main.py +++ b/backend/main.py @@ -119,6 +119,7 @@ from config import ( WEBUI_SESSION_COOKIE_SECURE, ENABLE_ADMIN_CHAT_ACCESS, AppConfig, + CORS_ALLOW_ORIGIN, ) from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS @@ -209,8 +210,6 @@ app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( app.state.MODELS = {} -origins = ["*"] - ################################## # @@ -833,7 +832,7 @@ app.add_middleware(PipelineMiddleware) app.add_middleware( CORSMiddleware, - allow_origins=origins, + allow_origins=CORS_ALLOW_ORIGIN, allow_credentials=True, allow_methods=["*"], allow_headers=["*"],