diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index 87ecc292b..31bfc0f5d 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -250,7 +250,7 @@ class GenerateImageForm(BaseModel): model: Optional[str] = None prompt: str n: int = 1 - size: str = "512x512" + size: Optional[str] = None negative_prompt: Optional[str] = None @@ -278,8 +278,7 @@ def generate_image( user=Depends(get_current_user), ): - print(form_data) - + r = None try: if app.state.ENGINE == "openai": @@ -291,10 +290,9 @@ def generate_image( "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2", "prompt": form_data.prompt, "n": form_data.n, - "size": form_data.size, + "size": form_data.size if form_data.size else app.state.IMAGE_SIZE, "response_format": "b64_json", } - r = requests.post( url=f"https://api.openai.com/v1/images/generations", json=data, @@ -359,4 +357,6 @@ def generate_image( except Exception as e: print(e) + if r: + print(r.json()) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index f8f166d01..5ecbaa297 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -15,7 +15,7 @@ import asyncio from apps.web.models.users import Users from constants import ERROR_MESSAGES from utils.utils import decode_token, get_current_user, get_admin_user -from config import OLLAMA_BASE_URLS +from config import OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST from typing import Optional, List, Union @@ -29,6 +29,10 @@ app.add_middleware( allow_headers=["*"], ) + +app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED +app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST + app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.MODELS = {} @@ -129,9 +133,19 @@ async def get_all_models(): async def get_ollama_tags( url_idx: Optional[int] = None, user=Depends(get_current_user) ): - if url_idx == None: - return await get_all_models() + models = await get_all_models() + + if app.state.MODEL_FILTER_ENABLED: + if user.role == "user": + models["models"] = list( + filter( + lambda model: model["name"] in app.state.MODEL_FILTER_LIST, + models["models"], + ) + ) + return models + return models else: url = app.state.OLLAMA_BASE_URLS[url_idx] try: diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 6b9c542ee..e902bea27 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -18,7 +18,13 @@ from utils.utils import ( get_verified_user, get_admin_user, ) -from config import OPENAI_API_BASE_URLS, OPENAI_API_KEYS, CACHE_DIR +from config import ( + OPENAI_API_BASE_URLS, + OPENAI_API_KEYS, + CACHE_DIR, + MODEL_FILTER_ENABLED, + MODEL_FILTER_LIST, +) from typing import List, Optional @@ -34,6 +40,9 @@ app.add_middleware( allow_headers=["*"], ) +app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED +app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST + app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.OPENAI_API_KEYS = OPENAI_API_KEYS @@ -186,12 +195,21 @@ async def get_all_models(): return models -# , user=Depends(get_current_user) @app.get("/models") @app.get("/models/{url_idx}") -async def get_models(url_idx: Optional[int] = None): +async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)): if url_idx == None: - return await get_all_models() + models = await get_all_models() + if app.state.MODEL_FILTER_ENABLED: + if user.role == "user": + models["data"] = list( + filter( + lambda model: model["id"] in app.state.MODEL_FILTER_LIST, + models["data"], + ) + ) + return models + return models else: url = app.state.OPENAI_API_BASE_URLS[url_idx] try: diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 45ad69707..6781a9a14 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -44,6 +44,8 @@ from apps.web.models.documents import ( DocumentResponse, ) +from apps.rag.utils import query_doc, query_collection + from utils.misc import ( calculate_sha256, calculate_sha256_string, @@ -248,21 +250,18 @@ class QueryDocForm(BaseModel): @app.post("/query/doc") -def query_doc( +def query_doc_handler( form_data: QueryDocForm, user=Depends(get_current_user), ): + try: - # if you use docker use the model from the environment variable - collection = CHROMA_CLIENT.get_collection( - name=form_data.collection_name, + return query_doc( + collection_name=form_data.collection_name, + query=form_data.query, + k=form_data.k if form_data.k else app.state.TOP_K, embedding_function=app.state.sentence_transformer_ef, ) - result = collection.query( - query_texts=[form_data.query], - n_results=form_data.k if form_data.k else app.state.TOP_K, - ) - return result except Exception as e: print(e) raise HTTPException( @@ -277,76 +276,16 @@ class QueryCollectionsForm(BaseModel): k: Optional[int] = None -def merge_and_sort_query_results(query_results, k): - # Initialize lists to store combined data - combined_ids = [] - combined_distances = [] - combined_metadatas = [] - combined_documents = [] - - # Combine data from each dictionary - for data in query_results: - combined_ids.extend(data["ids"][0]) - combined_distances.extend(data["distances"][0]) - combined_metadatas.extend(data["metadatas"][0]) - combined_documents.extend(data["documents"][0]) - - # Create a list of tuples (distance, id, metadata, document) - combined = list( - zip(combined_distances, combined_ids, combined_metadatas, combined_documents) - ) - - # Sort the list based on distances - combined.sort(key=lambda x: x[0]) - - # Unzip the sorted list - sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined) - - # Slicing the lists to include only k elements - sorted_distances = list(sorted_distances)[:k] - sorted_ids = list(sorted_ids)[:k] - sorted_metadatas = list(sorted_metadatas)[:k] - sorted_documents = list(sorted_documents)[:k] - - # Create the output dictionary - merged_query_results = { - "ids": [sorted_ids], - "distances": [sorted_distances], - "metadatas": [sorted_metadatas], - "documents": [sorted_documents], - "embeddings": None, - "uris": None, - "data": None, - } - - return merged_query_results - - @app.post("/query/collection") -def query_collection( +def query_collection_handler( form_data: QueryCollectionsForm, user=Depends(get_current_user), ): - results = [] - - for collection_name in form_data.collection_names: - try: - # if you use docker use the model from the environment variable - collection = CHROMA_CLIENT.get_collection( - name=collection_name, - embedding_function=app.state.sentence_transformer_ef, - ) - - result = collection.query( - query_texts=[form_data.query], - n_results=form_data.k if form_data.k else app.state.TOP_K, - ) - results.append(result) - except: - pass - - return merge_and_sort_query_results( - results, form_data.k if form_data.k else app.state.TOP_K + return query_collection( + collection_names=form_data.collection_names, + query=form_data.query, + k=form_data.k if form_data.k else app.state.TOP_K, + embedding_function=app.state.sentence_transformer_ef, ) diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py new file mode 100644 index 000000000..91b07e0aa --- /dev/null +++ b/backend/apps/rag/utils.py @@ -0,0 +1,97 @@ +import re +from typing import List + +from config import CHROMA_CLIENT + + +def query_doc(collection_name: str, query: str, k: int, embedding_function): + try: + # if you use docker use the model from the environment variable + collection = CHROMA_CLIENT.get_collection( + name=collection_name, + embedding_function=embedding_function, + ) + result = collection.query( + query_texts=[query], + n_results=k, + ) + return result + except Exception as e: + raise e + + +def merge_and_sort_query_results(query_results, k): + # Initialize lists to store combined data + combined_ids = [] + combined_distances = [] + combined_metadatas = [] + combined_documents = [] + + # Combine data from each dictionary + for data in query_results: + combined_ids.extend(data["ids"][0]) + combined_distances.extend(data["distances"][0]) + combined_metadatas.extend(data["metadatas"][0]) + combined_documents.extend(data["documents"][0]) + + # Create a list of tuples (distance, id, metadata, document) + combined = list( + zip(combined_distances, combined_ids, combined_metadatas, combined_documents) + ) + + # Sort the list based on distances + combined.sort(key=lambda x: x[0]) + + # Unzip the sorted list + sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined) + + # Slicing the lists to include only k elements + sorted_distances = list(sorted_distances)[:k] + sorted_ids = list(sorted_ids)[:k] + sorted_metadatas = list(sorted_metadatas)[:k] + sorted_documents = list(sorted_documents)[:k] + + # Create the output dictionary + merged_query_results = { + "ids": [sorted_ids], + "distances": [sorted_distances], + "metadatas": [sorted_metadatas], + "documents": [sorted_documents], + "embeddings": None, + "uris": None, + "data": None, + } + + return merged_query_results + + +def query_collection( + collection_names: List[str], query: str, k: int, embedding_function +): + + results = [] + + for collection_name in collection_names: + try: + # if you use docker use the model from the environment variable + collection = CHROMA_CLIENT.get_collection( + name=collection_name, + embedding_function=embedding_function, + ) + + result = collection.query( + query_texts=[query], + n_results=k, + ) + results.append(result) + except: + pass + + return merge_and_sort_query_results(results, k) + + +def rag_template(template: str, context: str, query: str): + template = re.sub(r"\[context\]", context, template) + template = re.sub(r"\[query\]", query, template) + + return template diff --git a/backend/config.py b/backend/config.py index 2cd016539..019e44e01 100644 --- a/backend/config.py +++ b/backend/config.py @@ -251,7 +251,7 @@ OPENAI_API_BASE_URLS = ( OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != "" else OPENAI_API_BASE_URL ) -OPENAI_API_BASE_URLS = [url.strip() for url in OPENAI_API_BASE_URL.split(";")] +OPENAI_API_BASE_URLS = [url.strip() for url in OPENAI_API_BASE_URLS.split(";")] #################################### @@ -292,6 +292,11 @@ DEFAULT_USER_ROLE = os.getenv("DEFAULT_USER_ROLE", "pending") USER_PERMISSIONS = {"chat": {"deletion": True}} +MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", False) +MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "") +MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")] + + #################################### # WEBUI_VERSION #################################### diff --git a/backend/main.py b/backend/main.py index afa974ca6..c7523ec62 100644 --- a/backend/main.py +++ b/backend/main.py @@ -12,6 +12,7 @@ from fastapi import HTTPException from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.cors import CORSMiddleware from starlette.exceptions import HTTPException as StarletteHTTPException +from starlette.middleware.base import BaseHTTPMiddleware from apps.ollama.main import app as ollama_app @@ -22,8 +23,22 @@ from apps.images.main import app as images_app from apps.rag.main import app as rag_app from apps.web.main import app as webui_app +from pydantic import BaseModel +from typing import List -from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR + +from utils.utils import get_admin_user +from apps.rag.utils import query_doc, query_collection, rag_template + +from config import ( + WEBUI_NAME, + ENV, + VERSION, + CHANGELOG, + FRONTEND_BUILD_DIR, + MODEL_FILTER_ENABLED, + MODEL_FILTER_LIST, +) from constants import ERROR_MESSAGES @@ -40,6 +55,9 @@ class SPAStaticFiles(StaticFiles): app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None) +app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED +app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST + origins = ["*"] app.add_middleware( @@ -56,6 +74,126 @@ async def on_startup(): await litellm_app_startup() +class RAGMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + if request.method == "POST" and ( + "/api/chat" in request.url.path or "/chat/completions" in request.url.path + ): + print(request.url.path) + + # Read the original request body + body = await request.body() + # Decode body to string + body_str = body.decode("utf-8") + # Parse string to JSON + data = json.loads(body_str) if body_str else {} + + # Example: Add a new key-value pair or modify existing ones + # data["modified"] = True # Example modification + if "docs" in data: + docs = data["docs"] + print(docs) + + last_user_message_idx = None + for i in range(len(data["messages"]) - 1, -1, -1): + if data["messages"][i]["role"] == "user": + last_user_message_idx = i + break + + user_message = data["messages"][last_user_message_idx] + + if isinstance(user_message["content"], list): + # Handle list content input + content_type = "list" + query = "" + for content_item in user_message["content"]: + if content_item["type"] == "text": + query = content_item["text"] + break + elif isinstance(user_message["content"], str): + # Handle text content input + content_type = "text" + query = user_message["content"] + else: + # Fallback in case the input does not match expected types + content_type = None + query = "" + + relevant_contexts = [] + + for doc in docs: + context = None + + try: + if doc["type"] == "collection": + context = query_collection( + collection_names=doc["collection_names"], + query=query, + k=rag_app.state.TOP_K, + embedding_function=rag_app.state.sentence_transformer_ef, + ) + else: + context = query_doc( + collection_name=doc["collection_name"], + query=query, + k=rag_app.state.TOP_K, + embedding_function=rag_app.state.sentence_transformer_ef, + ) + except Exception as e: + print(e) + context = None + + relevant_contexts.append(context) + + context_string = "" + for context in relevant_contexts: + if context: + context_string += " ".join(context["documents"][0]) + "\n" + + ra_content = rag_template( + template=rag_app.state.RAG_TEMPLATE, + context=context_string, + query=query, + ) + + if content_type == "list": + new_content = [] + for content_item in user_message["content"]: + if content_item["type"] == "text": + # Update the text item's content with ra_content + new_content.append({"type": "text", "text": ra_content}) + else: + # Keep other types of content as they are + new_content.append(content_item) + new_user_message = {**user_message, "content": new_content} + else: + new_user_message = { + **user_message, + "content": ra_content, + } + + data["messages"][last_user_message_idx] = new_user_message + del data["docs"] + + print(data["messages"]) + + modified_body_bytes = json.dumps(data).encode("utf-8") + + # Create a new request with the modified body + scope = request.scope + scope["body"] = modified_body_bytes + request = Request(scope, receive=lambda: self._receive(modified_body_bytes)) + + response = await call_next(request) + return response + + async def _receive(self, body: bytes): + return {"type": "http.request", "body": body, "more_body": False} + + +app.add_middleware(RAGMiddleware) + + @app.middleware("http") async def check_url(request: Request, call_next): start_time = int(time.time()) @@ -90,6 +228,39 @@ async def get_app_config(): } +@app.get("/api/config/model/filter") +async def get_model_filter_config(user=Depends(get_admin_user)): + return { + "enabled": app.state.MODEL_FILTER_ENABLED, + "models": app.state.MODEL_FILTER_LIST, + } + + +class ModelFilterConfigForm(BaseModel): + enabled: bool + models: List[str] + + +@app.post("/api/config/model/filter") +async def get_model_filter_config( + form_data: ModelFilterConfigForm, user=Depends(get_admin_user) +): + + app.state.MODEL_FILTER_ENABLED = form_data.enabled + app.state.MODEL_FILTER_LIST = form_data.models + + ollama_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED + ollama_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST + + openai_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED + openai_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST + + return { + "enabled": app.state.MODEL_FILTER_ENABLED, + "models": app.state.MODEL_FILTER_LIST, + } + + @app.get("/api/version") async def get_app_config(): diff --git a/backend/requirements.txt b/backend/requirements.txt index 41527a78c..29fb34925 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -16,7 +16,8 @@ aiohttp peewee bcrypt -litellm +litellm==1.30.7 +argon2-cffi apscheduler google-generativeai diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index b7b346c0d..b33fb571b 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -77,3 +77,65 @@ export const getVersionUpdates = async () => { return res; }; + +export const getModelFilterConfig = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/config/model/filter`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateModelFilterConfig = async ( + token: string, + enabled: boolean, + models: string[] +) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/config/model/filter`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + enabled: enabled, + models: models + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index 4e8e9b14c..6dcfbbe7d 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -252,7 +252,7 @@ export const queryCollection = async ( token: string, collection_names: string, query: string, - k: number + k: number | null = null ) => { let error = null; diff --git a/src/lib/components/admin/Settings/Users.svelte b/src/lib/components/admin/Settings/Users.svelte index 620555d7b..f7b79edce 100644 --- a/src/lib/components/admin/Settings/Users.svelte +++ b/src/lib/components/admin/Settings/Users.svelte @@ -1,12 +1,17 @@ @@ -23,6 +35,8 @@ on:submit|preventDefault={async () => { // console.log('submit'); await updateUserPermissions(localStorage.token, permissions); + + await updateModelFilterConfig(localStorage.token, whitelistEnabled, whitelistModels); saveHandler(); }} > @@ -71,6 +85,106 @@ + +
+ +
+
+
+
+
Manage Models
+
+
+ +
+
+
+
Model Whitelisting
+ + +
+
+ + {#if whitelistEnabled} +
+
+ {#each whitelistModels as modelId, modelIdx} +
+
+ +
+ + {#if modelIdx === 0} + + {:else} + + {/if} +
+ {/each} +
+ +
+
+ {whitelistModels.length} Model(s) Whitelisted +
+
+
+ {/if} +
+
+
diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index a32eeb2ad..f07f7f98d 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -364,12 +364,12 @@ {#if dragged}
-
+
diff --git a/src/lib/components/chat/Settings/Account.svelte b/src/lib/components/chat/Settings/Account.svelte index f3dd0efa5..ff564e58d 100644 --- a/src/lib/components/chat/Settings/Account.svelte +++ b/src/lib/components/chat/Settings/Account.svelte @@ -111,7 +111,9 @@
-
+ +
{$i18n.t('Title Generation Prompt')}