diff --git a/CHANGELOG.md b/CHANGELOG.md index dad583399..505ded309 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,19 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.1.121] - 2024-04-24 + +### Fixed + +- **🔧 Translation Issues**: Addressed various translation discrepancies. +- **🔒 LiteLLM Security Fix**: Updated LiteLLM version to resolve a security vulnerability. +- **🖥️ HTML Tag Display**: Rectified the issue where the '< br >' tag wasn't displaying correctly. +- **🔗 WebSocket Connection**: Resolved the failure of WebSocket connection under HTTPS security for ComfyUI server. +- **📜 FileReader Optimization**: Implemented FileReader initialization per image in multi-file drag & drop to ensure reusability. +- **🏷️ Tag Display**: Corrected tag display inconsistencies. +- **📦 Archived Chat Styling**: Fixed styling issues in archived chat. +- **🔖 Safari Copy Button Bug**: Addressed the bug where the copy button failed to copy links in Safari. + ## [0.1.120] - 2024-04-20 ### Added diff --git a/Dockerfile b/Dockerfile index f19952909..a8f664ada 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,8 +8,8 @@ ARG USE_CUDA_VER=cu121 # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard # for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB) -# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. -ARG USE_EMBEDDING_MODEL=all-MiniLM-L6-v2 +# IMPORTANT: If you change the default model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. +ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 ######## WebUI frontend ######## FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build @@ -98,13 +98,13 @@ RUN pip3 install uv && \ # If you use CUDA the whisper and embedding model will be downloaded on first use pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \ uv pip install --system -r requirements.txt --no-cache-dir && \ - python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \ - python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device='cpu')"; \ + python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ + python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ else \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \ uv pip install --system -r requirements.txt --no-cache-dir && \ - python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \ - python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device='cpu')"; \ + python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ + python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ fi diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index a3939d206..2059ac3c0 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -35,8 +35,8 @@ from config import ( ENABLE_IMAGE_GENERATION, AUTOMATIC1111_BASE_URL, COMFYUI_BASE_URL, - OPENAI_API_BASE_URL, - OPENAI_API_KEY, + IMAGES_OPENAI_API_BASE_URL, + IMAGES_OPENAI_API_KEY, ) @@ -58,8 +58,8 @@ app.add_middleware( app.state.ENGINE = "" app.state.ENABLED = ENABLE_IMAGE_GENERATION -app.state.OPENAI_API_BASE_URL = OPENAI_API_BASE_URL -app.state.OPENAI_API_KEY = OPENAI_API_KEY +app.state.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL +app.state.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY app.state.MODEL = "" @@ -135,27 +135,33 @@ async def update_engine_url( } -class OpenAIKeyUpdateForm(BaseModel): +class OpenAIConfigUpdateForm(BaseModel): + url: str key: str -@app.get("/key") -async def get_openai_key(user=Depends(get_admin_user)): - return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} +@app.get("/openai/config") +async def get_openai_config(user=Depends(get_admin_user)): + return { + "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.OPENAI_API_KEY, + } -@app.post("/key/update") -async def update_openai_key( - form_data: OpenAIKeyUpdateForm, user=Depends(get_admin_user) +@app.post("/openai/config/update") +async def update_openai_config( + form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user) ): - if form_data.key == "": raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) + app.state.OPENAI_API_BASE_URL = form_data.url app.state.OPENAI_API_KEY = form_data.key + return { - "OPENAI_API_KEY": app.state.OPENAI_API_KEY, "status": True, + "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.OPENAI_API_KEY, } diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index a9922aad7..119e9107e 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -1,17 +1,25 @@ +import sys + +from fastapi import FastAPI, Depends, HTTPException +from fastapi.routing import APIRoute +from fastapi.middleware.cors import CORSMiddleware + import logging - -from litellm.proxy.proxy_server import ProxyConfig, initialize -from litellm.proxy.proxy_server import app - from fastapi import FastAPI, Request, Depends, status, Response from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.responses import StreamingResponse import json +import time +import requests -from utils.utils import get_http_authorization_cred, get_current_user +from pydantic import BaseModel, ConfigDict +from typing import Optional, List + +from utils.utils import get_verified_user, get_current_user, get_admin_user from config import SRC_LOG_LEVELS, ENV +from constants import MESSAGES log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["LITELLM"]) @@ -20,81 +28,324 @@ log.setLevel(SRC_LOG_LEVELS["LITELLM"]) from config import ( MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, + DATA_DIR, + LITELLM_PROXY_PORT, + LITELLM_PROXY_HOST, +) + +from litellm.utils import get_llm_provider + +import asyncio +import subprocess +import yaml + +app = FastAPI() + +origins = ["*"] + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], ) -proxy_config = ProxyConfig() +LITELLM_CONFIG_DIR = f"{DATA_DIR}/litellm/config.yaml" + +with open(LITELLM_CONFIG_DIR, "r") as file: + litellm_config = yaml.safe_load(file) + +app.state.CONFIG = litellm_config + +# Global variable to store the subprocess reference +background_process = None -async def config(): - router, model_list, general_settings = await proxy_config.load_config( - router=None, config_file_path="./data/litellm/config.yaml" - ) +async def run_background_process(command): + global background_process + log.info("run_background_process") - await initialize(config="./data/litellm/config.yaml", telemetry=False) + try: + # Log the command to be executed + log.info(f"Executing command: {command}") + # Execute the command and create a subprocess + process = await asyncio.create_subprocess_exec( + *command, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + background_process = process + log.info("Subprocess started successfully.") + + # Capture STDERR for debugging purposes + stderr_output = await process.stderr.read() + stderr_text = stderr_output.decode().strip() + if stderr_text: + log.info(f"Subprocess STDERR: {stderr_text}") + + # log.info output line by line + async for line in process.stdout: + log.info(line.decode().strip()) + + # Wait for the process to finish + returncode = await process.wait() + log.info(f"Subprocess exited with return code {returncode}") + except Exception as e: + log.error(f"Failed to start subprocess: {e}") + raise # Optionally re-raise the exception if you want it to propagate -async def startup(): - await config() +async def start_litellm_background(): + log.info("start_litellm_background") + # Command to run in the background + command = [ + "litellm", + "--port", + str(LITELLM_PROXY_PORT), + "--host", + LITELLM_PROXY_HOST, + "--telemetry", + "False", + "--config", + LITELLM_CONFIG_DIR, + ] + + await run_background_process(command) + + +async def shutdown_litellm_background(): + log.info("shutdown_litellm_background") + global background_process + if background_process: + background_process.terminate() + await background_process.wait() # Ensure the process has terminated + log.info("Subprocess terminated") + background_process = None @app.on_event("startup") -async def on_startup(): - await startup() +async def startup_event(): + log.info("startup_event") + # TODO: Check config.yaml file and create one + asyncio.create_task(start_litellm_background()) app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST -@app.middleware("http") -async def auth_middleware(request: Request, call_next): - auth_header = request.headers.get("Authorization", "") - request.state.user = None +@app.get("/") +async def get_status(): + return {"status": True} + + +async def restart_litellm(): + """ + Endpoint to restart the litellm background service. + """ + log.info("Requested restart of litellm service.") + try: + # Shut down the existing process if it is running + await shutdown_litellm_background() + log.info("litellm service shutdown complete.") + + # Restart the background service + + asyncio.create_task(start_litellm_background()) + log.info("litellm service restart complete.") + + return { + "status": "success", + "message": "litellm service restarted successfully.", + } + except Exception as e: + log.info(f"Error restarting litellm service: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) + ) + + +@app.get("/restart") +async def restart_litellm_handler(user=Depends(get_admin_user)): + return await restart_litellm() + + +@app.get("/config") +async def get_config(user=Depends(get_admin_user)): + return app.state.CONFIG + + +class LiteLLMConfigForm(BaseModel): + general_settings: Optional[dict] = None + litellm_settings: Optional[dict] = None + model_list: Optional[List[dict]] = None + router_settings: Optional[dict] = None + + model_config = ConfigDict(protected_namespaces=()) + + +@app.post("/config/update") +async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)): + app.state.CONFIG = form_data.model_dump(exclude_none=True) + + with open(LITELLM_CONFIG_DIR, "w") as file: + yaml.dump(app.state.CONFIG, file) + + await restart_litellm() + return app.state.CONFIG + + +@app.get("/models") +@app.get("/v1/models") +async def get_models(user=Depends(get_current_user)): + while not background_process: + await asyncio.sleep(0.1) + + url = f"http://localhost:{LITELLM_PROXY_PORT}/v1" + r = None + try: + r = requests.request(method="GET", url=f"{url}/models") + r.raise_for_status() + + data = r.json() + + if app.state.MODEL_FILTER_ENABLED: + if user and user.role == "user": + data["data"] = list( + filter( + lambda model: model["id"] in app.state.MODEL_FILTER_LIST, + data["data"], + ) + ) + + return data + except Exception as e: + + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"External: {res['error']}" + except: + error_detail = f"External: {e}" + + return { + "data": [ + { + "id": model["model_name"], + "object": "model", + "created": int(time.time()), + "owned_by": "openai", + } + for model in app.state.CONFIG["model_list"] + ], + "object": "list", + } + + +@app.get("/model/info") +async def get_model_list(user=Depends(get_admin_user)): + return {"data": app.state.CONFIG["model_list"]} + + +class AddLiteLLMModelForm(BaseModel): + model_name: str + litellm_params: dict + + model_config = ConfigDict(protected_namespaces=()) + + +@app.post("/model/new") +async def add_model_to_config( + form_data: AddLiteLLMModelForm, user=Depends(get_admin_user) +): + try: + get_llm_provider(model=form_data.model_name) + app.state.CONFIG["model_list"].append(form_data.model_dump()) + + with open(LITELLM_CONFIG_DIR, "w") as file: + yaml.dump(app.state.CONFIG, file) + + await restart_litellm() + + return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)} + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) + ) + + +class DeleteLiteLLMModelForm(BaseModel): + id: str + + +@app.post("/model/delete") +async def delete_model_from_config( + form_data: DeleteLiteLLMModelForm, user=Depends(get_admin_user) +): + app.state.CONFIG["model_list"] = [ + model + for model in app.state.CONFIG["model_list"] + if model["model_name"] != form_data.id + ] + + with open(LITELLM_CONFIG_DIR, "w") as file: + yaml.dump(app.state.CONFIG, file) + + await restart_litellm() + + return {"message": MESSAGES.MODEL_DELETED(form_data.id)} + + +@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) +async def proxy(path: str, request: Request, user=Depends(get_verified_user)): + body = await request.body() + + url = f"http://localhost:{LITELLM_PROXY_PORT}" + + target_url = f"{url}/{path}" + + headers = {} + # headers["Authorization"] = f"Bearer {key}" + headers["Content-Type"] = "application/json" + + r = None try: - user = get_current_user(get_http_authorization_cred(auth_header)) - log.debug(f"user: {user}") - request.state.user = user + r = requests.request( + method=request.method, + url=target_url, + data=body, + headers=headers, + stream=True, + ) + + r.raise_for_status() + + # Check if response is SSE + if "text/event-stream" in r.headers.get("Content-Type", ""): + return StreamingResponse( + r.iter_content(chunk_size=8192), + status_code=r.status_code, + headers=dict(r.headers), + ) + else: + response_data = r.json() + return response_data except Exception as e: - return JSONResponse(status_code=400, content={"detail": str(e)}) + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" + except: + error_detail = f"External: {e}" - response = await call_next(request) - return response - - -class ModifyModelsResponseMiddleware(BaseHTTPMiddleware): - async def dispatch( - self, request: Request, call_next: RequestResponseEndpoint - ) -> Response: - - response = await call_next(request) - user = request.state.user - - if "/models" in request.url.path: - if isinstance(response, StreamingResponse): - # Read the content of the streaming response - body = b"" - async for chunk in response.body_iterator: - body += chunk - - data = json.loads(body.decode("utf-8")) - - if app.state.MODEL_FILTER_ENABLED: - if user and user.role == "user": - data["data"] = list( - filter( - lambda model: model["id"] - in app.state.MODEL_FILTER_LIST, - data["data"], - ) - ) - - # Modified Flag - data["modified"] = True - return JSONResponse(content=data) - - return response - - -app.add_middleware(ModifyModelsResponseMiddleware) + raise HTTPException( + status_code=r.status_code if r else 500, detail=error_detail + ) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index ac8410dbe..5da7489f1 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -13,7 +13,6 @@ import os, shutil, logging, re from pathlib import Path from typing import List -from chromadb.utils import embedding_functions from chromadb.utils.batch_utils import create_batches from langchain_community.document_loaders import ( @@ -38,6 +37,7 @@ import mimetypes import uuid import json +import sentence_transformers from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm @@ -48,11 +48,8 @@ from apps.web.models.documents import ( ) from apps.rag.utils import ( - query_doc, query_embeddings_doc, - query_collection, query_embeddings_collection, - get_embedding_model_path, generate_openai_embeddings, ) @@ -69,7 +66,7 @@ from config import ( DOCS_DIR, RAG_EMBEDDING_ENGINE, RAG_EMBEDDING_MODEL, - RAG_EMBEDDING_MODEL_AUTO_UPDATE, + RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, RAG_OPENAI_API_BASE_URL, RAG_OPENAI_API_KEY, DEVICE_TYPE, @@ -101,15 +98,12 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY app.state.PDF_EXTRACT_IMAGES = False - -app.state.sentence_transformer_ef = ( - embedding_functions.SentenceTransformerEmbeddingFunction( - model_name=get_embedding_model_path( - app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE - ), +if app.state.RAG_EMBEDDING_ENGINE == "": + app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( + app.state.RAG_EMBEDDING_MODEL, device=DEVICE_TYPE, + trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, ) -) origins = ["*"] @@ -185,13 +179,10 @@ async def update_embedding_config( app.state.OPENAI_API_BASE_URL = form_data.openai_config.url app.state.OPENAI_API_KEY = form_data.openai_config.key else: - sentence_transformer_ef = ( - embedding_functions.SentenceTransformerEmbeddingFunction( - model_name=get_embedding_model_path( - form_data.embedding_model, True - ), - device=DEVICE_TYPE, - ) + sentence_transformer_ef = sentence_transformers.SentenceTransformer( + app.state.RAG_EMBEDDING_MODEL, + device=DEVICE_TYPE, + trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, ) app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model app.state.sentence_transformer_ef = sentence_transformer_ef @@ -294,38 +285,34 @@ def query_doc_handler( form_data: QueryDocForm, user=Depends(get_current_user), ): - try: if app.state.RAG_EMBEDDING_ENGINE == "": - return query_doc( - collection_name=form_data.collection_name, - query=form_data.query, - k=form_data.k if form_data.k else app.state.TOP_K, - embedding_function=app.state.sentence_transformer_ef, + query_embeddings = app.state.sentence_transformer_ef.encode( + form_data.query + ).tolist() + elif app.state.RAG_EMBEDDING_ENGINE == "ollama": + query_embeddings = generate_ollama_embeddings( + GenerateEmbeddingsForm( + **{ + "model": app.state.RAG_EMBEDDING_MODEL, + "prompt": form_data.query, + } + ) + ) + elif app.state.RAG_EMBEDDING_ENGINE == "openai": + query_embeddings = generate_openai_embeddings( + model=app.state.RAG_EMBEDDING_MODEL, + text=form_data.query, + key=app.state.OPENAI_API_KEY, + url=app.state.OPENAI_API_BASE_URL, ) - else: - if app.state.RAG_EMBEDDING_ENGINE == "ollama": - query_embeddings = generate_ollama_embeddings( - GenerateEmbeddingsForm( - **{ - "model": app.state.RAG_EMBEDDING_MODEL, - "prompt": form_data.query, - } - ) - ) - elif app.state.RAG_EMBEDDING_ENGINE == "openai": - query_embeddings = generate_openai_embeddings( - model=app.state.RAG_EMBEDDING_MODEL, - text=form_data.query, - key=app.state.OPENAI_API_KEY, - url=app.state.OPENAI_API_BASE_URL, - ) - return query_embeddings_doc( - collection_name=form_data.collection_name, - query_embeddings=query_embeddings, - k=form_data.k if form_data.k else app.state.TOP_K, - ) + return query_embeddings_doc( + collection_name=form_data.collection_name, + query=form_data.query, + query_embeddings=query_embeddings, + k=form_data.k if form_data.k else app.state.TOP_K, + ) except Exception as e: log.exception(e) @@ -348,36 +335,31 @@ def query_collection_handler( ): try: if app.state.RAG_EMBEDDING_ENGINE == "": - return query_collection( - collection_names=form_data.collection_names, - query=form_data.query, - k=form_data.k if form_data.k else app.state.TOP_K, - embedding_function=app.state.sentence_transformer_ef, - ) - else: - - if app.state.RAG_EMBEDDING_ENGINE == "ollama": - query_embeddings = generate_ollama_embeddings( - GenerateEmbeddingsForm( - **{ - "model": app.state.RAG_EMBEDDING_MODEL, - "prompt": form_data.query, - } - ) + query_embeddings = app.state.sentence_transformer_ef.encode( + form_data.query + ).tolist() + elif app.state.RAG_EMBEDDING_ENGINE == "ollama": + query_embeddings = generate_ollama_embeddings( + GenerateEmbeddingsForm( + **{ + "model": app.state.RAG_EMBEDDING_MODEL, + "prompt": form_data.query, + } ) - elif app.state.RAG_EMBEDDING_ENGINE == "openai": - query_embeddings = generate_openai_embeddings( - model=app.state.RAG_EMBEDDING_MODEL, - text=form_data.query, - key=app.state.OPENAI_API_KEY, - url=app.state.OPENAI_API_BASE_URL, - ) - - return query_embeddings_collection( - collection_names=form_data.collection_names, - query_embeddings=query_embeddings, - k=form_data.k if form_data.k else app.state.TOP_K, ) + elif app.state.RAG_EMBEDDING_ENGINE == "openai": + query_embeddings = generate_openai_embeddings( + model=app.state.RAG_EMBEDDING_MODEL, + text=form_data.query, + key=app.state.OPENAI_API_KEY, + url=app.state.OPENAI_API_BASE_URL, + ) + + return query_embeddings_collection( + collection_names=form_data.collection_names, + query_embeddings=query_embeddings, + k=form_data.k if form_data.k else app.state.TOP_K, + ) except Exception as e: log.exception(e) @@ -445,6 +427,8 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b log.info(f"store_docs_in_vector_db {docs} {collection_name}") texts = [doc.page_content for doc in docs] + texts = list(map(lambda x: x.replace("\n", " "), texts)) + metadatas = [doc.metadata for doc in docs] try: @@ -454,52 +438,38 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b log.info(f"deleting existing collection {collection_name}") CHROMA_CLIENT.delete_collection(name=collection_name) + collection = CHROMA_CLIENT.create_collection(name=collection_name) + if app.state.RAG_EMBEDDING_ENGINE == "": - - collection = CHROMA_CLIENT.create_collection( - name=collection_name, - embedding_function=app.state.sentence_transformer_ef, - ) - - for batch in create_batches( - api=CHROMA_CLIENT, - ids=[str(uuid.uuid1()) for _ in texts], - metadatas=metadatas, - documents=texts, - ): - collection.add(*batch) - - else: - collection = CHROMA_CLIENT.create_collection(name=collection_name) - - if app.state.RAG_EMBEDDING_ENGINE == "ollama": - embeddings = [ - generate_ollama_embeddings( - GenerateEmbeddingsForm( - **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text} - ) + embeddings = app.state.sentence_transformer_ef.encode(texts).tolist() + elif app.state.RAG_EMBEDDING_ENGINE == "ollama": + embeddings = [ + generate_ollama_embeddings( + GenerateEmbeddingsForm( + **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text} ) - for text in texts - ] - elif app.state.RAG_EMBEDDING_ENGINE == "openai": - embeddings = [ - generate_openai_embeddings( - model=app.state.RAG_EMBEDDING_MODEL, - text=text, - key=app.state.OPENAI_API_KEY, - url=app.state.OPENAI_API_BASE_URL, - ) - for text in texts - ] + ) + for text in texts + ] + elif app.state.RAG_EMBEDDING_ENGINE == "openai": + embeddings = [ + generate_openai_embeddings( + model=app.state.RAG_EMBEDDING_MODEL, + text=text, + key=app.state.OPENAI_API_KEY, + url=app.state.OPENAI_API_BASE_URL, + ) + for text in texts + ] - for batch in create_batches( - api=CHROMA_CLIENT, - ids=[str(uuid.uuid1()) for _ in texts], - metadatas=metadatas, - embeddings=embeddings, - documents=texts, - ): - collection.add(*batch) + for batch in create_batches( + api=CHROMA_CLIENT, + ids=[str(uuid.uuid1()) for _ in texts], + metadatas=metadatas, + embeddings=embeddings, + documents=texts, + ): + collection.add(*batch) return True except Exception as e: diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index f4d1246c7..0ce299279 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -1,13 +1,12 @@ -import os -import re import logging -from typing import List import requests +from typing import List -from huggingface_hub import snapshot_download -from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm - +from apps.ollama.main import ( + generate_ollama_embeddings, + GenerateEmbeddingsForm, +) from config import SRC_LOG_LEVELS, CHROMA_CLIENT @@ -16,29 +15,12 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -def query_doc(collection_name: str, query: str, k: int, embedding_function): - try: - # if you use docker use the model from the environment variable - collection = CHROMA_CLIENT.get_collection( - name=collection_name, - embedding_function=embedding_function, - ) - result = collection.query( - query_texts=[query], - n_results=k, - ) - return result - except Exception as e: - raise e - - -def query_embeddings_doc(collection_name: str, query_embeddings, k: int): +def query_embeddings_doc(collection_name: str, query: str, query_embeddings, k: int): try: # if you use docker use the model from the environment variable log.info(f"query_embeddings_doc {query_embeddings}") - collection = CHROMA_CLIENT.get_collection( - name=collection_name, - ) + collection = CHROMA_CLIENT.get_collection(name=collection_name) + result = collection.query( query_embeddings=[query_embeddings], n_results=k, @@ -95,43 +77,20 @@ def merge_and_sort_query_results(query_results, k): return merged_query_results -def query_collection( - collection_names: List[str], query: str, k: int, embedding_function +def query_embeddings_collection( + collection_names: List[str], query: str, query_embeddings, k: int ): - results = [] - - for collection_name in collection_names: - try: - # if you use docker use the model from the environment variable - collection = CHROMA_CLIENT.get_collection( - name=collection_name, - embedding_function=embedding_function, - ) - - result = collection.query( - query_texts=[query], - n_results=k, - ) - results.append(result) - except: - pass - - return merge_and_sort_query_results(results, k) - - -def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int): - results = [] log.info(f"query_embeddings_collection {query_embeddings}") for collection_name in collection_names: try: - collection = CHROMA_CLIENT.get_collection(name=collection_name) - - result = collection.query( - query_embeddings=[query_embeddings], - n_results=k, + result = query_embeddings_doc( + collection_name=collection_name, + query=query, + query_embeddings=query_embeddings, + k=k, ) results.append(result) except: @@ -197,51 +156,38 @@ def rag_messages( context = doc["content"] else: if embedding_engine == "": - if doc["type"] == "collection": - context = query_collection( - collection_names=doc["collection_names"], - query=query, - k=k, - embedding_function=embedding_function, - ) - else: - context = query_doc( - collection_name=doc["collection_name"], - query=query, - k=k, - embedding_function=embedding_function, + query_embeddings = embedding_function.encode(query).tolist() + elif embedding_engine == "ollama": + query_embeddings = generate_ollama_embeddings( + GenerateEmbeddingsForm( + **{ + "model": embedding_model, + "prompt": query, + } ) + ) + elif embedding_engine == "openai": + query_embeddings = generate_openai_embeddings( + model=embedding_model, + text=query, + key=openai_key, + url=openai_url, + ) + if doc["type"] == "collection": + context = query_embeddings_collection( + collection_names=doc["collection_names"], + query=query, + query_embeddings=query_embeddings, + k=k, + ) else: - if embedding_engine == "ollama": - query_embeddings = generate_ollama_embeddings( - GenerateEmbeddingsForm( - **{ - "model": embedding_model, - "prompt": query, - } - ) - ) - elif embedding_engine == "openai": - query_embeddings = generate_openai_embeddings( - model=embedding_model, - text=query, - key=openai_key, - url=openai_url, - ) - - if doc["type"] == "collection": - context = query_embeddings_collection( - collection_names=doc["collection_names"], - query_embeddings=query_embeddings, - k=k, - ) - else: - context = query_embeddings_doc( - collection_name=doc["collection_name"], - query_embeddings=query_embeddings, - k=k, - ) + context = query_embeddings_doc( + collection_name=doc["collection_name"], + query=query, + query_embeddings=query_embeddings, + k=k, + ) except Exception as e: log.exception(e) @@ -283,46 +229,6 @@ def rag_messages( return messages -def get_embedding_model_path( - embedding_model: str, update_embedding_model: bool = False -): - # Construct huggingface_hub kwargs with local_files_only to return the snapshot path - cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME") - - local_files_only = not update_embedding_model - - snapshot_kwargs = { - "cache_dir": cache_dir, - "local_files_only": local_files_only, - } - - log.debug(f"embedding_model: {embedding_model}") - log.debug(f"snapshot_kwargs: {snapshot_kwargs}") - - # Inspiration from upstream sentence_transformers - if ( - os.path.exists(embedding_model) - or ("\\" in embedding_model or embedding_model.count("/") > 1) - and local_files_only - ): - # If fully qualified path exists, return input, else set repo_id - return embedding_model - elif "/" not in embedding_model: - # Set valid repo_id for model short-name - embedding_model = "sentence-transformers" + "/" + embedding_model - - snapshot_kwargs["repo_id"] = embedding_model - - # Attempt to query the huggingface_hub library to determine the local path and/or to update - try: - embedding_model_repo_path = snapshot_download(**snapshot_kwargs) - log.debug(f"embedding_model_repo_path: {embedding_model_repo_path}") - return embedding_model_repo_path - except Exception as e: - log.exception(f"Cannot determine embedding model snapshot path: {e}") - return embedding_model - - def generate_openai_embeddings( model: str, text: str, key: str, url: str = "https://api.openai.com/v1" ): diff --git a/backend/apps/web/routers/chats.py b/backend/apps/web/routers/chats.py index 678c9aea7..bbe3d84b9 100644 --- a/backend/apps/web/routers/chats.py +++ b/backend/apps/web/routers/chats.py @@ -28,7 +28,7 @@ from apps.web.models.tags import ( from constants import ERROR_MESSAGES -from config import SRC_LOG_LEVELS +from config import SRC_LOG_LEVELS, ENABLE_ADMIN_EXPORT log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -79,6 +79,11 @@ async def get_all_user_chats(user=Depends(get_current_user)): @router.get("/all/db", response_model=List[ChatResponse]) async def get_all_user_chats_in_db(user=Depends(get_admin_user)): + if not ENABLE_ADMIN_EXPORT: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) return [ ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) for chat in Chats.get_all_chats() diff --git a/backend/apps/web/routers/utils.py b/backend/apps/web/routers/utils.py index 0ee75cfe6..284f350a0 100644 --- a/backend/apps/web/routers/utils.py +++ b/backend/apps/web/routers/utils.py @@ -91,7 +91,11 @@ async def download_chat_as_pdf( @router.get("/db/download") async def download_db(user=Depends(get_admin_user)): - + if not ENABLE_ADMIN_EXPORT: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) return FileResponse( f"{DATA_DIR}/webui.db", media_type="application/octet-stream", diff --git a/backend/config.py b/backend/config.py index fb9063eb7..f421c8aea 100644 --- a/backend/config.py +++ b/backend/config.py @@ -382,6 +382,8 @@ MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")] WEBHOOK_URL = os.environ.get("WEBHOOK_URL", "") +ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true" + #################################### # WEBUI_VERSION #################################### @@ -416,18 +418,19 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "": #################################### CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" -# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2) +# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") -RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2") +RAG_EMBEDDING_MODEL = os.environ.get( + "RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2" +) log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"), -RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( - os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true" +RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( + os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" ) - # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") @@ -484,9 +487,24 @@ AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "") COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "") +IMAGES_OPENAI_API_BASE_URL = os.getenv( + "IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL +) +IMAGES_OPENAI_API_KEY = os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY) + + #################################### # Audio #################################### AUDIO_OPENAI_API_BASE_URL = os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL) AUDIO_OPENAI_API_KEY = os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY) + +#################################### +# LiteLLM +#################################### + +LITELLM_PROXY_PORT = int(os.getenv("LITELLM_PROXY_PORT", "14365")) +if LITELLM_PROXY_PORT < 0 or LITELLM_PROXY_PORT > 65535: + raise ValueError("Invalid port number for LITELLM_PROXY_PORT") +LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1") diff --git a/backend/constants.py b/backend/constants.py index da1ee0b3f..310c13311 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -3,6 +3,10 @@ from enum import Enum class MESSAGES(str, Enum): DEFAULT = lambda msg="": f"{msg if msg else ''}" + MODEL_ADDED = lambda model="": f"The model '{model}' has been added successfully." + MODEL_DELETED = ( + lambda model="": f"The model '{model}' has been deleted successfully." + ) class WEBHOOK_MESSAGES(str, Enum): diff --git a/backend/main.py b/backend/main.py index 8b5fd76bc..c7c78e18d 100644 --- a/backend/main.py +++ b/backend/main.py @@ -20,12 +20,17 @@ from starlette.middleware.base import BaseHTTPMiddleware from apps.ollama.main import app as ollama_app from apps.openai.main import app as openai_app -from apps.litellm.main import app as litellm_app, startup as litellm_app_startup +from apps.litellm.main import ( + app as litellm_app, + start_litellm_background, + shutdown_litellm_background, +) from apps.audio.main import app as audio_app from apps.images.main import app as images_app from apps.rag.main import app as rag_app from apps.web.main import app as webui_app +import asyncio from pydantic import BaseModel from typing import List @@ -47,6 +52,7 @@ from config import ( GLOBAL_LOG_LEVEL, SRC_LOG_LEVELS, WEBHOOK_URL, + ENABLE_ADMIN_EXPORT, ) from constants import ERROR_MESSAGES @@ -170,7 +176,7 @@ async def check_url(request: Request, call_next): @app.on_event("startup") async def on_startup(): - await litellm_app_startup() + asyncio.create_task(start_litellm_background()) app.mount("/api/v1", webui_app) @@ -202,6 +208,7 @@ async def get_app_config(): "default_models": webui_app.state.DEFAULT_MODELS, "default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS, "trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), + "admin_export_enabled": ENABLE_ADMIN_EXPORT, } @@ -315,3 +322,8 @@ app.mount( SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True), name="spa-static-files", ) + + +@app.on_event("shutdown") +async def shutdown_event(): + await shutdown_litellm_background() diff --git a/backend/requirements.txt b/backend/requirements.txt index c815d93da..10bcc3b69 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -17,7 +17,9 @@ peewee peewee-migrate bcrypt -litellm==1.30.7 +litellm==1.35.17 +litellm[proxy]==1.35.17 + boto3 argon2-cffi @@ -25,6 +27,7 @@ apscheduler google-generativeai langchain +langchain-chroma langchain-community fake_useragent chromadb @@ -43,6 +46,7 @@ opencv-python-headless rapidocr-onnxruntime fpdf2 +rank_bm25 faster-whisper diff --git a/package-lock.json b/package-lock.json index a310c609d..55b35dd58 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "0.1.120", + "version": "0.1.121", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.1.120", + "version": "0.1.121", "dependencies": { "@sveltejs/adapter-node": "^1.3.1", "async": "^3.2.5", diff --git a/package.json b/package.json index 12afea0f4..777f0f07b 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.1.120", + "version": "0.1.121", "private": true, "scripts": { "dev": "vite dev --host", diff --git a/src/lib/apis/images/index.ts b/src/lib/apis/images/index.ts index aadfafd14..3f624704e 100644 --- a/src/lib/apis/images/index.ts +++ b/src/lib/apis/images/index.ts @@ -72,10 +72,10 @@ export const updateImageGenerationConfig = async ( return res; }; -export const getOpenAIKey = async (token: string = '') => { +export const getOpenAIConfig = async (token: string = '') => { let error = null; - const res = await fetch(`${IMAGES_API_BASE_URL}/key`, { + const res = await fetch(`${IMAGES_API_BASE_URL}/openai/config`, { method: 'GET', headers: { Accept: 'application/json', @@ -101,13 +101,13 @@ export const getOpenAIKey = async (token: string = '') => { throw error; } - return res.OPENAI_API_KEY; + return res; }; -export const updateOpenAIKey = async (token: string = '', key: string) => { +export const updateOpenAIConfig = async (token: string = '', url: string, key: string) => { let error = null; - const res = await fetch(`${IMAGES_API_BASE_URL}/key/update`, { + const res = await fetch(`${IMAGES_API_BASE_URL}/openai/config/update`, { method: 'POST', headers: { Accept: 'application/json', @@ -115,6 +115,7 @@ export const updateOpenAIKey = async (token: string = '', key: string) => { ...(token && { authorization: `Bearer ${token}` }) }, body: JSON.stringify({ + url: url, key: key }) }) @@ -136,7 +137,7 @@ export const updateOpenAIKey = async (token: string = '', key: string) => { throw error; } - return res.OPENAI_API_KEY; + return res; }; export const getImageGenerationEngineUrls = async (token: string = '') => { diff --git a/src/lib/apis/streaming/index.ts b/src/lib/apis/streaming/index.ts new file mode 100644 index 000000000..5b89a4668 --- /dev/null +++ b/src/lib/apis/streaming/index.ts @@ -0,0 +1,70 @@ +type TextStreamUpdate = { + done: boolean; + value: string; +}; + +// createOpenAITextStream takes a ReadableStreamDefaultReader from an SSE response, +// and returns an async generator that emits delta updates with large deltas chunked into random sized chunks +export async function createOpenAITextStream( + messageStream: ReadableStreamDefaultReader, + splitLargeDeltas: boolean +): Promise> { + let iterator = openAIStreamToIterator(messageStream); + if (splitLargeDeltas) { + iterator = streamLargeDeltasAsRandomChunks(iterator); + } + return iterator; +} + +async function* openAIStreamToIterator( + reader: ReadableStreamDefaultReader +): AsyncGenerator { + while (true) { + const { value, done } = await reader.read(); + if (done) { + yield { done: true, value: '' }; + break; + } + const lines = value.split('\n'); + for (const line of lines) { + if (line !== '') { + console.log(line); + if (line === 'data: [DONE]') { + yield { done: true, value: '' }; + } else { + const data = JSON.parse(line.replace(/^data: /, '')); + console.log(data); + + yield { done: false, value: data.choices[0].delta.content ?? '' }; + } + } + } + } +} + +// streamLargeDeltasAsRandomChunks will chunk large deltas (length > 5) into random sized chunks between 1-3 characters +// This is to simulate a more fluid streaming, even though some providers may send large chunks of text at once +async function* streamLargeDeltasAsRandomChunks( + iterator: AsyncGenerator +): AsyncGenerator { + for await (const textStreamUpdate of iterator) { + if (textStreamUpdate.done) { + yield textStreamUpdate; + return; + } + let content = textStreamUpdate.value; + if (content.length < 5) { + yield { done: false, value: content }; + continue; + } + while (content != '') { + const chunkSize = Math.min(Math.floor(Math.random() * 3) + 1, content.length); + const chunk = content.slice(0, chunkSize); + yield { done: false, value: chunk }; + await sleep(5); + content = content.slice(chunkSize); + } + } +} + +const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)); diff --git a/src/lib/components/admin/Settings/Database.svelte b/src/lib/components/admin/Settings/Database.svelte index 7d3a34444..06a0d595c 100644 --- a/src/lib/components/admin/Settings/Database.svelte +++ b/src/lib/components/admin/Settings/Database.svelte @@ -1,6 +1,7 @@
@@ -643,9 +633,9 @@ >
- {#if Object.keys(modelDownloadStatus).length > 0} - {#each Object.keys(modelDownloadStatus) as model} - {#if 'pullProgress' in modelDownloadStatus[model]} + {#if Object.keys($MODEL_DOWNLOAD_POOL).length > 0} + {#each Object.keys($MODEL_DOWNLOAD_POOL) as model} + {#if 'pullProgress' in $MODEL_DOWNLOAD_POOL[model]}
{model}
@@ -655,10 +645,10 @@ class="dark:bg-gray-600 bg-gray-500 text-xs font-medium text-gray-100 text-center p-0.5 leading-none rounded-full" style="width: {Math.max( 15, - modelDownloadStatus[model].pullProgress ?? 0 + $MODEL_DOWNLOAD_POOL[model].pullProgress ?? 0 )}%" > - {modelDownloadStatus[model].pullProgress ?? 0}% + {$MODEL_DOWNLOAD_POOL[model].pullProgress ?? 0}%
@@ -689,9 +679,9 @@ - {#if 'digest' in modelDownloadStatus[model]} + {#if 'digest' in $MODEL_DOWNLOAD_POOL[model]}
- {modelDownloadStatus[model].digest} + {$MODEL_DOWNLOAD_POOL[model].digest}
{/if} @@ -1099,14 +1089,14 @@
diff --git a/src/lib/components/layout/Sidebar/ArchivedChatsModal.svelte b/src/lib/components/layout/Sidebar/ArchivedChatsModal.svelte index 6ae5286b4..51bcf1ad6 100644 --- a/src/lib/components/layout/Sidebar/ArchivedChatsModal.svelte +++ b/src/lib/components/layout/Sidebar/ArchivedChatsModal.svelte @@ -67,7 +67,7 @@
{#if chats.length > 0} -
+
= writable(undefined); +export const user: Writable = writable(undefined); // Frontend export const MODEL_DOWNLOAD_POOL = writable({}); @@ -14,10 +14,10 @@ export const chatId = writable(''); export const chats = writable([]); export const tags = writable([]); -export const models = writable([]); +export const models: Writable = writable([]); export const modelfiles = writable([]); -export const prompts = writable([]); +export const prompts: Writable = writable([]); export const documents = writable([ { collection_name: 'collection_name', @@ -33,6 +33,109 @@ export const documents = writable([ } ]); -export const settings = writable({}); +export const settings: Writable = writable({}); export const showSettings = writable(false); export const showChangelog = writable(false); + +type Model = OpenAIModel | OllamaModel; + +type OpenAIModel = { + id: string; + name: string; + external: boolean; + source?: string; +}; + +type OllamaModel = { + id: string; + name: string; + + // Ollama specific fields + details: OllamaModelDetails; + size: number; + description: string; + model: string; + modified_at: string; + digest: string; +}; + +type OllamaModelDetails = { + parent_model: string; + format: string; + family: string; + families: string[] | null; + parameter_size: string; + quantization_level: string; +}; + +type Settings = { + models?: string[]; + conversationMode?: boolean; + speechAutoSend?: boolean; + responseAutoPlayback?: boolean; + audio?: AudioSettings; + showUsername?: boolean; + saveChatHistory?: boolean; + notificationEnabled?: boolean; + title?: TitleSettings; + + system?: string; + requestFormat?: string; + keepAlive?: string; + seed?: number; + temperature?: string; + repeat_penalty?: string; + top_k?: string; + top_p?: string; + num_ctx?: string; + options?: ModelOptions; +}; + +type ModelOptions = { + stop?: boolean; +}; + +type AudioSettings = { + STTEngine?: string; + TTSEngine?: string; + speaker?: string; +}; + +type TitleSettings = { + auto?: boolean; + model?: string; + modelExternal?: string; + prompt?: string; +}; + +type Prompt = { + command: string; + user_id: string; + title: string; + content: string; + timestamp: number; +}; + +type Config = { + status?: boolean; + name?: string; + version?: string; + default_locale?: string; + images?: boolean; + default_models?: string[]; + default_prompt_suggestions?: PromptSuggestion[]; + trusted_header_auth?: boolean; +}; + +type PromptSuggestion = { + content: string; + title: [string, string]; +}; + +type SessionUser = { + id: string; + email: string; + name: string; + role: string; + profile_image_url: string; +}; diff --git a/src/lib/utils/index.ts b/src/lib/utils/index.ts index a24834c33..04cc22079 100644 --- a/src/lib/utils/index.ts +++ b/src/lib/utils/index.ts @@ -35,7 +35,6 @@ export const sanitizeResponseContent = (content: string) => { .replace(/<\|[a-z]+\|$/, '') .replace(/<$/, '') .replaceAll(/<\|[a-z]+\|>/g, ' ') - .replaceAll(//gi, '\n') .replaceAll('<', '<') .trim(); }; diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte index b4a375be7..0596a2a2a 100644 --- a/src/routes/(app)/+page.svelte +++ b/src/routes/(app)/+page.svelte @@ -39,6 +39,7 @@ import { RAGTemplate } from '$lib/utils/rag'; import { LITELLM_API_BASE_URL, OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL } from '$lib/constants'; import { WEBUI_BASE_URL } from '$lib/constants'; + import { createOpenAITextStream } from '$lib/apis/streaming'; const i18n = getContext('i18n'); @@ -599,38 +600,22 @@ .pipeThrough(splitStream('\n')) .getReader(); - while (true) { - const { value, done } = await reader.read(); + const textStream = await createOpenAITextStream(reader, $settings.splitLargeChunks); + console.log(textStream); + + for await (const update of textStream) { + const { value, done } = update; if (done || stopResponseFlag || _chatId !== $chatId) { responseMessage.done = true; messages = messages; break; } - try { - let lines = value.split('\n'); - - for (const line of lines) { - if (line !== '') { - console.log(line); - if (line === 'data: [DONE]') { - responseMessage.done = true; - messages = messages; - } else { - let data = JSON.parse(line.replace(/^data: /, '')); - console.log(data); - - if (responseMessage.content == '' && data.choices[0].delta.content == '\n') { - continue; - } else { - responseMessage.content += data.choices[0].delta.content ?? ''; - messages = messages; - } - } - } - } - } catch (error) { - console.log(error); + if (responseMessage.content == '' && value == '\n') { + continue; + } else { + responseMessage.content += value; + messages = messages; } if ($settings.notificationEnabled && !document.hasFocus()) { diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte index 7b3ffd7a2..105ee958c 100644 --- a/src/routes/(app)/c/[id]/+page.svelte +++ b/src/routes/(app)/c/[id]/+page.svelte @@ -42,6 +42,7 @@ OLLAMA_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; + import { createOpenAITextStream } from '$lib/apis/streaming'; const i18n = getContext('i18n'); @@ -611,38 +612,22 @@ .pipeThrough(splitStream('\n')) .getReader(); - while (true) { - const { value, done } = await reader.read(); + const textStream = await createOpenAITextStream(reader, $settings.splitLargeChunks); + console.log(textStream); + + for await (const update of textStream) { + const { value, done } = update; if (done || stopResponseFlag || _chatId !== $chatId) { responseMessage.done = true; messages = messages; break; } - try { - let lines = value.split('\n'); - - for (const line of lines) { - if (line !== '') { - console.log(line); - if (line === 'data: [DONE]') { - responseMessage.done = true; - messages = messages; - } else { - let data = JSON.parse(line.replace(/^data: /, '')); - console.log(data); - - if (responseMessage.content == '' && data.choices[0].delta.content == '\n') { - continue; - } else { - responseMessage.content += data.choices[0].delta.content ?? ''; - messages = messages; - } - } - } - } - } catch (error) { - console.log(error); + if (responseMessage.content == '' && value == '\n') { + continue; + } else { + responseMessage.content += value; + messages = messages; } if ($settings.notificationEnabled && !document.hasFocus()) { diff --git a/static/manifest.json b/static/manifest.json index e69de29bb..0967ef424 100644 --- a/static/manifest.json +++ b/static/manifest.json @@ -0,0 +1 @@ +{}