Allow seting CORS origin

This commit is contained in:
Craig Quiter 2024-08-18 14:17:26 -07:00
parent 446b2a334a
commit d2f10d50bf
8 changed files with 45 additions and 14 deletions

View File

@ -38,6 +38,7 @@ from config import (
AUDIO_TTS_MODEL, AUDIO_TTS_MODEL,
AUDIO_TTS_VOICE, AUDIO_TTS_VOICE,
AppConfig, AppConfig,
CORS_ALLOW_ORIGIN,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import ( from utils.utils import (
@ -52,7 +53,7 @@ log.setLevel(SRC_LOG_LEVELS["AUDIO"])
app = FastAPI() app = FastAPI()
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], allow_origins=CORS_ALLOW_ORIGIN,
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],

View File

@ -51,6 +51,7 @@ from config import (
IMAGE_SIZE, IMAGE_SIZE,
IMAGE_STEPS, IMAGE_STEPS,
AppConfig, AppConfig,
CORS_ALLOW_ORIGIN,
) )
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -62,7 +63,7 @@ IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
app = FastAPI() app = FastAPI()
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], allow_origins=CORS_ALLOW_ORIGIN,
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],

View File

@ -41,6 +41,7 @@ from config import (
MODEL_FILTER_LIST, MODEL_FILTER_LIST,
UPLOAD_DIR, UPLOAD_DIR,
AppConfig, AppConfig,
CORS_ALLOW_ORIGIN,
) )
from utils.misc import ( from utils.misc import (
calculate_sha256, calculate_sha256,
@ -55,7 +56,7 @@ log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
app = FastAPI() app = FastAPI()
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], allow_origins=CORS_ALLOW_ORIGIN,
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],

View File

@ -32,6 +32,7 @@ from config import (
ENABLE_MODEL_FILTER, ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST, MODEL_FILTER_LIST,
AppConfig, AppConfig,
CORS_ALLOW_ORIGIN,
) )
from typing import Optional, Literal, overload from typing import Optional, Literal, overload
@ -45,7 +46,7 @@ log.setLevel(SRC_LOG_LEVELS["OPENAI"])
app = FastAPI() app = FastAPI()
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], allow_origins=CORS_ALLOW_ORIGIN,
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],

View File

@ -129,6 +129,7 @@ from config import (
RAG_WEB_SEARCH_RESULT_COUNT, RAG_WEB_SEARCH_RESULT_COUNT,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS, RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
RAG_EMBEDDING_OPENAI_BATCH_SIZE, RAG_EMBEDDING_OPENAI_BATCH_SIZE,
CORS_ALLOW_ORIGIN,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -240,12 +241,9 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
) )
origins = ["*"]
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=origins, allow_origins=CORS_ALLOW_ORIGIN,
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],

View File

@ -47,6 +47,7 @@ from config import (
OAUTH_USERNAME_CLAIM, OAUTH_USERNAME_CLAIM,
OAUTH_PICTURE_CLAIM, OAUTH_PICTURE_CLAIM,
OAUTH_EMAIL_CLAIM, OAUTH_EMAIL_CLAIM,
CORS_ALLOW_ORIGIN,
) )
from apps.socket.main import get_event_call, get_event_emitter from apps.socket.main import get_event_call, get_event_emitter
@ -59,8 +60,6 @@ from pydantic import BaseModel
app = FastAPI() app = FastAPI()
origins = ["*"]
app.state.config = AppConfig() app.state.config = AppConfig()
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
@ -93,7 +92,7 @@ app.state.FUNCTIONS = {}
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=origins, allow_origins=CORS_ALLOW_ORIGIN,
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],

View File

@ -3,6 +3,8 @@ import sys
import logging import logging
import importlib.metadata import importlib.metadata
import pkgutil import pkgutil
from urllib.parse import urlparse
import chromadb import chromadb
from chromadb import Settings from chromadb import Settings
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
@ -840,6 +842,35 @@ ENABLE_COMMUNITY_SHARING = PersistentConfig(
os.environ.get("ENABLE_COMMUNITY_SHARING", "True").lower() == "true", os.environ.get("ENABLE_COMMUNITY_SHARING", "True").lower() == "true",
) )
def validate_cors_origins(origins):
for origin in origins:
if origin != "*":
validate_cors_origin(origin)
def validate_cors_origin(origin):
parsed_url = urlparse(origin)
# Check if the scheme is either http or https
if parsed_url.scheme not in ["http", "https"]:
raise ValueError(f"Invalid scheme in CORS_ALLOW_ORIGIN: '{origin}'. Only 'http' and 'https' are allowed.")
# Ensure that the netloc (domain + port) is present, indicating it's a valid URL
if not parsed_url.netloc:
raise ValueError(f"Invalid URL structure in CORS_ALLOW_ORIGIN: '{origin}'.")
# For production, you should only need one host as
# fastapi serves the svelte-kit built frontend and backend from the same host and port.
# To test CORS_ALLOW_ORIGIN locally, you can set something like
# CORS_ALLOW_ORIGIN=http://localhost:5173;http://localhost:8080
# in your .env file depending on your frontend port, 5173 in this case.
CORS_ALLOW_ORIGIN = os.environ.get("CORS_ALLOW_ORIGIN", "*").split(";")
if "*" in CORS_ALLOW_ORIGIN:
log.warning("\n\nWARNING: CORS_ALLOW_ORIGIN IS SET TO '*' - NOT RECOMMENDED FOR PRODUCTION DEPLOYMENTS.\n")
validate_cors_origins(CORS_ALLOW_ORIGIN)
class BannerModel(BaseModel): class BannerModel(BaseModel):
id: str id: str

View File

@ -119,6 +119,7 @@ from config import (
WEBUI_SESSION_COOKIE_SECURE, WEBUI_SESSION_COOKIE_SECURE,
ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_CHAT_ACCESS,
AppConfig, AppConfig,
CORS_ALLOW_ORIGIN,
) )
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
@ -209,8 +210,6 @@ app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
app.state.MODELS = {} app.state.MODELS = {}
origins = ["*"]
################################## ##################################
# #
@ -833,7 +832,7 @@ app.add_middleware(PipelineMiddleware)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=origins, allow_origins=CORS_ALLOW_ORIGIN,
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],