diff --git a/CHANGELOG.md b/CHANGELOG.md index a62e18871..98ba0c4c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,39 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.4.8] - 2024-12-07 + +### Added + +- **🔓 Bypass Model Access Control**: Introduced the 'BYPASS_MODEL_ACCESS_CONTROL' environment variable. Easily bypass model access controls for user roles when access control isn't required, simplifying workflows for trusted environments. +- **📝 Markdown in Banners**: Now supports markdown for banners, enabling richer, more visually engaging announcements. +- **🌐 Internationalization Updates**: Enhanced translations across multiple languages, further improving accessibility and global user experience. +- **🎨 Styling Enhancements**: General UI style refinements for a cleaner and more polished interface. +- **📋 Rich Text Reliability**: Improved the reliability and stability of rich text input across chats for smoother interactions. + +### Fixed + +- **💡 Tailwind Build Issue**: Resolved a breaking bug caused by Tailwind, ensuring smoother builds and overall system reliability. +- **📚 Knowledge Collection Query Fix**: Addressed API endpoint issues with querying knowledge collections, ensuring accurate and reliable information retrieval. + +## [0.4.7] - 2024-12-01 + +### Added + +- **✨ Prompt Input Auto-Completion**: Type a prompt and let AI intelligently suggest and complete your inputs. Simply press 'Tab' or swipe right on mobile to confirm. Available only with Rich Text Input (default setting). Disable via Admin Settings for full control. +- **🌍 Improved Translations**: Enhanced localization for multiple languages, ensuring a more polished and accessible experience for international users. + +### Fixed + +- **🛠️ Tools Export Issue**: Resolved a critical issue where exporting tools wasn’t functioning, restoring seamless export capabilities. +- **🔗 Model ID Registration**: Fixed an issue where model IDs weren’t registering correctly in the model editor, ensuring reliable model setup and tracking. +- **🖋️ Textarea Auto-Expansion**: Corrected a bug where textareas didn’t expand automatically on certain browsers, improving usability for multi-line inputs. +- **🔧 Ollama Embed Endpoint**: Addressed the /ollama/embed endpoint malfunction, ensuring consistent performance and functionality. + +### Changed + +- **🎨 Knowledge Base Styling**: Refined knowledge base visuals for a cleaner, more modern look, laying the groundwork for further enhancements in upcoming releases. + ## [0.4.6] - 2024-11-26 ### Added diff --git a/backend/open_webui/apps/audio/main.py b/backend/open_webui/apps/audio/main.py index 5c24c2633..a3972f19f 100644 --- a/backend/open_webui/apps/audio/main.py +++ b/backend/open_webui/apps/audio/main.py @@ -45,7 +45,7 @@ from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse from pydantic import BaseModel -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user # Constants MAX_FILE_SIZE_MB = 25 diff --git a/backend/open_webui/apps/images/main.py b/backend/open_webui/apps/images/main.py index 62c76425d..14209df2f 100644 --- a/backend/open_webui/apps/images/main.py +++ b/backend/open_webui/apps/images/main.py @@ -40,7 +40,7 @@ from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS from fastapi import Depends, FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["IMAGES"]) @@ -117,7 +117,7 @@ class OpenAIConfigForm(BaseModel): class Automatic1111ConfigForm(BaseModel): AUTOMATIC1111_BASE_URL: str AUTOMATIC1111_API_AUTH: str - AUTOMATIC1111_CFG_SCALE: Optional[str] + AUTOMATIC1111_CFG_SCALE: Optional[str | float | int] AUTOMATIC1111_SAMPLER: Optional[str] AUTOMATIC1111_SCHEDULER: Optional[str] diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index 0ac1f0401..48142fd9f 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -24,6 +24,7 @@ from open_webui.config import ( from open_webui.env import ( AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, + BYPASS_MODEL_ACCESS_CONTROL, ) @@ -44,7 +45,7 @@ from open_webui.utils.payload import ( apply_model_params_to_body_openai, apply_model_system_prompt_to_body, ) -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access log = logging.getLogger(__name__) @@ -359,7 +360,7 @@ async def get_ollama_tags( detail=error_detail, ) - if user.role == "user": + if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: # Filter models based on user access control filtered_models = [] for model in models.get("models", []): @@ -432,6 +433,26 @@ async def get_ollama_versions(url_idx: Optional[int] = None): return {"version": False} +@app.get("/api/ps") +async def get_ollama_loaded_models(user=Depends(get_verified_user)): + """ + List models that are currently loaded into Ollama memory, and which node they are loaded on. + """ + if app.state.config.ENABLE_OLLAMA_API: + tasks = [ + aiohttp_get( + f"{url}/api/ps", + app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None), + ) + for url in app.state.config.OLLAMA_BASE_URLS + ] + responses = await asyncio.gather(*tasks) + + return dict(zip(app.state.config.OLLAMA_BASE_URLS, responses)) + else: + return {} + + class ModelNameForm(BaseModel): name: str @@ -706,7 +727,7 @@ async def generate_embeddings( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - return generate_ollama_batch_embeddings(form_data, url_idx) + return await generate_ollama_batch_embeddings(form_data, url_idx) @app.post("/api/embeddings") @@ -946,6 +967,9 @@ async def generate_chat_completion( user=Depends(get_verified_user), bypass_filter: Optional[bool] = False, ): + if BYPASS_MODEL_ACCESS_CONTROL: + bypass_filter = True + payload = {**form_data.model_dump(exclude_none=True)} log.debug(f"generate_chat_completion() - 1.payload = {payload}") if "metadata" in payload: @@ -1031,6 +1055,82 @@ class OpenAIChatCompletionForm(BaseModel): model_config = ConfigDict(extra="allow") +class OpenAICompletionForm(BaseModel): + model: str + prompt: str + + model_config = ConfigDict(extra="allow") + + +@app.post("/v1/completions") +@app.post("/v1/completions/{url_idx}") +async def generate_openai_completion( + form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user) +): + try: + form_data = OpenAICompletionForm(**form_data) + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=400, + detail=str(e), + ) + + payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])} + if "metadata" in payload: + del payload["metadata"] + + model_id = form_data.model + if ":" not in model_id: + model_id = f"{model_id}:latest" + + model_info = Models.get_model_by_id(model_id) + if model_info: + if model_info.base_model_id: + payload["model"] = model_info.base_model_id + params = model_info.params.model_dump() + + if params: + payload = apply_model_params_to_body_openai(params, payload) + + # Check if user has access to the model + if user.role == "user": + if not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + else: + if user.role != "admin": + raise HTTPException( + status_code=403, + detail="Model not found", + ) + + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" + + url = await get_ollama_url(url_idx, payload["model"]) + log.info(f"url: {url}") + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + prefix_id = api_config.get("prefix_id", None) + + if prefix_id: + payload["model"] = payload["model"].replace(f"{prefix_id}.", "") + + return await post_streaming_url( + f"{url}/v1/completions", + json.dumps(payload), + stream=payload.get("stream", False), + ) + + @app.post("/v1/chat/completions") @app.post("/v1/chat/completions/{url_idx}") async def generate_openai_chat_completion( @@ -1156,7 +1256,7 @@ async def get_openai_models( detail=error_detail, ) - if user.role == "user": + if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: # Filter models based on user access control filtered_models = [] for model in models: diff --git a/backend/open_webui/apps/openai/main.py b/backend/open_webui/apps/openai/main.py index 31c36a8a1..b64e7b28d 100644 --- a/backend/open_webui/apps/openai/main.py +++ b/backend/open_webui/apps/openai/main.py @@ -24,6 +24,7 @@ from open_webui.env import ( AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, ENABLE_FORWARD_USER_INFO_HEADERS, + BYPASS_MODEL_ACCESS_CONTROL, ) from open_webui.constants import ERROR_MESSAGES @@ -39,7 +40,7 @@ from open_webui.utils.payload import ( apply_model_system_prompt_to_body, ) -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access @@ -422,7 +423,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us error_detail = f"Unexpected error: {str(e)}" raise HTTPException(status_code=500, detail=error_detail) - if user.role == "user": + if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: # Filter models based on user access control filtered_models = [] for model in models.get("data", []): diff --git a/backend/open_webui/apps/retrieval/loaders/youtube.py b/backend/open_webui/apps/retrieval/loaders/youtube.py index ad1088be0..8eb48488b 100644 --- a/backend/open_webui/apps/retrieval/loaders/youtube.py +++ b/backend/open_webui/apps/retrieval/loaders/youtube.py @@ -1,7 +1,12 @@ +import logging + from typing import Any, Dict, Generator, List, Optional, Sequence, Union from urllib.parse import parse_qs, urlparse from langchain_core.documents import Document +from open_webui.env import SRC_LOG_LEVELS +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) ALLOWED_SCHEMES = {"http", "https"} ALLOWED_NETLOCS = { @@ -51,12 +56,14 @@ class YoutubeLoader: self, video_id: str, language: Union[str, Sequence[str]] = "en", + proxy_url: Optional[str] = None, ): """Initialize with YouTube video ID.""" _video_id = _parse_video_id(video_id) self.video_id = _video_id if _video_id is not None else video_id self._metadata = {"source": video_id} self.language = language + self.proxy_url = proxy_url if isinstance(language, str): self.language = [language] else: @@ -76,10 +83,22 @@ class YoutubeLoader: "Please install it with `pip install youtube-transcript-api`." ) + if self.proxy_url: + youtube_proxies = { + "http": self.proxy_url, + "https": self.proxy_url, + } + # Don't log complete URL because it might contain secrets + log.debug(f"Using proxy URL: {self.proxy_url[:14]}...") + else: + youtube_proxies = None + try: - transcript_list = YouTubeTranscriptApi.list_transcripts(self.video_id) + transcript_list = YouTubeTranscriptApi.list_transcripts( + self.video_id, proxies=youtube_proxies + ) except Exception as e: - print(e) + log.exception("Loading YouTube transcript failed") return [] try: diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py index 63bc18190..cfbc5beee 100644 --- a/backend/open_webui/apps/retrieval/main.py +++ b/backend/open_webui/apps/retrieval/main.py @@ -29,6 +29,7 @@ from open_webui.apps.retrieval.loaders.youtube import YoutubeLoader from open_webui.apps.retrieval.web.main import SearchResult from open_webui.apps.retrieval.web.utils import get_web_loader from open_webui.apps.retrieval.web.brave import search_brave +from open_webui.apps.retrieval.web.kagi import search_kagi from open_webui.apps.retrieval.web.mojeek import search_mojeek from open_webui.apps.retrieval.web.duckduckgo import search_duckduckgo from open_webui.apps.retrieval.web.google_pse import search_google_pse @@ -54,6 +55,7 @@ from open_webui.apps.retrieval.utils import ( from open_webui.apps.webui.models.files import Files from open_webui.config import ( BRAVE_SEARCH_API_KEY, + KAGI_SEARCH_API_KEY, MOJEEK_SEARCH_API_KEY, TIKTOKEN_ENCODING_NAME, RAG_TEXT_SPLITTER, @@ -105,6 +107,7 @@ from open_webui.config import ( TIKA_SERVER_URL, UPLOAD_DIR, YOUTUBE_LOADER_LANGUAGE, + YOUTUBE_LOADER_PROXY_URL, DEFAULT_LOCALE, AppConfig, ) @@ -120,7 +123,7 @@ from open_webui.utils.misc import ( extract_folders_after_data_docs, sanitize_filename, ) -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter from langchain_core.documents import Document @@ -171,6 +174,7 @@ app.state.config.OLLAMA_API_KEY = RAG_OLLAMA_API_KEY app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE +app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL app.state.YOUTUBE_LOADER_TRANSLATION = None @@ -182,6 +186,7 @@ app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY +app.state.config.KAGI_SEARCH_API_KEY = KAGI_SEARCH_API_KEY app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS @@ -471,6 +476,7 @@ async def get_rag_config(user=Depends(get_admin_user)): "youtube": { "language": app.state.config.YOUTUBE_LOADER_LANGUAGE, "translation": app.state.YOUTUBE_LOADER_TRANSLATION, + "proxy_url": app.state.config.YOUTUBE_LOADER_PROXY_URL, }, "web": { "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, @@ -481,6 +487,7 @@ async def get_rag_config(user=Depends(get_admin_user)): "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY, "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID, "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY, + "kagi_search_api_key": app.state.config.KAGI_SEARCH_API_KEY, "mojeek_search_api_key": app.state.config.MOJEEK_SEARCH_API_KEY, "serpstack_api_key": app.state.config.SERPSTACK_API_KEY, "serpstack_https": app.state.config.SERPSTACK_HTTPS, @@ -518,6 +525,7 @@ class ChunkParamUpdateForm(BaseModel): class YoutubeLoaderConfig(BaseModel): language: list[str] translation: Optional[str] = None + proxy_url: str = "" class WebSearchConfig(BaseModel): @@ -527,6 +535,7 @@ class WebSearchConfig(BaseModel): google_pse_api_key: Optional[str] = None google_pse_engine_id: Optional[str] = None brave_search_api_key: Optional[str] = None + kagi_search_api_key: Optional[str] = None mojeek_search_api_key: Optional[str] = None serpstack_api_key: Optional[str] = None serpstack_https: Optional[bool] = None @@ -580,6 +589,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ if form_data.youtube is not None: app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language + app.state.config.YOUTUBE_LOADER_PROXY_URL = form_data.youtube.proxy_url app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation if form_data.web is not None: @@ -598,6 +608,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ app.state.config.BRAVE_SEARCH_API_KEY = ( form_data.web.search.brave_search_api_key ) + app.state.config.KAGI_SEARCH_API_KEY = form_data.web.search.kagi_search_api_key app.state.config.MOJEEK_SEARCH_API_KEY = ( form_data.web.search.mojeek_search_api_key ) @@ -640,6 +651,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ }, "youtube": { "language": app.state.config.YOUTUBE_LOADER_LANGUAGE, + "proxy_url": app.state.config.YOUTUBE_LOADER_PROXY_URL, "translation": app.state.YOUTUBE_LOADER_TRANSLATION, }, "web": { @@ -651,6 +663,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY, "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID, "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY, + "kagi_search_api_key": app.state.config.KAGI_SEARCH_API_KEY, "mojeek_search_api_key": app.state.config.MOJEEK_SEARCH_API_KEY, "serpstack_api_key": app.state.config.SERPSTACK_API_KEY, "serpstack_https": app.state.config.SERPSTACK_HTTPS, @@ -867,7 +880,7 @@ def save_docs_to_vector_db( return True except Exception as e: log.exception(e) - return False + raise e class ProcessFileForm(BaseModel): @@ -897,7 +910,7 @@ def process_file( docs = [ Document( - page_content=form_data.content, + page_content=form_data.content.replace("
", "\n"), metadata={ **file.meta, "name": file.filename, @@ -1081,7 +1094,9 @@ def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_u collection_name = calculate_sha256_string(form_data.url)[:63] loader = YoutubeLoader( - form_data.url, language=app.state.config.YOUTUBE_LOADER_LANGUAGE + form_data.url, + language=app.state.config.YOUTUBE_LOADER_LANGUAGE, + proxy_url=app.state.config.YOUTUBE_LOADER_PROXY_URL, ) docs = loader.load() @@ -1154,6 +1169,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]: - SEARXNG_QUERY_URL - GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID - BRAVE_SEARCH_API_KEY + - KAGI_SEARCH_API_KEY - MOJEEK_SEARCH_API_KEY - SERPSTACK_API_KEY - SERPER_API_KEY @@ -1201,6 +1217,16 @@ def search_web(engine: str, query: str) -> list[SearchResult]: ) else: raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables") + elif engine == "kagi": + if app.state.config.KAGI_SEARCH_API_KEY: + return search_kagi( + app.state.config.KAGI_SEARCH_API_KEY, + query, + app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + ) + else: + raise Exception("No KAGI_SEARCH_API_KEY found in environment variables") elif engine == "mojeek": if app.state.config.MOJEEK_SEARCH_API_KEY: return search_mojeek( @@ -1362,8 +1388,7 @@ def query_doc_handler( else: return query_doc( collection_name=form_data.collection_name, - query=form_data.query, - embedding_function=app.state.EMBEDDING_FUNCTION, + query_embedding=app.state.EMBEDDING_FUNCTION(form_data.query), k=form_data.k if form_data.k else app.state.config.TOP_K, ) except Exception as e: @@ -1391,7 +1416,7 @@ def query_collection_handler( if app.state.config.ENABLE_RAG_HYBRID_SEARCH: return query_collection_with_hybrid_search( collection_names=form_data.collection_names, - query=form_data.query, + queries=[form_data.query], embedding_function=app.state.EMBEDDING_FUNCTION, k=form_data.k if form_data.k else app.state.config.TOP_K, reranking_function=app.state.sentence_transformer_rf, @@ -1402,7 +1427,7 @@ def query_collection_handler( else: return query_collection( collection_names=form_data.collection_names, - query=form_data.query, + queries=[form_data.query], embedding_function=app.state.EMBEDDING_FUNCTION, k=form_data.k if form_data.k else app.state.config.TOP_K, ) diff --git a/backend/open_webui/apps/retrieval/web/kagi.py b/backend/open_webui/apps/retrieval/web/kagi.py new file mode 100644 index 000000000..c8c2699ed --- /dev/null +++ b/backend/open_webui/apps/retrieval/web/kagi.py @@ -0,0 +1,50 @@ +import logging +from typing import Optional + +import requests +from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +def search_kagi( + api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None +) -> list[SearchResult]: + """Search using Kagi's Search API and return the results as a list of SearchResult objects. + + The Search API will inherit the settings in your account, including results personalization and snippet length. + + Args: + api_key (str): A Kagi Search API key + query (str): The query to search for + count (int): The number of results to return + """ + url = "https://kagi.com/api/v0/search" + headers = { + "Authorization": f"Bot {api_key}", + } + params = {"q": query, "limit": count} + + response = requests.get(url, headers=headers, params=params) + response.raise_for_status() + json_response = response.json() + search_results = json_response.get("data", []) + + results = [ + SearchResult( + link=result["url"], + title=result["title"], + snippet=result.get("snippet") + ) + for result in search_results + if result["t"] == 0 + ] + + print(results) + + if filter_list: + results = get_filtered_results(results, filter_list) + + return results diff --git a/backend/open_webui/apps/socket/main.py b/backend/open_webui/apps/socket/main.py index 5c284f18d..8ec8937a1 100644 --- a/backend/open_webui/apps/socket/main.py +++ b/backend/open_webui/apps/socket/main.py @@ -12,7 +12,7 @@ from open_webui.env import ( WEBSOCKET_MANAGER, WEBSOCKET_REDIS_URL, ) -from open_webui.utils.utils import decode_token +from open_webui.utils.auth import decode_token from open_webui.apps.socket.utils import RedisDict from open_webui.env import ( diff --git a/backend/open_webui/apps/webui/models/auths.py b/backend/open_webui/apps/webui/models/auths.py index ead897347..391b2e9ec 100644 --- a/backend/open_webui/apps/webui/models/auths.py +++ b/backend/open_webui/apps/webui/models/auths.py @@ -7,7 +7,7 @@ from open_webui.apps.webui.models.users import UserModel, Users from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel from sqlalchemy import Boolean, Column, String, Text -from open_webui.utils.utils import verify_password +from open_webui.utils.auth import verify_password log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) diff --git a/backend/open_webui/apps/webui/models/tools.py b/backend/open_webui/apps/webui/models/tools.py index b628f4f9f..8f798c317 100644 --- a/backend/open_webui/apps/webui/models/tools.py +++ b/backend/open_webui/apps/webui/models/tools.py @@ -76,6 +76,10 @@ class ToolModel(BaseModel): #################### +class ToolUserModel(ToolModel): + user: Optional[UserResponse] = None + + class ToolResponse(BaseModel): id: str user_id: str @@ -138,13 +142,13 @@ class ToolsTable: except Exception: return None - def get_tools(self) -> list[ToolUserResponse]: + def get_tools(self) -> list[ToolUserModel]: with get_db() as db: tools = [] for tool in db.query(Tool).order_by(Tool.updated_at.desc()).all(): user = Users.get_user_by_id(tool.user_id) tools.append( - ToolUserResponse.model_validate( + ToolUserModel.model_validate( { **ToolModel.model_validate(tool).model_dump(), "user": user.model_dump() if user else None, @@ -155,7 +159,7 @@ class ToolsTable: def get_tools_by_user_id( self, user_id: str, permission: str = "write" - ) -> list[ToolUserResponse]: + ) -> list[ToolUserModel]: tools = self.get_tools() return [ diff --git a/backend/open_webui/apps/webui/routers/auths.py b/backend/open_webui/apps/webui/routers/auths.py index 8f175f366..094ce568f 100644 --- a/backend/open_webui/apps/webui/routers/auths.py +++ b/backend/open_webui/apps/webui/routers/auths.py @@ -3,6 +3,7 @@ import uuid import time import datetime import logging +from aiohttp import ClientSession from open_webui.apps.webui.models.auths import ( AddUserForm, @@ -29,10 +30,14 @@ from open_webui.env import ( SRC_LOG_LEVELS, ) from fastapi import APIRouter, Depends, HTTPException, Request, status -from fastapi.responses import Response +from fastapi.responses import RedirectResponse, Response +from open_webui.config import ( + OPENID_PROVIDER_URL, + ENABLE_OAUTH_SIGNUP, +) from pydantic import BaseModel from open_webui.utils.misc import parse_duration, validate_email_format -from open_webui.utils.utils import ( +from open_webui.utils.auth import ( create_api_key, create_token, get_admin_user, @@ -498,8 +503,31 @@ async def signup(request: Request, response: Response, form_data: SignupForm): @router.get("/signout") -async def signout(response: Response): +async def signout(request: Request, response: Response): response.delete_cookie("token") + + if ENABLE_OAUTH_SIGNUP.value: + oauth_id_token = request.cookies.get("oauth_id_token") + if oauth_id_token: + try: + async with ClientSession() as session: + async with session.get(OPENID_PROVIDER_URL.value) as resp: + if resp.status == 200: + openid_data = await resp.json() + logout_url = openid_data.get("end_session_endpoint") + if logout_url: + response.delete_cookie("oauth_id_token") + return RedirectResponse( + url=f"{logout_url}?id_token_hint={oauth_id_token}" + ) + else: + raise HTTPException( + status_code=resp.status, + detail="Failed to fetch OpenID configuration", + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + return {"status": True} diff --git a/backend/open_webui/apps/webui/routers/chats.py b/backend/open_webui/apps/webui/routers/chats.py index db95337d5..ec5dae4bf 100644 --- a/backend/open_webui/apps/webui/routers/chats.py +++ b/backend/open_webui/apps/webui/routers/chats.py @@ -19,7 +19,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status from pydantic import BaseModel -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_permission log = logging.getLogger(__name__) @@ -607,7 +607,6 @@ async def add_tag_by_id_and_tag_name( detail=ERROR_MESSAGES.DEFAULT("Tag name cannot be 'None'"), ) - print(tags, tag_id) if tag_id not in tags: Chats.add_chat_tag_by_id_and_user_id_and_tag_name( id, user.id, form_data.name diff --git a/backend/open_webui/apps/webui/routers/configs.py b/backend/open_webui/apps/webui/routers/configs.py index 7466e6fda..ef6c4d8c1 100644 --- a/backend/open_webui/apps/webui/routers/configs.py +++ b/backend/open_webui/apps/webui/routers/configs.py @@ -3,7 +3,7 @@ from pydantic import BaseModel from typing import Optional -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.config import get_config, save_config from open_webui.config import BannerModel diff --git a/backend/open_webui/apps/webui/routers/evaluations.py b/backend/open_webui/apps/webui/routers/evaluations.py index b9e3bff29..0bcee2a79 100644 --- a/backend/open_webui/apps/webui/routers/evaluations.py +++ b/backend/open_webui/apps/webui/routers/evaluations.py @@ -11,7 +11,7 @@ from open_webui.apps.webui.models.feedbacks import ( ) from open_webui.constants import ERROR_MESSAGES -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user router = APIRouter() diff --git a/backend/open_webui/apps/webui/routers/files.py b/backend/open_webui/apps/webui/routers/files.py index e7459a15f..4b7cf1ed4 100644 --- a/backend/open_webui/apps/webui/routers/files.py +++ b/backend/open_webui/apps/webui/routers/files.py @@ -25,7 +25,7 @@ from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status from fastapi.responses import FileResponse, StreamingResponse -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) diff --git a/backend/open_webui/apps/webui/routers/folders.py b/backend/open_webui/apps/webui/routers/folders.py index 36075c357..f05781476 100644 --- a/backend/open_webui/apps/webui/routers/folders.py +++ b/backend/open_webui/apps/webui/routers/folders.py @@ -24,7 +24,7 @@ from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status from fastapi.responses import FileResponse, StreamingResponse -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) diff --git a/backend/open_webui/apps/webui/routers/functions.py b/backend/open_webui/apps/webui/routers/functions.py index aeaceecfb..bdd422b95 100644 --- a/backend/open_webui/apps/webui/routers/functions.py +++ b/backend/open_webui/apps/webui/routers/functions.py @@ -12,7 +12,7 @@ from open_webui.apps.webui.utils import load_function_module_by_id, replace_impo from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user router = APIRouter() diff --git a/backend/open_webui/apps/webui/routers/groups.py b/backend/open_webui/apps/webui/routers/groups.py index 59d7d0052..ef392fb6a 100644 --- a/backend/open_webui/apps/webui/routers/groups.py +++ b/backend/open_webui/apps/webui/routers/groups.py @@ -12,7 +12,7 @@ from open_webui.apps.webui.models.groups import ( from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user router = APIRouter() diff --git a/backend/open_webui/apps/webui/routers/knowledge.py b/backend/open_webui/apps/webui/routers/knowledge.py index 1b063cda2..d572e83b7 100644 --- a/backend/open_webui/apps/webui/routers/knowledge.py +++ b/backend/open_webui/apps/webui/routers/knowledge.py @@ -16,7 +16,7 @@ from open_webui.apps.retrieval.main import process_file, ProcessFileForm from open_webui.constants import ERROR_MESSAGES -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission diff --git a/backend/open_webui/apps/webui/routers/memories.py b/backend/open_webui/apps/webui/routers/memories.py index ccf84a9d4..60993607f 100644 --- a/backend/open_webui/apps/webui/routers/memories.py +++ b/backend/open_webui/apps/webui/routers/memories.py @@ -5,7 +5,7 @@ from typing import Optional from open_webui.apps.webui.models.memories import Memories, MemoryModel from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT -from open_webui.utils.utils import get_verified_user +from open_webui.utils.auth import get_verified_user from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/apps/webui/routers/models.py b/backend/open_webui/apps/webui/routers/models.py index 6a8085385..2e073219a 100644 --- a/backend/open_webui/apps/webui/routers/models.py +++ b/backend/open_webui/apps/webui/routers/models.py @@ -11,7 +11,7 @@ from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission diff --git a/backend/open_webui/apps/webui/routers/prompts.py b/backend/open_webui/apps/webui/routers/prompts.py index 7cacde606..89a60fd95 100644 --- a/backend/open_webui/apps/webui/routers/prompts.py +++ b/backend/open_webui/apps/webui/routers/prompts.py @@ -8,7 +8,7 @@ from open_webui.apps.webui.models.prompts import ( ) from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, status, Request -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission router = APIRouter() diff --git a/backend/open_webui/apps/webui/routers/tools.py b/backend/open_webui/apps/webui/routers/tools.py index d0523ddac..410f12d64 100644 --- a/backend/open_webui/apps/webui/routers/tools.py +++ b/backend/open_webui/apps/webui/routers/tools.py @@ -13,7 +13,7 @@ from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status from open_webui.utils.tools import get_tools_specs -from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission diff --git a/backend/open_webui/apps/webui/routers/users.py b/backend/open_webui/apps/webui/routers/users.py index b6b91a5c3..92131b9ad 100644 --- a/backend/open_webui/apps/webui/routers/users.py +++ b/backend/open_webui/apps/webui/routers/users.py @@ -14,7 +14,7 @@ from open_webui.constants import ERROR_MESSAGES from open_webui.env import SRC_LOG_LEVELS from fastapi import APIRouter, Depends, HTTPException, Request, status from pydantic import BaseModel -from open_webui.utils.utils import get_admin_user, get_password_hash, get_verified_user +from open_webui.utils.auth import get_admin_user, get_password_hash, get_verified_user log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) diff --git a/backend/open_webui/apps/webui/routers/utils.py b/backend/open_webui/apps/webui/routers/utils.py index 0ab0f6b15..a4c33a03b 100644 --- a/backend/open_webui/apps/webui/routers/utils.py +++ b/backend/open_webui/apps/webui/routers/utils.py @@ -9,7 +9,7 @@ from pydantic import BaseModel from starlette.responses import FileResponse from open_webui.utils.misc import get_gravatar_url from open_webui.utils.pdf_generator import PDFGenerator -from open_webui.utils.utils import get_admin_user +from open_webui.utils.auth import get_admin_user router = APIRouter() diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 0a76626c1..905d2472a 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -429,6 +429,12 @@ OAUTH_ADMIN_ROLES = PersistentConfig( [role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")], ) +OAUTH_ALLOWED_DOMAINS = PersistentConfig( + "OAUTH_ALLOWED_DOMAINS", + "oauth.allowed_domains", + [domain.strip() for domain in os.environ.get("OAUTH_ALLOWED_DOMAINS", "*").split(",")], +) + def load_oauth_providers(): OAUTH_PROVIDERS.clear() @@ -583,6 +589,12 @@ OLLAMA_API_BASE_URL = os.environ.get( ) OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "") +if OLLAMA_BASE_URL: + # Remove trailing slash + OLLAMA_BASE_URL = ( + OLLAMA_BASE_URL[:-1] if OLLAMA_BASE_URL.endswith("/") else OLLAMA_BASE_URL + ) + K8S_FLAG = os.environ.get("K8S_FLAG", "") USE_OLLAMA_DOCKER = os.environ.get("USE_OLLAMA_DOCKER", "false") @@ -696,6 +708,7 @@ ENABLE_LOGIN_FORM = PersistentConfig( os.environ.get("ENABLE_LOGIN_FORM", "True").lower() == "true", ) + DEFAULT_LOCALE = PersistentConfig( "DEFAULT_LOCALE", "ui.default_locale", @@ -752,7 +765,6 @@ DEFAULT_USER_ROLE = PersistentConfig( os.getenv("DEFAULT_USER_ROLE", "pending"), ) - USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS = ( os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS", "False").lower() == "true" @@ -998,6 +1010,66 @@ Strictly return in JSON format: """ +ENABLE_AUTOCOMPLETE_GENERATION = PersistentConfig( + "ENABLE_AUTOCOMPLETE_GENERATION", + "task.autocomplete.enable", + os.environ.get("ENABLE_AUTOCOMPLETE_GENERATION", "True").lower() == "true", +) + +AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = PersistentConfig( + "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH", + "task.autocomplete.input_max_length", + int(os.environ.get("AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH", "-1")), +) + +AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = PersistentConfig( + "AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE", + "task.autocomplete.prompt_template", + os.environ.get("AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE", ""), +) + + +DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = """### Task: +You are an autocompletion system. Continue the text in `` based on the **completion type** in `` and the given language. + +### **Instructions**: +1. Analyze `` for context and meaning. +2. Use `` to guide your output: + - **General**: Provide a natural, concise continuation. + - **Search Query**: Complete as if generating a realistic search query. +3. Start as if you are directly continuing ``. Do **not** repeat, paraphrase, or respond as a model. Simply complete the text. +4. Ensure the continuation: + - Flows naturally from ``. + - Avoids repetition, overexplaining, or unrelated ideas. +5. If unsure, return: `{ "text": "" }`. + +### **Output Rules**: +- Respond only in JSON format: `{ "text": "" }`. + +### **Examples**: +#### Example 1: +Input: +General +The sun was setting over the horizon, painting the sky +Output: +{ "text": "with vibrant shades of orange and pink." } + +#### Example 2: +Input: +Search Query +Top-rated restaurants in +Output: +{ "text": "New York City for Italian cuisine." } + +--- +### Context: + +{{MESSAGES:END:6}} + +{{TYPE}} +{{PROMPT}} +#### Output: +""" TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", @@ -1259,6 +1331,12 @@ YOUTUBE_LOADER_LANGUAGE = PersistentConfig( os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","), ) +YOUTUBE_LOADER_PROXY_URL = PersistentConfig( + "YOUTUBE_LOADER_PROXY_URL", + "rag.youtube_loader_proxy_url", + os.getenv("YOUTUBE_LOADER_PROXY_URL", ""), +) + ENABLE_RAG_WEB_SEARCH = PersistentConfig( "ENABLE_RAG_WEB_SEARCH", @@ -1308,6 +1386,12 @@ BRAVE_SEARCH_API_KEY = PersistentConfig( os.getenv("BRAVE_SEARCH_API_KEY", ""), ) +KAGI_SEARCH_API_KEY = PersistentConfig( + "KAGI_SEARCH_API_KEY", + "rag.web.search.kagi_search_api_key", + os.getenv("KAGI_SEARCH_API_KEY", ""), +) + MOJEEK_SEARCH_API_KEY = PersistentConfig( "MOJEEK_SEARCH_API_KEY", "rag.web.search.mojeek_search_api_key", diff --git a/backend/open_webui/constants.py b/backend/open_webui/constants.py index 9c7d6f9e9..c5fdfabfb 100644 --- a/backend/open_webui/constants.py +++ b/backend/open_webui/constants.py @@ -113,5 +113,6 @@ class TASKS(str, Enum): TAGS_GENERATION = "tags_generation" EMOJI_GENERATION = "emoji_generation" QUERY_GENERATION = "query_generation" + AUTOCOMPLETE_GENERATION = "autocomplete_generation" FUNCTION_CALLING = "function_calling" MOA_RESPONSE_GENERATION = "moa_response_generation" diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 6e591311d..e1b350ead 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -329,6 +329,9 @@ WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( ) WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None) +BYPASS_MODEL_ACCESS_CONTROL = ( + os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true" +) #################################### # WEBUI_SECRET_KEY @@ -373,7 +376,7 @@ else: AIOHTTP_CLIENT_TIMEOUT = 300 AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get( - "AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "3" + "AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "5" ) if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "": @@ -384,7 +387,7 @@ else: AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST ) except Exception: - AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 3 + AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 5 #################################### # OFFLINE_MODE diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index aa936db47..253a7a165 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -89,6 +89,10 @@ from open_webui.config import ( DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, TITLE_GENERATION_PROMPT_TEMPLATE, TAGS_GENERATION_PROMPT_TEMPLATE, + ENABLE_AUTOCOMPLETE_GENERATION, + AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, + AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, + DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, WEBHOOK_URL, WEBUI_AUTH, @@ -108,6 +112,7 @@ from open_webui.env import ( WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE, WEBUI_URL, + BYPASS_MODEL_ACCESS_CONTROL, RESET_CONFIG_ON_START, OFFLINE_MODE, ) @@ -127,13 +132,14 @@ from open_webui.utils.task import ( rag_template, title_generation_template, query_generation_template, + autocomplete_generation_template, tags_generation_template, emoji_generation_template, moa_response_generation_template, tools_function_calling_generation_template, ) from open_webui.utils.tools import get_tools -from open_webui.utils.utils import ( +from open_webui.utils.auth import ( decode_token, get_admin_user, get_current_user, @@ -207,6 +213,11 @@ app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE +app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION +app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( + AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH +) + app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE @@ -215,6 +226,10 @@ app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE +app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = ( + AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE +) + app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE ) @@ -465,8 +480,6 @@ async def chat_completion_tools_handler( except Exception as e: tool_output = str(e) - print(tools[tool_function_name]["citation"]) - if isinstance(tool_output, str): if tools[tool_function_name]["citation"]: sources.append( @@ -607,7 +620,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ) model_info = Models.get_model_by_id(model["id"]) - if user.role == "user": + if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: if model.get("arena"): if not has_access( user.id, @@ -1210,7 +1223,7 @@ async def get_models(user=Depends(get_verified_user)): ) # Filter out models that the user does not have access to - if user.role == "user": + if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: filtered_models = [] for model in models: if model.get("arena"): @@ -1252,6 +1265,9 @@ async def get_base_models(user=Depends(get_admin_user)): async def generate_chat_completions( form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False ): + if BYPASS_MODEL_ACCESS_CONTROL: + bypass_filter = True + model_list = await get_all_models() models = {model["id"]: model for model in model_list} @@ -1665,6 +1681,8 @@ async def get_task_config(user=Depends(get_verified_user)): "TASK_MODEL": app.state.config.TASK_MODEL, "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "ENABLE_AUTOCOMPLETE_GENERATION": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, + "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, "ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION, "ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION, @@ -1678,6 +1696,8 @@ class TaskConfigForm(BaseModel): TASK_MODEL: Optional[str] TASK_MODEL_EXTERNAL: Optional[str] TITLE_GENERATION_PROMPT_TEMPLATE: str + ENABLE_AUTOCOMPLETE_GENERATION: bool + AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int TAGS_GENERATION_PROMPT_TEMPLATE: str ENABLE_TAGS_GENERATION: bool ENABLE_SEARCH_QUERY_GENERATION: bool @@ -1693,6 +1713,14 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = ( form_data.TITLE_GENERATION_PROMPT_TEMPLATE ) + + app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ( + form_data.ENABLE_AUTOCOMPLETE_GENERATION + ) + app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( + form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH + ) + app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = ( form_data.TAGS_GENERATION_PROMPT_TEMPLATE ) @@ -1715,6 +1743,8 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u "TASK_MODEL": app.state.config.TASK_MODEL, "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "ENABLE_AUTOCOMPLETE_GENERATION": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, + "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, "ENABLE_TAGS_GENERATION": app.state.config.ENABLE_TAGS_GENERATION, "ENABLE_SEARCH_QUERY_GENERATION": app.state.config.ENABLE_SEARCH_QUERY_GENERATION, @@ -1942,7 +1972,7 @@ async def generate_queries(form_data: dict, user=Depends(get_verified_user)): f"generating {type} queries using model {task_model_id} for user {user.email}" ) - if app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE != "": + if (app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "": template = app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE else: template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE @@ -1982,6 +2012,88 @@ async def generate_queries(form_data: dict, user=Depends(get_verified_user)): return await generate_chat_completions(form_data=payload, user=user) +@app.post("/api/task/auto/completions") +async def generate_autocompletion(form_data: dict, user=Depends(get_verified_user)): + if not app.state.config.ENABLE_AUTOCOMPLETE_GENERATION: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Autocompletion generation is disabled", + ) + + type = form_data.get("type") + prompt = form_data.get("prompt") + messages = form_data.get("messages") + + if app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0: + if len(prompt) > app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Input prompt exceeds maximum length of {app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}", + ) + + model_list = await get_all_models() + models = {model["id"]: model for model in model_list} + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + app.state.config.TASK_MODEL, + app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug( + f"generating autocompletion using model {task_model_id} for user {user.email}" + ) + + if (app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "": + template = app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE + else: + template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE + + content = autocomplete_generation_template( + template, prompt, messages, type, {"name": user.name} + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "metadata": { + "task": str(TASKS.AUTOCOMPLETE_GENERATION), + "task_body": form_data, + "chat_id": form_data.get("chat_id", None), + }, + } + + # Handle pipeline filters + try: + payload = filter_pipeline(payload, user, models) + except Exception as e: + if len(e.args) > 1: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + else: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + if "chat_id" in payload: + del payload["chat_id"] + + return await generate_chat_completions(form_data=payload, user=user) + + @app.post("/api/task/emoji/completions") async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): diff --git a/backend/open_webui/static/assets/pdf-style.css b/backend/open_webui/static/assets/pdf-style.css index db9ac83dd..85c36271c 100644 --- a/backend/open_webui/static/assets/pdf-style.css +++ b/backend/open_webui/static/assets/pdf-style.css @@ -26,7 +26,7 @@ html { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'NotoSans', 'NotoSansJP', 'NotoSansKR', - 'NotoSansSC', 'STSong-Light', 'MSung-Light', 'HeiseiMin-W3', 'HYSMyeongJo-Medium', Roboto, + 'NotoSansSC', 'Twemoji', 'STSong-Light', 'MSung-Light', 'HeiseiMin-W3', 'HYSMyeongJo-Medium', Roboto, 'Helvetica Neue', Arial, sans-serif; font-size: 14px; /* Default font size */ line-height: 1.5; @@ -40,7 +40,7 @@ html { body { margin: 0; - color: #212529; + padding: 0; background-color: #fff; width: auto; } diff --git a/backend/open_webui/static/fonts/Twemoji.ttf b/backend/open_webui/static/fonts/Twemoji.ttf new file mode 100644 index 000000000..281d356d9 Binary files /dev/null and b/backend/open_webui/static/fonts/Twemoji.ttf differ diff --git a/backend/open_webui/test/apps/webui/routers/test_auths.py b/backend/open_webui/test/apps/webui/routers/test_auths.py index bc14fb8dd..cee68228e 100644 --- a/backend/open_webui/test/apps/webui/routers/test_auths.py +++ b/backend/open_webui/test/apps/webui/routers/test_auths.py @@ -26,7 +26,7 @@ class TestAuths(AbstractPostgresTest): } def test_update_profile(self): - from open_webui.utils.utils import get_password_hash + from open_webui.utils.auth import get_password_hash user = self.auths.insert_new_auth( email="john.doe@openwebui.com", @@ -47,7 +47,7 @@ class TestAuths(AbstractPostgresTest): assert db_user.profile_image_url == "/user2.png" def test_update_password(self): - from open_webui.utils.utils import get_password_hash + from open_webui.utils.auth import get_password_hash user = self.auths.insert_new_auth( email="john.doe@openwebui.com", @@ -74,7 +74,7 @@ class TestAuths(AbstractPostgresTest): assert new_auth is not None def test_signin(self): - from open_webui.utils.utils import get_password_hash + from open_webui.utils.auth import get_password_hash user = self.auths.insert_new_auth( email="john.doe@openwebui.com", diff --git a/backend/open_webui/test/util/mock_user.py b/backend/open_webui/test/util/mock_user.py index 96456a2c8..ba8e24d4e 100644 --- a/backend/open_webui/test/util/mock_user.py +++ b/backend/open_webui/test/util/mock_user.py @@ -13,7 +13,7 @@ def mock_webui_user(**kwargs): @contextmanager def mock_user(app: FastAPI, **kwargs): - from open_webui.utils.utils import ( + from open_webui.utils.auth import ( get_current_user, get_verified_user, get_admin_user, diff --git a/backend/open_webui/utils/utils.py b/backend/open_webui/utils/auth.py similarity index 100% rename from backend/open_webui/utils/utils.py rename to backend/open_webui/utils/auth.py diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 722b1ea73..37dc5b788 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -26,6 +26,7 @@ from open_webui.config import ( OAUTH_USERNAME_CLAIM, OAUTH_ALLOWED_ROLES, OAUTH_ADMIN_ROLES, + OAUTH_ALLOWED_DOMAINS, WEBHOOK_URL, JWT_EXPIRES_IN, AppConfig, @@ -33,7 +34,7 @@ from open_webui.config import ( from open_webui.constants import ERROR_MESSAGES from open_webui.env import WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE from open_webui.utils.misc import parse_duration -from open_webui.utils.utils import get_password_hash, create_token +from open_webui.utils.auth import get_password_hash, create_token from open_webui.utils.webhook import post_webhook log = logging.getLogger(__name__) @@ -49,6 +50,7 @@ auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM auth_manager_config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES auth_manager_config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES +auth_manager_config.OAUTH_ALLOWED_DOMAINS = OAUTH_ALLOWED_DOMAINS auth_manager_config.WEBHOOK_URL = WEBHOOK_URL auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN @@ -156,6 +158,9 @@ class OAuthManager: if not email: log.warning(f"OAuth callback failed, email is missing: {user_data}") raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + if "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS and email.split("@")[-1] not in auth_manager_config.OAUTH_ALLOWED_DOMAINS: + log.warning(f"OAuth callback failed, e-mail domain is not in the list of allowed domains: {user_data}") + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) # Check if the user exists user = Users.get_user_by_oauth_sub(provider_sub) @@ -253,9 +258,18 @@ class OAuthManager: secure=WEBUI_SESSION_COOKIE_SECURE, ) + if ENABLE_OAUTH_SIGNUP.value: + oauth_id_token = token.get("id_token") + response.set_cookie( + key="oauth_id_token", + value=oauth_id_token, + httponly=True, + samesite=WEBUI_SESSION_COOKIE_SAME_SITE, + secure=WEBUI_SESSION_COOKIE_SECURE, + ) # Redirect back to the frontend with the JWT token redirect_url = f"{request.base_url}auth#token={jwt_token}" - return RedirectResponse(url=redirect_url) + return RedirectResponse(url=redirect_url, headers=response.headers) oauth_manager = OAuthManager() diff --git a/backend/open_webui/utils/pdf_generator.py b/backend/open_webui/utils/pdf_generator.py index fb6cd57d5..6b7d506e6 100644 --- a/backend/open_webui/utils/pdf_generator.py +++ b/backend/open_webui/utils/pdf_generator.py @@ -51,21 +51,25 @@ class PDFGenerator: # extends pymdownx extension to convert markdown to html. # - https://facelessuser.github.io/pymdown-extensions/usage_notes/ - html_content = markdown(content, extensions=["pymdownx.extra"]) + # html_content = markdown(content, extensions=["pymdownx.extra"]) html_message = f""" -
{date_str}
-
+
-

+

{role.title()} - {model} -

+ {model} + +
{date_str}
-
+                
+
+ +
{content} -
+
+
""" return html_message @@ -74,18 +78,15 @@ class PDFGenerator: return f""" - - + -
-
-

{self.form_data.title}

-
-
- {self.messages_html} -
+
+
+

{self.form_data.title}

+ {self.messages_html}
+
""" @@ -114,9 +115,12 @@ class PDFGenerator: pdf.add_font("NotoSansKR", "", f"{FONTS_DIR}/NotoSansKR-Regular.ttf") pdf.add_font("NotoSansJP", "", f"{FONTS_DIR}/NotoSansJP-Regular.ttf") pdf.add_font("NotoSansSC", "", f"{FONTS_DIR}/NotoSansSC-Regular.ttf") + pdf.add_font("Twemoji", "", f"{FONTS_DIR}/Twemoji.ttf") pdf.set_font("NotoSans", size=12) - pdf.set_fallback_fonts(["NotoSansKR", "NotoSansJP", "NotoSansSC"]) + pdf.set_fallback_fonts( + ["NotoSansKR", "NotoSansJP", "NotoSansSC", "Twemoji"] + ) pdf.set_auto_page_break(auto=True, margin=15) diff --git a/backend/open_webui/utils/security_headers.py b/backend/open_webui/utils/security_headers.py index bcef773a5..fbcf7d697 100644 --- a/backend/open_webui/utils/security_headers.py +++ b/backend/open_webui/utils/security_headers.py @@ -27,6 +27,7 @@ def set_security_headers() -> Dict[str, str]: - x-download-options - x-frame-options - x-permitted-cross-domain-policies + - content-security-policy Each environment variable is associated with a specific setter function that constructs the header. If the environment variable is set, the @@ -45,6 +46,7 @@ def set_security_headers() -> Dict[str, str]: "XDOWNLOAD_OPTIONS": set_xdownload_options, "XFRAME_OPTIONS": set_xframe, "XPERMITTED_CROSS_DOMAIN_POLICIES": set_xpermitted_cross_domain_policies, + "CONTENT_SECURITY_POLICY": set_content_security_policy, } for env_var, setter in header_setters.items(): @@ -124,3 +126,8 @@ def set_xpermitted_cross_domain_policies(value: str): if not match: value = "none" return {"X-Permitted-Cross-Domain-Policies": value} + + +# Set Content-Security-Policy response header +def set_content_security_policy(value: str): + return {"Content-Security-Policy": value} diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index 3b71ba746..604161a31 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -53,7 +53,9 @@ def prompt_template( def replace_prompt_variable(template: str, prompt: str) -> str: def replacement_function(match): - full_match = match.group(0) + full_match = match.group( + 0 + ).lower() # Normalize to lowercase for consistent handling start_length = match.group(1) end_length = match.group(2) middle_length = match.group(3) @@ -73,20 +75,23 @@ def replace_prompt_variable(template: str, prompt: str) -> str: return f"{start}...{end}" return "" - template = re.sub( - r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}", - replacement_function, - template, - ) + # Updated regex pattern to make it case-insensitive with the `(?i)` flag + pattern = r"(?i){{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}" + template = re.sub(pattern, replacement_function, template) return template -def replace_messages_variable(template: str, messages: list[str]) -> str: +def replace_messages_variable( + template: str, messages: Optional[list[str]] = None +) -> str: def replacement_function(match): full_match = match.group(0) start_length = match.group(1) end_length = match.group(2) middle_length = match.group(3) + # If messages is None, handle it as an empty list + if messages is None: + return "" # Process messages based on the number of messages required if full_match == "{{MESSAGES}}": @@ -122,7 +127,7 @@ def replace_messages_variable(template: str, messages: list[str]) -> str: def rag_template(template: str, context: str, query: str): - if template == "": + if template.strip() == "": template = DEFAULT_RAG_TEMPLATE if "[context]" not in template and "{{CONTEXT}}" not in template: @@ -212,6 +217,28 @@ def emoji_generation_template( return template +def autocomplete_generation_template( + template: str, + prompt: str, + messages: Optional[list[dict]] = None, + type: Optional[str] = None, + user: Optional[dict] = None, +) -> str: + template = template.replace("{{TYPE}}", type if type else "") + template = replace_prompt_variable(template, prompt) + template = replace_messages_variable(template, messages) + + template = prompt_template( + template, + **( + {"user_name": user.get("name"), "user_location": user.get("location")} + if user + else {} + ), + ) + return template + + def query_generation_template( template: str, messages: list[dict], user: Optional[dict] = None ) -> str: diff --git a/backend/requirements.txt b/backend/requirements.txt index c83e6b3b7..79e898c6a 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,7 +1,7 @@ fastapi==0.111.0 uvicorn[standard]==0.30.6 pydantic==2.9.2 -python-multipart==0.0.17 +python-multipart==0.0.18 Flask==3.0.3 Flask-Cors==5.0.0 @@ -11,13 +11,13 @@ python-jose==3.3.0 passlib[bcrypt]==1.7.4 requests==2.32.3 -aiohttp==3.10.8 +aiohttp==3.11.8 async-timeout aiocache aiofiles sqlalchemy==2.0.32 -alembic==1.13.2 +alembic==1.14.0 peewee==3.17.6 peewee-migrate==1.12.2 psycopg2-binary==2.9.9 @@ -44,11 +44,11 @@ langchain-chroma==0.1.4 fake-useragent==1.5.1 chromadb==0.5.15 -pymilvus==2.4.9 +pymilvus==2.5.0 qdrant-client~=1.12.0 opensearch-py==2.7.1 -sentence-transformers==3.2.0 +sentence-transformers==3.3.1 colbert-ai==0.2.21 einops==0.8.0 diff --git a/package-lock.json b/package-lock.json index eaa39b6db..020cd0f53 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "0.4.6", + "version": "0.4.8", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.4.6", + "version": "0.4.8", "dependencies": { "@codemirror/lang-javascript": "^6.2.2", "@codemirror/lang-python": "^6.1.6", @@ -1836,9 +1836,10 @@ } }, "node_modules/@polka/url": { - "version": "1.0.0-next.25", - "resolved": "https://registry.npmjs.org/@polka/url/-/url-1.0.0-next.25.tgz", - "integrity": "sha512-j7P6Rgr3mmtdkeDGTe0E/aYyWEWVtc5yFXtHCRHs28/jptDEWfaVOc5T7cblqy1XKPPfCxJc/8DwQ5YgLOZOVQ==" + "version": "1.0.0-next.28", + "resolved": "https://registry.npmjs.org/@polka/url/-/url-1.0.0-next.28.tgz", + "integrity": "sha512-8LduaNlMZGwdZ6qWrKlfa+2M4gahzFkprZiAt2TF8uS0qQgBizKXpXURqvTJ4WtmupWxaLqjRb2UCTe72mu+Aw==", + "license": "MIT" }, "node_modules/@popperjs/core": { "version": "2.11.8", @@ -2248,31 +2249,33 @@ } }, "node_modules/@sveltejs/adapter-static": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/@sveltejs/adapter-static/-/adapter-static-3.0.2.tgz", - "integrity": "sha512-/EBFydZDwfwFfFEuF1vzUseBoRziwKP7AoHAwv+Ot3M084sE/HTVBHf9mCmXfdM9ijprY5YEugZjleflncX5fQ==", + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/@sveltejs/adapter-static/-/adapter-static-3.0.6.tgz", + "integrity": "sha512-MGJcesnJWj7FxDcB/GbrdYD3q24Uk0PIL4QIX149ku+hlJuj//nxUbb0HxUTpjkecWfHjVveSUnUaQWnPRXlpg==", "dev": true, + "license": "MIT", "peerDependencies": { "@sveltejs/kit": "^2.0.0" } }, "node_modules/@sveltejs/kit": { - "version": "2.6.2", - "resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.6.2.tgz", - "integrity": "sha512-ruogrSPXjckn5poUiZU8VYNCSPHq66SFR1AATvOikQxtP6LNI4niAZVX/AWZRe/EPDG3oY2DNJ9c5z7u0t2NAQ==", + "version": "2.9.0", + "resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.9.0.tgz", + "integrity": "sha512-W3E7ed3ChB6kPqRs2H7tcHp+Z7oiTFC6m+lLyAQQuyXeqw6LdNuuwEUla+5VM0OGgqQD+cYD6+7Xq80vVm17Vg==", "hasInstallScript": true, + "license": "MIT", "dependencies": { "@types/cookie": "^0.6.0", - "cookie": "^0.7.0", + "cookie": "^0.6.0", "devalue": "^5.1.0", - "esm-env": "^1.0.0", + "esm-env": "^1.2.1", "import-meta-resolve": "^4.1.0", "kleur": "^4.1.5", "magic-string": "^0.30.5", "mrmime": "^2.0.0", "sade": "^1.8.1", "set-cookie-parser": "^2.6.0", - "sirv": "^2.0.4", + "sirv": "^3.0.0", "tiny-glob": "^0.2.9" }, "bin": { @@ -2282,9 +2285,9 @@ "node": ">=18.13" }, "peerDependencies": { - "@sveltejs/vite-plugin-svelte": "^3.0.0 || ^4.0.0-next.1", + "@sveltejs/vite-plugin-svelte": "^3.0.0 || ^4.0.0-next.1 || ^5.0.0", "svelte": "^4.0.0 || ^5.0.0-next.0", - "vite": "^5.0.3" + "vite": "^5.0.3 || ^6.0.0" } }, "node_modules/@sveltejs/vite-plugin-svelte": { @@ -4391,9 +4394,10 @@ "dev": true }, "node_modules/cookie": { - "version": "0.7.1", - "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.1.tgz", - "integrity": "sha512-6DnInpx7SJ2AK3+CTUE/ZM0vWTUboZCegxhC2xiIydHR9jNuTAASBrfEpHhiGOZw/nX51bHt6YQl8jsGo4y/0w==", + "version": "0.6.0", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.6.0.tgz", + "integrity": "sha512-U71cyTamuh1CRNCfpGY6to28lxvNwPG4Guz/EVjgf3Jmzv0vlDp1atT9eS5dDjMYHucpHbWns6Lwf3BKz6svdw==", + "license": "MIT", "engines": { "node": ">= 0.6" } @@ -5690,9 +5694,10 @@ } }, "node_modules/esm-env": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/esm-env/-/esm-env-1.0.0.tgz", - "integrity": "sha512-Cf6VksWPsTuW01vU9Mk/3vRue91Zevka5SjyNf3nEpokFRuqt/KjUQoGAwq9qMmhpLTHmXzSIrFRw8zxWzmFBA==" + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/esm-env/-/esm-env-1.2.1.tgz", + "integrity": "sha512-U9JedYYjCnadUlXk7e1Kr+aENQhtUaoaV9+gZm1T8LC/YBAPJx3NSPIAurFOC0U5vrdSevnUJS2/wUVxGwPhng==", + "license": "MIT" }, "node_modules/espree": { "version": "9.6.1", @@ -8228,6 +8233,7 @@ "version": "2.0.0", "resolved": "https://registry.npmjs.org/mrmime/-/mrmime-2.0.0.tgz", "integrity": "sha512-eu38+hdgojoyq63s+yTpN4XMBdt5l8HhMhc4VKLO9KM5caLIBvUm4thi7fFaxyTmCKeNnXZ5pAlBwCUnhA09uw==", + "license": "MIT", "engines": { "node": ">=10" } @@ -10359,16 +10365,17 @@ } }, "node_modules/sirv": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/sirv/-/sirv-2.0.4.tgz", - "integrity": "sha512-94Bdh3cC2PKrbgSOUqTiGPWVZeSiXfKOVZNJniWoqrWrRkB1CJzBU3NEbiTsPcYy1lDsANA/THzS+9WBiy5nfQ==", + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/sirv/-/sirv-3.0.0.tgz", + "integrity": "sha512-BPwJGUeDaDCHihkORDchNyyTvWFhcusy1XMmhEVTQTwGeybFbp8YEmB+njbPnth1FibULBSBVwCQni25XlCUDg==", + "license": "MIT", "dependencies": { "@polka/url": "^1.0.0-next.24", "mrmime": "^2.0.0", "totalist": "^3.0.0" }, "engines": { - "node": ">= 10" + "node": ">=18" } }, "node_modules/slash": { @@ -11260,6 +11267,7 @@ "version": "3.0.1", "resolved": "https://registry.npmjs.org/totalist/-/totalist-3.0.1.tgz", "integrity": "sha512-sf4i37nQ2LBx4m3wB74y+ubopq6W/dIzXg0FDGjsYnZHVa1Da8FH853wlL2gtUhg+xJXjfk3kUZS3BRoQeoQBQ==", + "license": "MIT", "engines": { "node": ">=6" } diff --git a/package.json b/package.json index 80e7b4fc3..c131e1f91 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.4.6", + "version": "0.4.8", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", diff --git a/pyproject.toml b/pyproject.toml index 0dc8e856d..0554baa9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ dependencies = [ "fastapi==0.111.0", "uvicorn[standard]==0.30.6", "pydantic==2.9.2", - "python-multipart==0.0.17", + "python-multipart==0.0.18", "Flask==3.0.3", "Flask-Cors==5.0.0", @@ -19,13 +19,13 @@ dependencies = [ "passlib[bcrypt]==1.7.4", "requests==2.32.3", - "aiohttp==3.10.8", + "aiohttp==3.11.8", "async-timeout", "aiocache", "aiofiles", "sqlalchemy==2.0.32", - "alembic==1.13.2", + "alembic==1.14.0", "peewee==3.17.6", "peewee-migrate==1.12.2", "psycopg2-binary==2.9.9", @@ -51,11 +51,11 @@ dependencies = [ "fake-useragent==1.5.1", "chromadb==0.5.15", - "pymilvus==2.4.9", + "pymilvus==2.5.0", "qdrant-client~=1.12.0", "opensearch-py==2.7.1", - "sentence-transformers==3.2.0", + "sentence-transformers==3.3.1", "colbert-ai==0.2.21", "einops==0.8.0", diff --git a/src/app.css b/src/app.css index 659498add..cf0afea4f 100644 --- a/src/app.css +++ b/src/app.css @@ -45,15 +45,15 @@ math { } .input-prose { - @apply prose dark:prose-invert prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; + @apply prose dark:prose-invert prose-headings:font-semibold prose-hr:my-4 prose-hr:border-gray-100 prose-hr:dark:border-gray-800 prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; } .input-prose-sm { - @apply prose dark:prose-invert prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line text-sm; + @apply prose dark:prose-invert prose-headings:font-semibold prose-hr:my-4 prose-hr:border-gray-100 prose-hr:dark:border-gray-800 prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line text-sm; } .markdown-prose { - @apply prose dark:prose-invert prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; + @apply prose dark:prose-invert prose-headings:font-semibold prose-hr:my-4 prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; } .markdown a { @@ -211,7 +211,15 @@ input[type='number'] { float: left; color: #adb5bd; pointer-events: none; - height: 0; + + @apply line-clamp-1 absolute; +} + +.ai-autocompletion::after { + color: #a0a0a0; + + content: attr(data-suggestion); + pointer-events: none; } .tiptap > pre > code { diff --git a/src/app.html b/src/app.html index f6e46c9cf..537e28dbe 100644 --- a/src/app.html +++ b/src/app.html @@ -2,9 +2,12 @@ - - - + + + + + + { + const controller = new AbortController(); + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/task/auto/completions`, { + signal: controller.signal, + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + model: model, + prompt: prompt, + ...(messages && { messages: messages }), + type: type, + stream: false + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } + return null; + }); + + if (error) { + throw error; + } + + const response = res?.choices[0]?.message?.content ?? ''; + + try { + const jsonStartIndex = response.indexOf('{'); + const jsonEndIndex = response.lastIndexOf('}'); + + if (jsonStartIndex !== -1 && jsonEndIndex !== -1) { + const jsonResponse = response.substring(jsonStartIndex, jsonEndIndex + 1); + + // Step 5: Parse the JSON block + const parsed = JSON.parse(jsonResponse); + + // Step 6: If there's a "queries" key, return the queries array; otherwise, return an empty array + if (parsed && parsed.text) { + return parsed.text; + } else { + return ''; + } + } + + // If no valid JSON block found, return response as is + return response; + } catch (e) { + // Catch and safely return empty array on any parsing errors + console.error('Failed to parse response: ', e); + return response; + } +}; + export const generateMoACompletion = async ( token: string = '', model: string, diff --git a/src/lib/apis/retrieval/index.ts b/src/lib/apis/retrieval/index.ts index 6c6b18b9f..21ae792fa 100644 --- a/src/lib/apis/retrieval/index.ts +++ b/src/lib/apis/retrieval/index.ts @@ -40,6 +40,7 @@ type ContentExtractConfigForm = { type YoutubeConfigForm = { language: string[]; translation?: string | null; + proxy_url: string; }; type RAGConfigForm = { diff --git a/src/lib/components/admin/Settings/Images.svelte b/src/lib/components/admin/Settings/Images.svelte index c76e192bf..b0492f24b 100644 --- a/src/lib/components/admin/Settings/Images.svelte +++ b/src/lib/components/admin/Settings/Images.svelte @@ -105,10 +105,15 @@ }; const updateConfigHandler = async () => { - const res = await updateConfig(localStorage.token, config).catch((error) => { - toast.error(error); - return null; - }); + const res = await updateConfig(localStorage.token, config) + .catch((error) => { + toast.error(error); + return null; + }) + .catch((error) => { + toast.error(error); + return null; + }); if (res) { config = res; diff --git a/src/lib/components/admin/Settings/Interface.svelte b/src/lib/components/admin/Settings/Interface.svelte index 2fee518ee..9c669dae5 100644 --- a/src/lib/components/admin/Settings/Interface.svelte +++ b/src/lib/components/admin/Settings/Interface.svelte @@ -24,6 +24,8 @@ TASK_MODEL: '', TASK_MODEL_EXTERNAL: '', TITLE_GENERATION_PROMPT_TEMPLATE: '', + ENABLE_AUTOCOMPLETE_GENERATION: true, + AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: -1, TAGS_GENERATION_PROMPT_TEMPLATE: '', ENABLE_TAGS_GENERATION: true, ENABLE_SEARCH_QUERY_GENERATION: true, @@ -138,11 +140,42 @@
-
+
- {$i18n.t('Enable Tags Generation')} + {$i18n.t('Autocomplete Generation')} +
+ + + + +
+ + {#if taskConfig.ENABLE_AUTOCOMPLETE_GENERATION} +
+
+ {$i18n.t('Autocomplete Generation Input Max Length')} +
+ + + + +
+ {/if} + +
+ +
+
+ {$i18n.t('Tags Generation')}
@@ -166,11 +199,11 @@
{/if} -
+
- {$i18n.t('Enable Retrieval Query Generation')} + {$i18n.t('Retrieval Query Generation')}
@@ -178,7 +211,7 @@
- {$i18n.t('Enable Web Search Query Generation')} + {$i18n.t('Web Search Query Generation')}
@@ -201,7 +234,7 @@
-
+
diff --git a/src/lib/components/admin/Settings/Models.svelte b/src/lib/components/admin/Settings/Models.svelte index 06058fa9f..f084de65a 100644 --- a/src/lib/components/admin/Settings/Models.svelte +++ b/src/lib/components/admin/Settings/Models.svelte @@ -44,10 +44,10 @@ filteredModels = models .filter((m) => searchValue === '' || m.name.toLowerCase().includes(searchValue.toLowerCase())) .sort((a, b) => { - // Check if either model is inactive and push them to the bottom - if ((a.is_active ?? true) !== (b.is_active ?? true)) { - return (b.is_active ?? true) - (a.is_active ?? true); - } + // // Check if either model is inactive and push them to the bottom + // if ((a.is_active ?? true) !== (b.is_active ?? true)) { + // return (b.is_active ?? true) - (a.is_active ?? true); + // } // If both models' active states are the same, sort alphabetically return a.name.localeCompare(b.name); }); @@ -137,7 +137,7 @@ }); - + {#if models !== null} {#if selectedModelId === null} diff --git a/src/lib/components/admin/Settings/Models/ConfigureModelsModal.svelte b/src/lib/components/admin/Settings/Models/ConfigureModelsModal.svelte index 761ffefba..23865c184 100644 --- a/src/lib/components/admin/Settings/Models/ConfigureModelsModal.svelte +++ b/src/lib/components/admin/Settings/Models/ConfigureModelsModal.svelte @@ -18,7 +18,7 @@ import Plus from '$lib/components/icons/Plus.svelte'; export let show = false; - export let init = () => {}; + export let initHandler = () => {}; let config = null; @@ -29,26 +29,11 @@ let loading = false; let showResetModal = false; - const submitHandler = async () => { - loading = true; + $: if (show) { + init(); + } - const res = await setModelsConfig(localStorage.token, { - DEFAULT_MODELS: defaultModelIds.join(','), - MODEL_ORDER_LIST: modelIds - }); - - if (res) { - toast.success($i18n.t('Models configuration saved successfully')); - init(); - show = false; - } else { - toast.error($i18n.t('Failed to save models configuration')); - } - - loading = false; - }; - - onMount(async () => { + const init = async () => { config = await getModelsConfig(localStorage.token); if (config?.DEFAULT_MODELS) { @@ -68,6 +53,28 @@ // Add remaining IDs not in MODEL_ORDER_LIST, sorted alphabetically ...allModelIds.filter((id) => !orderedSet.has(id)).sort((a, b) => a.localeCompare(b)) ]; + }; + const submitHandler = async () => { + loading = true; + + const res = await setModelsConfig(localStorage.token, { + DEFAULT_MODELS: defaultModelIds.join(','), + MODEL_ORDER_LIST: modelIds + }); + + if (res) { + toast.success($i18n.t('Models configuration saved successfully')); + initHandler(); + show = false; + } else { + toast.error($i18n.t('Failed to save models configuration')); + } + + loading = false; + }; + + onMount(async () => { + init(); }); @@ -79,7 +86,7 @@ const res = deleteAllModels(localStorage.token); if (res) { toast.success($i18n.t('All models deleted successfully')); - init(); + initHandler(); } }} /> @@ -213,6 +220,7 @@ showResetModal = true; }} > + {$i18n.t('Reset All Models')} diff --git a/src/lib/components/admin/Settings/WebSearch.svelte b/src/lib/components/admin/Settings/WebSearch.svelte index d8b1a33d1..58eb09da3 100644 --- a/src/lib/components/admin/Settings/WebSearch.svelte +++ b/src/lib/components/admin/Settings/WebSearch.svelte @@ -16,6 +16,7 @@ 'searxng', 'google_pse', 'brave', + 'kagi', 'mojeek', 'serpstack', 'serper', @@ -29,13 +30,15 @@ let youtubeLanguage = 'en'; let youtubeTranslation = null; + let youtubeProxyUrl = ''; const submitHandler = async () => { const res = await updateRAGConfig(localStorage.token, { web: webConfig, youtube: { language: youtubeLanguage.split(',').map((lang) => lang.trim()), - translation: youtubeTranslation + translation: youtubeTranslation, + proxy_url: youtubeProxyUrl } }); }; @@ -48,6 +51,7 @@ youtubeLanguage = res.youtube.language.join(','); youtubeTranslation = res.youtube.translation; + youtubeProxyUrl = res.youtube.proxy_url; } }); @@ -152,6 +156,17 @@ bind:value={webConfig.search.brave_search_api_key} />
+ {:else if webConfig.search.engine === 'kagi'} +
+
+ {$i18n.t('Kagi Search API Key')} +
+ + +
{:else if webConfig.search.engine === 'mojeek'}
@@ -358,6 +373,21 @@
+ +
+
+
{$i18n.t('Proxy URL')}
+
+ +
+
+
{/if} diff --git a/src/lib/components/admin/Users/UserList/UserChatsModal.svelte b/src/lib/components/admin/Users/UserList/UserChatsModal.svelte index 4c8447829..7cf03b4b7 100644 --- a/src/lib/components/admin/Users/UserList/UserChatsModal.svelte +++ b/src/lib/components/admin/Users/UserList/UserChatsModal.svelte @@ -9,13 +9,14 @@ import Modal from '$lib/components/common/Modal.svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte'; + import Spinner from '$lib/components/common/Spinner.svelte'; const i18n = getContext('i18n'); export let show = false; export let user; - let chats = []; + let chats = null; const deleteChatHandler = async (chatId) => { const res = await deleteChatById(localStorage.token, chatId).catch((error) => { @@ -31,6 +32,8 @@ chats = await getChatListByUserId(localStorage.token, user.id); } })(); + } else { + chats = null; } let sortKey = 'updated_at'; // default sort key @@ -46,33 +49,32 @@ -
-
-
- {$i18n.t("{{user}}'s Chats", { user: user.name })} -
- +
+
+ {$i18n.t("{{user}}'s Chats", { user: user.name })}
-
+ +
-
-
+
+
+ {#if chats} {#if chats.length > 0}
@@ -176,7 +178,9 @@ {$i18n.t('has no conversations.')}
{/if} -
+ {:else} + + {/if}
diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 03080d9b5..e6a653420 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -2284,7 +2284,7 @@
-
+
- {$i18n.t('LLMs can make mistakes. Verify important information.')} +
{:else} diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 16e3cdb91..800059055 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -18,7 +18,7 @@ showControls } from '$lib/stores'; - import { blobToFile, findWordIndices } from '$lib/utils'; + import { blobToFile, createMessagesList, findWordIndices } from '$lib/utils'; import { transcribeAudio } from '$lib/apis/audio'; import { uploadFile } from '$lib/apis/files'; import { getTools } from '$lib/apis/tools'; @@ -34,6 +34,8 @@ import Commands from './MessageInput/Commands.svelte'; import XMark from '../icons/XMark.svelte'; import RichTextInput from '../common/RichTextInput.svelte'; + import { generateAutoCompletion } from '$lib/apis'; + import { error, text } from '@sveltejs/kit'; const i18n = getContext('i18n'); @@ -47,6 +49,9 @@ export let atSelectedModel: Model | undefined; export let selectedModels: ['']; + let selectedModelIds = []; + $: selectedModelIds = atSelectedModel !== undefined ? [atSelectedModel.id] : selectedModels; + export let history; export let prompt = ''; @@ -266,8 +271,8 @@ {#if loaded}
-
-
+
+
{#if autoScroll === false && history?.currentId}
{#if atSelectedModel !== undefined || selectedToolIds.length > 0 || webSearchEnabled}
{#if selectedToolIds.length > 0}
@@ -405,7 +410,7 @@
-
+
{#if files.length > 0} @@ -542,7 +547,7 @@ {/if}
-
+
@@ -577,10 +582,11 @@ {#if $settings?.richTextInput ?? true}
{ + if (selectedModelIds.length === 0 || !selectedModelIds.at(0)) { + toast.error($i18n.t('Please select a model first.')); + } + + const res = await generateAutoCompletion( + localStorage.token, + selectedModelIds.at(0), + text, + history?.currentId + ? createMessagesList(history, history.currentId) + : null + ).catch((error) => { + console.log(error); + + return null; + }); + + console.log(res); + return res; + }} on:keydown={async (e) => { e = e.detail.event; @@ -754,7 +781,7 @@