From eb0e683b4762434efe189dc9aad9cef1f3f36cd3 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 10 Sep 2024 01:34:27 +0100 Subject: [PATCH 1/2] refac --- backend/open_webui/config.py | 73 +++++++++++++++++++----------------- 1 file changed, 38 insertions(+), 35 deletions(-) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 5ccb40d47..8070d3cab 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -923,24 +923,9 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( #################################### -# RAG document content extraction +# Vector Database #################################### -CONTENT_EXTRACTION_ENGINE = PersistentConfig( - "CONTENT_EXTRACTION_ENGINE", - "rag.CONTENT_EXTRACTION_ENGINE", - os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(), -) - -TIKA_SERVER_URL = PersistentConfig( - "TIKA_SERVER_URL", - "rag.tika_server_url", - os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment -) - -#################################### -# RAG -#################################### CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT) @@ -958,6 +943,43 @@ else: CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) + +if CHROMA_HTTP_HOST != "": + CHROMA_CLIENT = chromadb.HttpClient( + host=CHROMA_HTTP_HOST, + port=CHROMA_HTTP_PORT, + headers=CHROMA_HTTP_HEADERS, + ssl=CHROMA_HTTP_SSL, + tenant=CHROMA_TENANT, + database=CHROMA_DATABASE, + settings=Settings(allow_reset=True, anonymized_telemetry=False), + ) +else: + CHROMA_CLIENT = chromadb.PersistentClient( + path=CHROMA_DATA_PATH, + settings=Settings(allow_reset=True, anonymized_telemetry=False), + tenant=CHROMA_TENANT, + database=CHROMA_DATABASE, + ) + + +#################################### +# RAG +#################################### + +# RAG Content Extraction +CONTENT_EXTRACTION_ENGINE = PersistentConfig( + "CONTENT_EXTRACTION_ENGINE", + "rag.CONTENT_EXTRACTION_ENGINE", + os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(), +) + +TIKA_SERVER_URL = PersistentConfig( + "TIKA_SERVER_URL", + "rag.tika_server_url", + os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment +) + RAG_TOP_K = PersistentConfig( "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "5")) ) @@ -1049,25 +1071,6 @@ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = ( ) -if CHROMA_HTTP_HOST != "": - CHROMA_CLIENT = chromadb.HttpClient( - host=CHROMA_HTTP_HOST, - port=CHROMA_HTTP_PORT, - headers=CHROMA_HTTP_HEADERS, - ssl=CHROMA_HTTP_SSL, - tenant=CHROMA_TENANT, - database=CHROMA_DATABASE, - settings=Settings(allow_reset=True, anonymized_telemetry=False), - ) -else: - CHROMA_CLIENT = chromadb.PersistentClient( - path=CHROMA_DATA_PATH, - settings=Settings(allow_reset=True, anonymized_telemetry=False), - tenant=CHROMA_TENANT, - database=CHROMA_DATABASE, - ) - - # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") From 28087ccf406c2ccf47491f9329834e07874f0874 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 10 Sep 2024 01:37:36 +0100 Subject: [PATCH 2/2] refac --- backend/open_webui/apps/audio/main.py | 4 +-- backend/open_webui/apps/rag/main.py | 3 +- backend/open_webui/config.py | 43 --------------------------- backend/open_webui/env.py | 9 ++++++ 4 files changed, 12 insertions(+), 47 deletions(-) diff --git a/backend/open_webui/apps/audio/main.py b/backend/open_webui/apps/audio/main.py index 1fc44b28f..4734b0d95 100644 --- a/backend/open_webui/apps/audio/main.py +++ b/backend/open_webui/apps/audio/main.py @@ -21,14 +21,14 @@ from open_webui.config import ( AUDIO_TTS_VOICE, CACHE_DIR, CORS_ALLOW_ORIGIN, - DEVICE_TYPE, WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE, WHISPER_MODEL_DIR, AppConfig, ) + from open_webui.constants import ERROR_MESSAGES -from open_webui.env import SRC_LOG_LEVELS +from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse diff --git a/backend/open_webui/apps/rag/main.py b/backend/open_webui/apps/rag/main.py index 6c064fe81..5ca42a1ac 100644 --- a/backend/open_webui/apps/rag/main.py +++ b/backend/open_webui/apps/rag/main.py @@ -41,7 +41,6 @@ from open_webui.config import ( CHUNK_SIZE, CONTENT_EXTRACTION_ENGINE, CORS_ALLOW_ORIGIN, - DEVICE_TYPE, DOCS_DIR, ENABLE_RAG_HYBRID_SEARCH, ENABLE_RAG_LOCAL_WEB_FETCH, @@ -84,7 +83,7 @@ from open_webui.config import ( AppConfig, ) from open_webui.constants import ERROR_MESSAGES -from open_webui.env import SRC_LOG_LEVELS +from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status from fastapi.middleware.cors import CORSMiddleware from langchain.text_splitter import RecursiveCharacterTextSplitter diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 8070d3cab..4047d8aa2 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -540,40 +540,6 @@ Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True) FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions") Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True) - -#################################### -# LITELLM_CONFIG -#################################### - - -def create_config_file(file_path): - directory = os.path.dirname(file_path) - - # Check if directory exists, if not, create it - if not os.path.exists(directory): - os.makedirs(directory) - - # Data to write into the YAML file - config_data = { - "general_settings": {}, - "litellm_settings": {}, - "model_list": [], - "router_settings": {}, - } - - # Write data to YAML file - with open(file_path, "w") as file: - yaml.dump(config_data, file) - - -LITELLM_CONFIG_PATH = f"{DATA_DIR}/litellm/config.yaml" - -# if not os.path.exists(LITELLM_CONFIG_PATH): -# log.info("Config file doesn't exist. Creating...") -# create_config_file(LITELLM_CONFIG_PATH) -# log.info("Config file created successfully.") - - #################################### # OLLAMA_BASE_URL #################################### @@ -1070,15 +1036,6 @@ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = ( os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" ) - -# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance -USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") - -if USE_CUDA.lower() == "true": - DEVICE_TYPE = "cuda" -else: - DEVICE_TYPE = "cpu" - CHUNK_SIZE = PersistentConfig( "CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1500")) ) diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 8683bb370..d99a80df4 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -32,6 +32,15 @@ except ImportError: print("dotenv not installed, skipping...") +# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance +USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") + +if USE_CUDA.lower() == "true": + DEVICE_TYPE = "cuda" +else: + DEVICE_TYPE = "cpu" + + #################################### # LOGGING ####################################