diff --git a/CHANGELOG.md b/CHANGELOG.md index d19e82c39..bfff72eed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,81 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.3.4] - 2024-06-12 + +### Fixed + +- **🔒 Mixed Content with HTTPS Issue**: Resolved a problem where mixed content (HTTP and HTTPS) was causing security warnings and blocking resources on HTTPS sites. +- **🔍 Web Search Issue**: Addressed the problem where web search functionality was not working correctly. The 'ENABLE_RAG_LOCAL_WEB_FETCH' option has been reintroduced to restore proper web searching capabilities. +- **💾 RAG Template Not Being Saved**: Fixed an issue where the RAG template was not being saved correctly, ensuring your custom templates are now preserved as expected. + +## [0.3.3] - 2024-06-12 + +### Added + +- **🛠️ Native Python Function Calling**: Introducing native Python function calling within Open WebUI. We’ve also included a built-in code editor to seamlessly develop and integrate function code within the 'Tools' workspace. With this, you can significantly enhance your LLM’s capabilities by creating custom RAG pipelines, web search tools, and even agent-like features such as sending Discord messages. +- **🌐 DuckDuckGo Integration**: Added DuckDuckGo as a web search provider, giving you more search options. +- **🌏 Enhanced Translations**: Improved translations for Vietnamese and Chinese languages, making the interface more accessible. + +### Fixed + +- **🔗 Web Search URL Error Handling**: Fixed the issue where a single URL error would disrupt the data loading process in Web Search mode. Now, such errors will be handled gracefully to ensure uninterrupted data loading. +- **🖥️ Frontend Responsiveness**: Resolved the problem where the frontend would stop responding if the backend encounters an error while downloading a model. Improved error handling to maintain frontend stability. +- **🔧 Dependency Issues in pip**: Fixed issues related to pip installations, ensuring all dependencies are correctly managed to prevent installation errors. + +## [0.3.2] - 2024-06-10 + +### Added + +- **🔍 Web Search Query Status**: The web search query will now persist in the results section to aid in easier debugging and tracking of search queries. +- **🌐 New Web Search Provider**: We have added Serply as a new option for web search providers, giving you more choices for your search needs. +- **🌏 Improved Translations**: We've enhanced translations for Chinese and Portuguese. + +### Fixed + +- **🎤 Audio File Upload Issue**: The bug that prevented audio files from being uploaded in chat input has been fixed, ensuring smooth communication. +- **💬 Message Input Handling**: Improved the handling of message inputs by instantly clearing images and text after sending, along with immediate visual indications when a response message is loading, enhancing user feedback. +- **⚙️ Parameter Registration and Validation**: Fixed the issue where parameters were not registering in certain cases and addressed the problem where users were unable to save due to invalid input errors. + +## [0.3.1] - 2024-06-09 + +### Fixed + +- **💬 Chat Functionality**: Resolved the issue where chat functionality was not working for specific models. + +## [0.3.0] - 2024-06-09 + +### Added + +- **📚 Knowledge Support for Models**: Attach documents directly to models from the models workspace, enhancing the information available to each model. +- **🎙️ Hands-Free Voice Call Feature**: Initiate voice calls without needing to use your hands, making interactions more seamless. +- **📹 Video Call Feature**: Enable video calls with supported vision models like Llava and GPT-4o, adding a visual dimension to your communications. +- **🎛️ Enhanced UI for Voice Recording**: Improved user interface for the voice recording feature, making it more intuitive and user-friendly. +- **🌐 External STT Support**: Now support for external Speech-To-Text services, providing more flexibility in choosing your STT provider. +- **⚙️ Unified Settings**: Consolidated settings including document settings under a new admin settings section for easier management. +- **🌑 Dark Mode Splash Screen**: A new splash screen for dark mode, ensuring a consistent and visually appealing experience for dark mode users. +- **📥 Upload Pipeline**: Directly upload pipelines from the admin settings > pipelines section, streamlining the pipeline management process. +- **🌍 Improved Language Support**: Enhanced support for Chinese and Ukrainian languages, better catering to a global user base. + +### Fixed + +- **🛠️ Playground Issue**: Fixed the playground not functioning properly, ensuring a smoother user experience. +- **🔥 Temperature Parameter Issue**: Corrected the issue where the temperature value '0' was not being passed correctly. +- **📝 Prompt Input Clearing**: Resolved prompt input textarea not being cleared right away, ensuring a clean slate for new inputs. +- **✨ Various UI Styling Issues**: Fixed numerous user interface styling problems for a more cohesive look. +- **👥 Active Users Display**: Fixed active users showing active sessions instead of actual users, now reflecting accurate user activity. +- **🌐 Community Platform Compatibility**: The Community Platform is back online and fully compatible with Open WebUI. + +### Changed + +- **📝 RAG Implementation**: Updated the RAG (Retrieval-Augmented Generation) implementation to use a system prompt for context, instead of overriding the user's prompt. +- **🔄 Settings Relocation**: Moved Models, Connections, Audio, and Images settings to the admin settings for better organization. +- **✍️ Improved Title Generation**: Enhanced the default prompt for title generation, yielding better results. +- **🔧 Backend Task Management**: Tasks like title generation and search query generation are now managed on the backend side and controlled only by the admin. +- **🔍 Editable Search Query Prompt**: You can now edit the search query generation prompt, offering more control over how queries are generated. +- **📏 Prompt Length Threshold**: Set the prompt length threshold for search query generation from the admin settings, giving more customization options. +- **📣 Settings Consolidation**: Merged the Banners admin setting with the Interface admin setting for a more streamlined settings area. + ## [0.2.5] - 2024-06-05 ### Added diff --git a/README.md b/README.md index a8d79bd5c..5f6e4550b 100644 --- a/README.md +++ b/README.md @@ -29,11 +29,15 @@ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature- - ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction. +- 🎤📹 **Hands-Free Voice/Video Call**: Experience seamless communication with integrated hands-free voice and video call features, allowing for a more dynamic and interactive chat environment. + - 🛠️ **Model Builder**: Easily create Ollama models via the Web UI. Create and add custom characters/agents, customize chat elements, and import models effortlessly through [Open WebUI Community](https://openwebui.com/) integration. +- 🐍 **Native Python Function Calling Tool**: Enhance your LLMs with built-in code editor support in the tools workspace. Bring Your Own Function (BYOF) by simply adding your pure Python functions, enabling seamless integration with LLMs. + - 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query. -- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, and `serper`, and inject the results directly into your chat experience. +- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo` and `TavilySearch` and inject the results directly into your chat experience. - 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions. @@ -146,10 +150,19 @@ docker run --rm --volume /var/run/docker.sock:/var/run/docker.sock containrrr/wa In the last part of the command, replace `open-webui` with your container name if it is different. -### Moving from Ollama WebUI to Open WebUI - Check our Migration Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/migration/). +### Using the Dev Branch 🌙 + +> [!WARNING] +> The `:dev` branch contains the latest unstable features and changes. Use it at your own risk as it may have bugs or incomplete features. + +If you want to try out the latest bleeding-edge features and are okay with occasional instability, you can use the `:dev` tag like this: + +```bash +docker run -d -p 3000:8080 -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:dev +``` + ## What's Next? 🌟 Discover upcoming features on our roadmap in the [Open WebUI Documentation](https://docs.openwebui.com/roadmap/). diff --git a/TROUBLESHOOTING.md b/TROUBLESHOOTING.md index 8e8f89da0..9bf242381 100644 --- a/TROUBLESHOOTING.md +++ b/TROUBLESHOOTING.md @@ -18,6 +18,10 @@ If you're experiencing connection issues, it’s often due to the WebUI docker c docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main ``` +### Error on Slow Reponses for Ollama + +Open WebUI has a default timeout of 5 minutes for Ollama to finish generating the response. If needed, this can be adjusted via the environment variable AIOHTTP_CLIENT_TIMEOUT, which sets the timeout in seconds. + ### General Connection Errors **Ensure Ollama Version is Up-to-Date**: Always start by checking that you have the latest version of Ollama. Visit [Ollama's official site](https://ollama.com/) for the latest updates. diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py index 0f65a551e..663e20c97 100644 --- a/backend/apps/audio/main.py +++ b/backend/apps/audio/main.py @@ -17,13 +17,12 @@ from fastapi.middleware.cors import CORSMiddleware from faster_whisper import WhisperModel from pydantic import BaseModel - +import uuid import requests import hashlib from pathlib import Path import json - from constants import ERROR_MESSAGES from utils.utils import ( decode_token, @@ -41,10 +40,15 @@ from config import ( WHISPER_MODEL_DIR, WHISPER_MODEL_AUTO_UPDATE, DEVICE_TYPE, - AUDIO_OPENAI_API_BASE_URL, - AUDIO_OPENAI_API_KEY, - AUDIO_OPENAI_API_MODEL, - AUDIO_OPENAI_API_VOICE, + AUDIO_STT_OPENAI_API_BASE_URL, + AUDIO_STT_OPENAI_API_KEY, + AUDIO_TTS_OPENAI_API_BASE_URL, + AUDIO_TTS_OPENAI_API_KEY, + AUDIO_STT_ENGINE, + AUDIO_STT_MODEL, + AUDIO_TTS_ENGINE, + AUDIO_TTS_MODEL, + AUDIO_TTS_VOICE, AppConfig, ) @@ -61,10 +65,17 @@ app.add_middleware( ) app.state.config = AppConfig() -app.state.config.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL -app.state.config.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY -app.state.config.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL -app.state.config.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE + +app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL +app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY +app.state.config.STT_ENGINE = AUDIO_STT_ENGINE +app.state.config.STT_MODEL = AUDIO_STT_MODEL + +app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL +app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY +app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE +app.state.config.TTS_MODEL = AUDIO_TTS_MODEL +app.state.config.TTS_VOICE = AUDIO_TTS_VOICE # setting device type for whisper model whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu" @@ -74,41 +85,101 @@ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) -class OpenAIConfigUpdateForm(BaseModel): - url: str - key: str - model: str - speaker: str +class TTSConfigForm(BaseModel): + OPENAI_API_BASE_URL: str + OPENAI_API_KEY: str + ENGINE: str + MODEL: str + VOICE: str + + +class STTConfigForm(BaseModel): + OPENAI_API_BASE_URL: str + OPENAI_API_KEY: str + ENGINE: str + MODEL: str + + +class AudioConfigUpdateForm(BaseModel): + tts: TTSConfigForm + stt: STTConfigForm + + +from pydub import AudioSegment +from pydub.utils import mediainfo + + +def is_mp4_audio(file_path): + """Check if the given file is an MP4 audio file.""" + if not os.path.isfile(file_path): + print(f"File not found: {file_path}") + return False + + info = mediainfo(file_path) + if ( + info.get("codec_name") == "aac" + and info.get("codec_type") == "audio" + and info.get("codec_tag_string") == "mp4a" + ): + return True + return False + + +def convert_mp4_to_wav(file_path, output_path): + """Convert MP4 audio file to WAV format.""" + audio = AudioSegment.from_file(file_path, format="mp4") + audio.export(output_path, format="wav") + print(f"Converted {file_path} to {output_path}") @app.get("/config") -async def get_openai_config(user=Depends(get_admin_user)): +async def get_audio_config(user=Depends(get_admin_user)): return { - "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, - "OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL, - "OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE, + "tts": { + "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY, + "ENGINE": app.state.config.TTS_ENGINE, + "MODEL": app.state.config.TTS_MODEL, + "VOICE": app.state.config.TTS_VOICE, + }, + "stt": { + "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY, + "ENGINE": app.state.config.STT_ENGINE, + "MODEL": app.state.config.STT_MODEL, + }, } @app.post("/config/update") -async def update_openai_config( - form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user) +async def update_audio_config( + form_data: AudioConfigUpdateForm, user=Depends(get_admin_user) ): - if form_data.key == "": - raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) + app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL + app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY + app.state.config.TTS_ENGINE = form_data.tts.ENGINE + app.state.config.TTS_MODEL = form_data.tts.MODEL + app.state.config.TTS_VOICE = form_data.tts.VOICE - app.state.config.OPENAI_API_BASE_URL = form_data.url - app.state.config.OPENAI_API_KEY = form_data.key - app.state.config.OPENAI_API_MODEL = form_data.model - app.state.config.OPENAI_API_VOICE = form_data.speaker + app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL + app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY + app.state.config.STT_ENGINE = form_data.stt.ENGINE + app.state.config.STT_MODEL = form_data.stt.MODEL return { - "status": True, - "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, - "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, - "OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL, - "OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE, + "tts": { + "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY, + "ENGINE": app.state.config.TTS_ENGINE, + "MODEL": app.state.config.TTS_MODEL, + "VOICE": app.state.config.TTS_VOICE, + }, + "stt": { + "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY, + "ENGINE": app.state.config.STT_ENGINE, + "MODEL": app.state.config.STT_MODEL, + }, } @@ -125,13 +196,21 @@ async def speech(request: Request, user=Depends(get_verified_user)): return FileResponse(file_path) headers = {} - headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}" + headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}" headers["Content-Type"] = "application/json" + try: + body = body.decode("utf-8") + body = json.loads(body) + body["model"] = app.state.config.TTS_MODEL + body = json.dumps(body).encode("utf-8") + except Exception as e: + pass + r = None try: r = requests.post( - url=f"{app.state.config.OPENAI_API_BASE_URL}/audio/speech", + url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", data=body, headers=headers, stream=True, @@ -181,41 +260,110 @@ def transcribe( ) try: - filename = file.filename - file_path = f"{UPLOAD_DIR}/{filename}" + ext = file.filename.split(".")[-1] + + id = uuid.uuid4() + filename = f"{id}.{ext}" + + file_dir = f"{CACHE_DIR}/audio/transcriptions" + os.makedirs(file_dir, exist_ok=True) + file_path = f"{file_dir}/{filename}" + + print(filename) + contents = file.file.read() with open(file_path, "wb") as f: f.write(contents) f.close() - whisper_kwargs = { - "model_size_or_path": WHISPER_MODEL, - "device": whisper_device_type, - "compute_type": "int8", - "download_root": WHISPER_MODEL_DIR, - "local_files_only": not WHISPER_MODEL_AUTO_UPDATE, - } + if app.state.config.STT_ENGINE == "": + whisper_kwargs = { + "model_size_or_path": WHISPER_MODEL, + "device": whisper_device_type, + "compute_type": "int8", + "download_root": WHISPER_MODEL_DIR, + "local_files_only": not WHISPER_MODEL_AUTO_UPDATE, + } - log.debug(f"whisper_kwargs: {whisper_kwargs}") + log.debug(f"whisper_kwargs: {whisper_kwargs}") - try: - model = WhisperModel(**whisper_kwargs) - except: - log.warning( - "WhisperModel initialization failed, attempting download with local_files_only=False" + try: + model = WhisperModel(**whisper_kwargs) + except: + log.warning( + "WhisperModel initialization failed, attempting download with local_files_only=False" + ) + whisper_kwargs["local_files_only"] = False + model = WhisperModel(**whisper_kwargs) + + segments, info = model.transcribe(file_path, beam_size=5) + log.info( + "Detected language '%s' with probability %f" + % (info.language, info.language_probability) ) - whisper_kwargs["local_files_only"] = False - model = WhisperModel(**whisper_kwargs) - segments, info = model.transcribe(file_path, beam_size=5) - log.info( - "Detected language '%s' with probability %f" - % (info.language, info.language_probability) - ) + transcript = "".join([segment.text for segment in list(segments)]) - transcript = "".join([segment.text for segment in list(segments)]) + data = {"text": transcript.strip()} - return {"text": transcript.strip()} + # save the transcript to a json file + transcript_file = f"{file_dir}/{id}.json" + with open(transcript_file, "w") as f: + json.dump(data, f) + + print(data) + + return data + + elif app.state.config.STT_ENGINE == "openai": + if is_mp4_audio(file_path): + print("is_mp4_audio") + os.rename(file_path, file_path.replace(".wav", ".mp4")) + # Convert MP4 audio file to WAV format + convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path) + + headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"} + + files = {"file": (filename, open(file_path, "rb"))} + data = {"model": "whisper-1"} + + print(files, data) + + r = None + try: + r = requests.post( + url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions", + headers=headers, + files=files, + data=data, + ) + + r.raise_for_status() + + data = r.json() + + # save the transcript to a json file + transcript_file = f"{file_dir}/{id}.json" + with open(transcript_file, "w") as f: + json.dump(data, f) + + print(data) + return data + except Exception as e: + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"External: {res['error']['message']}" + except: + error_detail = f"External: {e}" + + raise HTTPException( + status_code=r.status_code if r != None else 500, + detail=error_detail, + ) except Exception as e: log.exception(e) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 82cd8d383..118c688d3 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -41,13 +41,12 @@ from utils.utils import ( get_admin_user, ) -from utils.models import get_model_id_from_custom_model_id - from config import ( SRC_LOG_LEVELS, OLLAMA_BASE_URLS, ENABLE_OLLAMA_API, + AIOHTTP_CLIENT_TIMEOUT, ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, UPLOAD_DIR, @@ -156,7 +155,9 @@ async def cleanup_response( async def post_streaming_url(url: str, payload: str): r = None try: - session = aiohttp.ClientSession(trust_env=True) + session = aiohttp.ClientSession( + trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + ) r = await session.post(url, data=payload) r.raise_for_status() @@ -728,7 +729,6 @@ async def generate_chat_completion( model_info = Models.get_model_by_id(model_id) if model_info: - print(model_info) if model_info.base_model_id: payload["model"] = model_info.base_model_id @@ -754,6 +754,14 @@ async def generate_chat_completion( if model_info.params.get("num_ctx", None): payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None) + if model_info.params.get("num_batch", None): + payload["options"]["num_batch"] = model_info.params.get( + "num_batch", None + ) + + if model_info.params.get("num_keep", None): + payload["options"]["num_keep"] = model_info.params.get("num_keep", None) + if model_info.params.get("repeat_last_n", None): payload["options"]["repeat_last_n"] = model_info.params.get( "repeat_last_n", None @@ -764,7 +772,7 @@ async def generate_chat_completion( "frequency_penalty", None ) - if model_info.params.get("temperature", None): + if model_info.params.get("temperature", None) is not None: payload["options"]["temperature"] = model_info.params.get( "temperature", None ) @@ -849,9 +857,14 @@ async def generate_chat_completion( # TODO: we should update this part once Ollama supports other types +class OpenAIChatMessageContent(BaseModel): + type: str + model_config = ConfigDict(extra="allow") + + class OpenAIChatMessage(BaseModel): role: str - content: str + content: Union[str, OpenAIChatMessageContent] model_config = ConfigDict(extra="allow") @@ -879,7 +892,6 @@ async def generate_openai_chat_completion( model_info = Models.get_model_by_id(model_id) if model_info: - print(model_info) if model_info.base_model_id: payload["model"] = model_info.base_model_id diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 472699f1d..93f913dea 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -345,113 +345,97 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use ) -@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) -async def proxy(path: str, request: Request, user=Depends(get_verified_user)): +@app.post("/chat/completions") +@app.post("/chat/completions/{url_idx}") +async def generate_chat_completion( + form_data: dict, + url_idx: Optional[int] = None, + user=Depends(get_verified_user), +): idx = 0 + payload = {**form_data} - body = await request.body() - # TODO: Remove below after gpt-4-vision fix from Open AI - # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision) + model_id = form_data.get("model") + model_info = Models.get_model_by_id(model_id) - payload = None + if model_info: + if model_info.base_model_id: + payload["model"] = model_info.base_model_id - try: - if "chat/completions" in path: - body = body.decode("utf-8") - body = json.loads(body) + model_info.params = model_info.params.model_dump() - payload = {**body} + if model_info.params: + if model_info.params.get("temperature", None) is not None: + payload["temperature"] = float(model_info.params.get("temperature")) - model_id = body.get("model") - model_info = Models.get_model_by_id(model_id) + if model_info.params.get("top_p", None): + payload["top_p"] = int(model_info.params.get("top_p", None)) - if model_info: - print(model_info) - if model_info.base_model_id: - payload["model"] = model_info.base_model_id + if model_info.params.get("max_tokens", None): + payload["max_tokens"] = int(model_info.params.get("max_tokens", None)) - model_info.params = model_info.params.model_dump() + if model_info.params.get("frequency_penalty", None): + payload["frequency_penalty"] = int( + model_info.params.get("frequency_penalty", None) + ) - if model_info.params: - if model_info.params.get("temperature", None): - payload["temperature"] = int( - model_info.params.get("temperature") + if model_info.params.get("seed", None): + payload["seed"] = model_info.params.get("seed", None) + + if model_info.params.get("stop", None): + payload["stop"] = ( + [ + bytes(stop, "utf-8").decode("unicode_escape") + for stop in model_info.params["stop"] + ] + if model_info.params.get("stop", None) + else None + ) + + if model_info.params.get("system", None): + # Check if the payload already has a system message + # If not, add a system message to the payload + if payload.get("messages"): + for message in payload["messages"]: + if message.get("role") == "system": + message["content"] = ( + model_info.params.get("system", None) + message["content"] ) + break + else: + payload["messages"].insert( + 0, + { + "role": "system", + "content": model_info.params.get("system", None), + }, + ) - if model_info.params.get("top_p", None): - payload["top_p"] = int(model_info.params.get("top_p", None)) + else: + pass - if model_info.params.get("max_tokens", None): - payload["max_tokens"] = int( - model_info.params.get("max_tokens", None) - ) + model = app.state.MODELS[payload.get("model")] + idx = model["urlIdx"] - if model_info.params.get("frequency_penalty", None): - payload["frequency_penalty"] = int( - model_info.params.get("frequency_penalty", None) - ) + if "pipeline" in model and model.get("pipeline"): + payload["user"] = {"name": user.name, "id": user.id} - if model_info.params.get("seed", None): - payload["seed"] = model_info.params.get("seed", None) + # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 + # This is a workaround until OpenAI fixes the issue with this model + if payload.get("model") == "gpt-4-vision-preview": + if "max_tokens" not in payload: + payload["max_tokens"] = 4000 + log.debug("Modified payload:", payload) - if model_info.params.get("stop", None): - payload["stop"] = ( - [ - bytes(stop, "utf-8").decode("unicode_escape") - for stop in model_info.params["stop"] - ] - if model_info.params.get("stop", None) - else None - ) - - if model_info.params.get("system", None): - # Check if the payload already has a system message - # If not, add a system message to the payload - if payload.get("messages"): - for message in payload["messages"]: - if message.get("role") == "system": - message["content"] = ( - model_info.params.get("system", None) - + message["content"] - ) - break - else: - payload["messages"].insert( - 0, - { - "role": "system", - "content": model_info.params.get("system", None), - }, - ) - else: - pass - - model = app.state.MODELS[payload.get("model")] - - idx = model["urlIdx"] - - if "pipeline" in model and model.get("pipeline"): - payload["user"] = {"name": user.name, "id": user.id} - - # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 - # This is a workaround until OpenAI fixes the issue with this model - if payload.get("model") == "gpt-4-vision-preview": - if "max_tokens" not in payload: - payload["max_tokens"] = 4000 - log.debug("Modified payload:", payload) - - # Convert the modified body back to JSON - payload = json.dumps(payload) - - except json.JSONDecodeError as e: - log.error("Error loading request body into a dictionary:", e) + # Convert the modified body back to JSON + payload = json.dumps(payload) print(payload) url = app.state.config.OPENAI_API_BASE_URLS[idx] key = app.state.config.OPENAI_API_KEYS[idx] - target_url = f"{url}/{path}" + print(payload) headers = {} headers["Authorization"] = f"Bearer {key}" @@ -464,9 +448,72 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): try: session = aiohttp.ClientSession(trust_env=True) r = await session.request( - method=request.method, - url=target_url, - data=payload if payload else body, + method="POST", + url=f"{url}/chat/completions", + data=payload, + headers=headers, + ) + + r.raise_for_status() + + # Check if response is SSE + if "text/event-stream" in r.headers.get("Content-Type", ""): + streaming = True + return StreamingResponse( + r.content, + status_code=r.status, + headers=dict(r.headers), + background=BackgroundTask( + cleanup_response, response=r, session=session + ), + ) + else: + response_data = await r.json() + return response_data + except Exception as e: + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = await r.json() + print(res) + if "error" in res: + error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" + except: + error_detail = f"External: {e}" + raise HTTPException(status_code=r.status if r else 500, detail=error_detail) + finally: + if not streaming and session: + if r: + r.close() + await session.close() + + +@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) +async def proxy(path: str, request: Request, user=Depends(get_verified_user)): + idx = 0 + + body = await request.body() + + url = app.state.config.OPENAI_API_BASE_URLS[idx] + key = app.state.config.OPENAI_API_KEYS[idx] + + target_url = f"{url}/{path}" + + headers = {} + headers["Authorization"] = f"Bearer {key}" + headers["Content-Type"] = "application/json" + + r = None + session = None + streaming = False + + try: + session = aiohttp.ClientSession(trust_env=True) + r = await session.request( + method=request.method, + url=target_url, + data=body, headers=headers, ) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index d405ef0b4..4bd5da86c 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -8,12 +8,15 @@ from fastapi import ( Form, ) from fastapi.middleware.cors import CORSMiddleware +import requests import os, shutil, logging, re +from datetime import datetime from pathlib import Path -from typing import List, Union, Sequence +from typing import List, Union, Sequence, Iterator, Any from chromadb.utils.batch_utils import create_batches +from langchain_core.documents import Document from langchain_community.document_loaders import ( WebBaseLoader, @@ -30,6 +33,7 @@ from langchain_community.document_loaders import ( UnstructuredExcelLoader, UnstructuredPowerPointLoader, YoutubeLoader, + OutlookMessageLoader, ) from langchain.text_splitter import RecursiveCharacterTextSplitter @@ -67,7 +71,9 @@ from apps.rag.search.main import SearchResult from apps.rag.search.searxng import search_searxng from apps.rag.search.serper import search_serper from apps.rag.search.serpstack import search_serpstack - +from apps.rag.search.serply import search_serply +from apps.rag.search.duckduckgo import search_duckduckgo +from apps.rag.search.tavily import search_tavily from utils.misc import ( calculate_sha256, @@ -113,6 +119,8 @@ from config import ( SERPSTACK_API_KEY, SERPSTACK_HTTPS, SERPER_API_KEY, + SERPLY_API_KEY, + TAVILY_API_KEY, RAG_WEB_SEARCH_RESULT_COUNT, RAG_WEB_SEARCH_CONCURRENT_REQUESTS, RAG_EMBEDDING_OPENAI_BATCH_SIZE, @@ -165,6 +173,8 @@ app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS app.state.config.SERPER_API_KEY = SERPER_API_KEY +app.state.config.SERPLY_API_KEY = SERPLY_API_KEY +app.state.config.TAVILY_API_KEY = TAVILY_API_KEY app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS @@ -392,6 +402,8 @@ async def get_rag_config(user=Depends(get_admin_user)): "serpstack_api_key": app.state.config.SERPSTACK_API_KEY, "serpstack_https": app.state.config.SERPSTACK_HTTPS, "serper_api_key": app.state.config.SERPER_API_KEY, + "serply_api_key": app.state.config.SERPLY_API_KEY, + "tavily_api_key": app.state.config.TAVILY_API_KEY, "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, }, @@ -419,6 +431,8 @@ class WebSearchConfig(BaseModel): serpstack_api_key: Optional[str] = None serpstack_https: Optional[bool] = None serper_api_key: Optional[str] = None + serply_api_key: Optional[str] = None + tavily_api_key: Optional[str] = None result_count: Optional[int] = None concurrent_requests: Optional[int] = None @@ -469,6 +483,8 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key + app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key + app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = ( form_data.web.search.concurrent_requests @@ -497,6 +513,8 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ "serpstack_api_key": app.state.config.SERPSTACK_API_KEY, "serpstack_https": app.state.config.SERPSTACK_HTTPS, "serper_api_key": app.state.config.SERPER_API_KEY, + "serply_api_key": app.state.config.SERPLY_API_KEY, + "tavily_api_key": app.state.config.TAVILY_API_KEY, "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, }, @@ -693,7 +711,7 @@ def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True): # Check if the URL is valid if not validate_url(url): raise ValueError(ERROR_MESSAGES.INVALID_URL) - return WebBaseLoader( + return SafeWebBaseLoader( url, verify_ssl=verify_ssl, requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS, @@ -744,7 +762,8 @@ def search_web(engine: str, query: str) -> list[SearchResult]: - BRAVE_SEARCH_API_KEY - SERPSTACK_API_KEY - SERPER_API_KEY - + - SERPLY_API_KEY + - TAVILY_API_KEY Args: query (str): The query to search for """ @@ -802,6 +821,26 @@ def search_web(engine: str, query: str) -> list[SearchResult]: ) else: raise Exception("No SERPER_API_KEY found in environment variables") + elif engine == "serply": + if app.state.config.SERPLY_API_KEY: + return search_serply( + app.state.config.SERPLY_API_KEY, + query, + app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + ) + else: + raise Exception("No SERPLY_API_KEY found in environment variables") + elif engine == "duckduckgo": + return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT) + elif engine == "tavily": + if app.state.config.TAVILY_API_KEY: + return search_tavily( + app.state.config.TAVILY_API_KEY, + query, + app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + ) + else: + raise Exception("No TAVILY_API_KEY found in environment variables") else: raise Exception("No search engine API key found in environment variables") @@ -809,6 +848,9 @@ def search_web(engine: str, query: str) -> list[SearchResult]: @app.post("/web/search") def store_web_search(form_data: SearchForm, user=Depends(get_current_user)): try: + logging.info( + f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}" + ) web_results = search_web( app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query ) @@ -879,6 +921,13 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b texts = [doc.page_content for doc in docs] metadatas = [doc.metadata for doc in docs] + # ChromaDB does not like datetime formats + # for meta-data so convert them to string. + for metadata in metadatas: + for key, value in metadata.items(): + if isinstance(value, datetime): + metadata[key] = str(value) + try: if overwrite: for collection in CHROMA_CLIENT.list_collections(): @@ -965,6 +1014,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str): "swift", "vue", "svelte", + "msg", ] if file_ext == "pdf": @@ -999,6 +1049,8 @@ def get_loader(filename: str, file_content_type: str, file_path: str): "application/vnd.openxmlformats-officedocument.presentationml.presentation", ] or file_ext in ["ppt", "pptx"]: loader = UnstructuredPowerPointLoader(file_path) + elif file_ext == "msg": + loader = OutlookMessageLoader(file_path) elif file_ext in known_source_ext or ( file_content_type and file_content_type.find("text/") >= 0 ): @@ -1209,6 +1261,33 @@ def reset(user=Depends(get_admin_user)) -> bool: return True +class SafeWebBaseLoader(WebBaseLoader): + """WebBaseLoader with enhanced error handling for URLs.""" + + def lazy_load(self) -> Iterator[Document]: + """Lazy load text from the url(s) in web_path with error handling.""" + for path in self.web_paths: + try: + soup = self._scrape(path, bs_kwargs=self.bs_kwargs) + text = soup.get_text(**self.bs_get_text_kwargs) + + # Build metadata + metadata = {"source": path} + if title := soup.find("title"): + metadata["title"] = title.get_text() + if description := soup.find("meta", attrs={"name": "description"}): + metadata["description"] = description.get( + "content", "No description found." + ) + if html := soup.find("html"): + metadata["language"] = html.get("lang", "No language found.") + + yield Document(page_content=text, metadata=metadata) + except Exception as e: + # Log the error and continue with the next URL + log.error(f"Error loading {path}: {e}") + + if ENV == "dev": @app.get("/ef") diff --git a/backend/apps/rag/search/duckduckgo.py b/backend/apps/rag/search/duckduckgo.py new file mode 100644 index 000000000..188ae2bea --- /dev/null +++ b/backend/apps/rag/search/duckduckgo.py @@ -0,0 +1,46 @@ +import logging + +from apps.rag.search.main import SearchResult +from duckduckgo_search import DDGS +from config import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +def search_duckduckgo(query: str, count: int) -> list[SearchResult]: + """ + Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects. + Args: + query (str): The query to search for + count (int): The number of results to return + + Returns: + List[SearchResult]: A list of search results + """ + # Use the DDGS context manager to create a DDGS object + with DDGS() as ddgs: + # Use the ddgs.text() method to perform the search + ddgs_gen = ddgs.text( + query, safesearch="moderate", max_results=count, backend="api" + ) + # Check if there are search results + if ddgs_gen: + # Convert the search results into a list + search_results = [r for r in ddgs_gen] + + # Create an empty list to store the SearchResult objects + results = [] + # Iterate over each search result + for result in search_results: + # Create a SearchResult object and append it to the results list + results.append( + SearchResult( + link=result["href"], + title=result.get("title"), + snippet=result.get("body"), + ) + ) + print(results) + # Return the list of search results + return results diff --git a/backend/apps/rag/search/searxng.py b/backend/apps/rag/search/searxng.py index a62ab5089..c8ad88813 100644 --- a/backend/apps/rag/search/searxng.py +++ b/backend/apps/rag/search/searxng.py @@ -25,6 +25,7 @@ def search_searxng( Keyword Args: language (str): Language filter for the search results; e.g., "en-US". Defaults to an empty string. + safesearch (int): Safe search filter for safer web results; 0 = off, 1 = moderate, 2 = strict. Defaults to 1 (moderate). time_range (str): Time range for filtering results by date; e.g., "2023-04-05..today" or "all-time". Defaults to ''. categories: (Optional[List[str]]): Specific categories within which the search should be performed, defaulting to an empty string if not provided. @@ -37,6 +38,7 @@ def search_searxng( # Default values for optional parameters are provided as empty strings or None when not specified. language = kwargs.get("language", "en-US") + safesearch = kwargs.get("safesearch", "1") time_range = kwargs.get("time_range", "") categories = "".join(kwargs.get("categories", [])) @@ -44,6 +46,7 @@ def search_searxng( "q": query, "format": "json", "pageno": 1, + "safesearch": safesearch, "language": language, "time_range": time_range, "categories": categories, diff --git a/backend/apps/rag/search/serply.py b/backend/apps/rag/search/serply.py new file mode 100644 index 000000000..fccf70ecd --- /dev/null +++ b/backend/apps/rag/search/serply.py @@ -0,0 +1,68 @@ +import json +import logging + +import requests +from urllib.parse import urlencode + +from apps.rag.search.main import SearchResult +from config import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +def search_serply( + api_key: str, + query: str, + count: int, + hl: str = "us", + limit: int = 10, + device_type: str = "desktop", + proxy_location: str = "US", +) -> list[SearchResult]: + """Search using serper.dev's API and return the results as a list of SearchResult objects. + + Args: + api_key (str): A serply.io API key + query (str): The query to search for + hl (str): Host Language code to display results in (reference https://developers.google.com/custom-search/docs/xml_results?hl=en#wsInterfaceLanguages) + limit (int): The maximum number of results to return [10-100, defaults to 10] + """ + log.info("Searching with Serply") + + url = "https://api.serply.io/v1/search/" + + query_payload = { + "q": query, + "language": "en", + "num": limit, + "gl": proxy_location.upper(), + "hl": hl.lower(), + } + + url = f"{url}{urlencode(query_payload)}" + headers = { + "X-API-KEY": api_key, + "X-User-Agent": device_type, + "User-Agent": "open-webui", + "X-Proxy-Location": proxy_location, + } + + response = requests.request("GET", url, headers=headers) + response.raise_for_status() + + json_response = response.json() + log.info(f"results from serply search: {json_response}") + + results = sorted( + json_response.get("results", []), key=lambda x: x.get("realPosition", 0) + ) + + return [ + SearchResult( + link=result["link"], + title=result.get("title"), + snippet=result.get("description"), + ) + for result in results[:count] + ] diff --git a/backend/apps/rag/search/tavily.py b/backend/apps/rag/search/tavily.py new file mode 100644 index 000000000..b15d6ef9d --- /dev/null +++ b/backend/apps/rag/search/tavily.py @@ -0,0 +1,39 @@ +import logging + +import requests + +from apps.rag.search.main import SearchResult +from config import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]: + """Search using Tavily's Search API and return the results as a list of SearchResult objects. + + Args: + api_key (str): A Tavily Search API key + query (str): The query to search for + + Returns: + List[SearchResult]: A list of search results + """ + url = "https://api.tavily.com/search" + data = {"query": query, "api_key": api_key} + + response = requests.post(url, json=data) + response.raise_for_status() + + json_response = response.json() + + raw_search_results = json_response.get("results", []) + + return [ + SearchResult( + link=result["url"], + title=result.get("title", ""), + snippet=result.get("content"), + ) + for result in raw_search_results[:count] + ] diff --git a/backend/apps/rag/search/testdata/serply.json b/backend/apps/rag/search/testdata/serply.json new file mode 100644 index 000000000..0fc2a31e4 --- /dev/null +++ b/backend/apps/rag/search/testdata/serply.json @@ -0,0 +1,206 @@ +{ + "ads": [], + "ads_count": 0, + "answers": [], + "results": [ + { + "title": "Apple", + "link": "https://www.apple.com/", + "description": "Discover the innovative world of Apple and shop everything iPhone, iPad, Apple Watch, Mac, and Apple TV, plus explore accessories, entertainment, ...", + "additional_links": [ + { + "text": "AppleApplehttps://www.apple.com", + "href": "https://www.apple.com/" + } + ], + "cite": {}, + "subdomains": [ + { + "title": "Support", + "link": "https://support.apple.com/", + "description": "SupportContact - iPhone Support - Billing and Subscriptions - Apple Repair" + }, + { + "title": "Store", + "link": "https://www.apple.com/store", + "description": "StoreShop iPhone - Shop iPad - App Store - Shop Mac - ..." + }, + { + "title": "Mac", + "link": "https://www.apple.com/mac/", + "description": "MacMacBook Air - MacBook Pro - iMac - Compare Mac models - Mac mini" + }, + { + "title": "iPad", + "link": "https://www.apple.com/ipad/", + "description": "iPadShop iPad - iPad Pro - iPad Air - Compare iPad models - ..." + }, + { + "title": "Watch", + "link": "https://www.apple.com/watch/", + "description": "WatchShop Apple Watch - Series 9 - SE - Ultra 2 - Nike - Hermès - ..." + } + ], + "realPosition": 1 + }, + { + "title": "Apple", + "link": "https://www.apple.com/", + "description": "Discover the innovative world of Apple and shop everything iPhone, iPad, Apple Watch, Mac, and Apple TV, plus explore accessories, entertainment, ...", + "additional_links": [ + { + "text": "AppleApplehttps://www.apple.com", + "href": "https://www.apple.com/" + } + ], + "cite": {}, + "realPosition": 2 + }, + { + "title": "Apple Inc.", + "link": "https://en.wikipedia.org/wiki/Apple_Inc.", + "description": "Apple Inc. (formerly Apple Computer, Inc.) is an American multinational corporation and technology company headquartered in Cupertino, California, ...", + "additional_links": [ + { + "text": "Apple Inc.Wikipediahttps://en.wikipedia.org › wiki › Apple_Inc", + "href": "https://en.wikipedia.org/wiki/Apple_Inc." + }, + { + "text": "", + "href": "https://en.wikipedia.org/wiki/Apple_Inc." + }, + { + "text": "History", + "href": "https://en.wikipedia.org/wiki/History_of_Apple_Inc." + }, + { + "text": "List of Apple products", + "href": "https://en.wikipedia.org/wiki/List_of_Apple_products" + }, + { + "text": "Litigation involving Apple Inc.", + "href": "https://en.wikipedia.org/wiki/Litigation_involving_Apple_Inc." + }, + { + "text": "Apple Park", + "href": "https://en.wikipedia.org/wiki/Apple_Park" + } + ], + "cite": { + "domain": "https://en.wikipedia.org › wiki › Apple_Inc", + "span": " › wiki › Apple_Inc" + }, + "realPosition": 3 + }, + { + "title": "Apple Inc. (AAPL) Company Profile & Facts", + "link": "https://finance.yahoo.com/quote/AAPL/profile/", + "description": "Apple Inc. designs, manufactures, and markets smartphones, personal computers, tablets, wearables, and accessories worldwide. The company offers iPhone, a line ...", + "additional_links": [ + { + "text": "Apple Inc. (AAPL) Company Profile & FactsYahoo Financehttps://finance.yahoo.com › quote › AAPL › profile", + "href": "https://finance.yahoo.com/quote/AAPL/profile/" + } + ], + "cite": { + "domain": "https://finance.yahoo.com › quote › AAPL › profile", + "span": " › quote › AAPL › profile" + }, + "realPosition": 4 + }, + { + "title": "Apple Inc - Company Profile and News", + "link": "https://www.bloomberg.com/profile/company/AAPL:US", + "description": "Apple Inc. Apple Inc. designs, manufactures, and markets smartphones, personal computers, tablets, wearables and accessories, and sells a variety of related ...", + "additional_links": [ + { + "text": "Apple Inc - Company Profile and NewsBloomberghttps://www.bloomberg.com › company › AAPL:US", + "href": "https://www.bloomberg.com/profile/company/AAPL:US" + }, + { + "text": "", + "href": "https://www.bloomberg.com/profile/company/AAPL:US" + } + ], + "cite": { + "domain": "https://www.bloomberg.com › company › AAPL:US", + "span": " › company › AAPL:US" + }, + "realPosition": 5 + }, + { + "title": "Apple Inc. | History, Products, Headquarters, & Facts", + "link": "https://www.britannica.com/money/Apple-Inc", + "description": "May 22, 2024 — Apple Inc. is an American multinational technology company that revolutionized the technology sector through its innovation of computer ...", + "additional_links": [ + { + "text": "Apple Inc. | History, Products, Headquarters, & FactsBritannicahttps://www.britannica.com › money › Apple-Inc", + "href": "https://www.britannica.com/money/Apple-Inc" + }, + { + "text": "", + "href": "https://www.britannica.com/money/Apple-Inc" + } + ], + "cite": { + "domain": "https://www.britannica.com › money › Apple-Inc", + "span": " › money › Apple-Inc" + }, + "realPosition": 6 + } + ], + "shopping_ads": [], + "places": [ + { + "title": "Apple Inc." + }, + { + "title": "Apple Inc" + }, + { + "title": "Apple Inc" + } + ], + "related_searches": { + "images": [], + "text": [ + { + "title": "apple inc full form", + "link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+Inc+full+form&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhPEAE" + }, + { + "title": "apple company history", + "link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+company+history&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhOEAE" + }, + { + "title": "apple store", + "link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+Store&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhQEAE" + }, + { + "title": "apple id", + "link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+id&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhSEAE" + }, + { + "title": "apple inc industry", + "link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+Inc+industry&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhREAE" + }, + { + "title": "apple login", + "link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+login&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhTEAE" + } + ] + }, + "image_results": [], + "carousel": [], + "total": 2450000000, + "knowledge_graph": "", + "related_questions": [ + "What does the Apple Inc do?", + "Why did Apple change to Apple Inc?", + "Who owns Apple Inc.?", + "What is Apple Inc best known for?" + ], + "carousel_count": 0, + "ts": 2.491065263748169, + "device_type": null +} diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index ac52dc3d8..d0570f748 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -20,7 +20,7 @@ from langchain.retrievers import ( from typing import Optional - +from utils.misc import get_last_user_message, add_or_update_system_message from config import SRC_LOG_LEVELS, CHROMA_CLIENT log = logging.getLogger(__name__) @@ -236,10 +236,9 @@ def get_embedding_function( return lambda query: generate_multiple(query, func) -def rag_messages( +def get_rag_context( docs, messages, - template, embedding_function, k, reranking_function, @@ -247,31 +246,7 @@ def rag_messages( hybrid_search, ): log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}") - - last_user_message_idx = None - for i in range(len(messages) - 1, -1, -1): - if messages[i]["role"] == "user": - last_user_message_idx = i - break - - user_message = messages[last_user_message_idx] - - if isinstance(user_message["content"], list): - # Handle list content input - content_type = "list" - query = "" - for content_item in user_message["content"]: - if content_item["type"] == "text": - query = content_item["text"] - break - elif isinstance(user_message["content"], str): - # Handle text content input - content_type = "text" - query = user_message["content"] - else: - # Fallback in case the input does not match expected types - content_type = None - query = "" + query = get_last_user_message(messages) extracted_collections = [] relevant_contexts = [] @@ -342,33 +317,7 @@ def rag_messages( context_string = context_string.strip() - ra_content = rag_template( - template=template, - context=context_string, - query=query, - ) - - log.debug(f"ra_content: {ra_content}") - - if content_type == "list": - new_content = [] - for content_item in user_message["content"]: - if content_item["type"] == "text": - # Update the text item's content with ra_content - new_content.append({"type": "text", "text": ra_content}) - else: - # Keep other types of content as they are - new_content.append(content_item) - new_user_message = {**user_message, "content": new_content} - else: - new_user_message = { - **user_message, - "content": ra_content, - } - - messages[last_user_message_idx] = new_user_message - - return messages, citations + return context_string, citations def get_model_path(model: str, update_model: bool = False): diff --git a/backend/apps/socket/main.py b/backend/apps/socket/main.py index 0bc45287a..123ff31cd 100644 --- a/backend/apps/socket/main.py +++ b/backend/apps/socket/main.py @@ -10,7 +10,7 @@ app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io") # Dictionary to maintain the user pool - +SESSION_POOL = {} USER_POOL = {} USAGE_POOL = {} # Timeout duration in seconds @@ -19,8 +19,6 @@ TIMEOUT_DURATION = 3 @sio.event async def connect(sid, environ, auth): - print("connect ", sid) - user = None if auth and "token" in auth: data = decode_token(auth["token"]) @@ -29,10 +27,14 @@ async def connect(sid, environ, auth): user = Users.get_user_by_id(data["id"]) if user: - USER_POOL[sid] = user.id + SESSION_POOL[sid] = user.id + if user.id in USER_POOL: + USER_POOL[user.id].append(sid) + else: + USER_POOL[user.id] = [sid] + print(f"user {user.name}({user.id}) connected with session ID {sid}") - print(len(set(USER_POOL))) await sio.emit("user-count", {"count": len(set(USER_POOL))}) await sio.emit("usage", {"models": get_models_in_use()}) @@ -50,16 +52,20 @@ async def user_join(sid, data): user = Users.get_user_by_id(data["id"]) if user: - USER_POOL[sid] = user.id + + SESSION_POOL[sid] = user.id + if user.id in USER_POOL: + USER_POOL[user.id].append(sid) + else: + USER_POOL[user.id] = [sid] + print(f"user {user.name}({user.id}) connected with session ID {sid}") - print(len(set(USER_POOL))) await sio.emit("user-count", {"count": len(set(USER_POOL))}) @sio.on("user-count") async def user_count(sid): - print("user-count", sid) await sio.emit("user-count", {"count": len(set(USER_POOL))}) @@ -68,14 +74,12 @@ def get_models_in_use(): models_in_use = [] for model_id, data in USAGE_POOL.items(): models_in_use.append(model_id) - print(f"Models in use: {models_in_use}") return models_in_use @sio.on("usage") async def usage(sid, data): - print(f'Received "usage" event from {sid}: {data}') model_id = data["model"] @@ -103,7 +107,6 @@ async def usage(sid, data): async def remove_after_timeout(sid, model_id): try: - print("remove_after_timeout", sid, model_id) await asyncio.sleep(TIMEOUT_DURATION) if model_id in USAGE_POOL: print(USAGE_POOL[model_id]["sids"]) @@ -113,7 +116,6 @@ async def remove_after_timeout(sid, model_id): if len(USAGE_POOL[model_id]["sids"]) == 0: del USAGE_POOL[model_id] - print(f"Removed usage data for {model_id} due to timeout") # Broadcast the usage data to all clients await sio.emit("usage", {"models": get_models_in_use()}) except asyncio.CancelledError: @@ -123,9 +125,14 @@ async def remove_after_timeout(sid, model_id): @sio.event async def disconnect(sid): - if sid in USER_POOL: - disconnected_user = USER_POOL.pop(sid) - print(f"user {disconnected_user} disconnected with session ID {sid}") + if sid in SESSION_POOL: + user_id = SESSION_POOL[sid] + del SESSION_POOL[sid] + + USER_POOL[user_id].remove(sid) + + if len(USER_POOL[user_id]) == 0: + del USER_POOL[user_id] await sio.emit("user-count", {"count": len(USER_POOL)}) else: diff --git a/backend/apps/webui/internal/migrations/012_add_tools.py b/backend/apps/webui/internal/migrations/012_add_tools.py new file mode 100644 index 000000000..4a68eea55 --- /dev/null +++ b/backend/apps/webui/internal/migrations/012_add_tools.py @@ -0,0 +1,61 @@ +"""Peewee migrations -- 009_add_models.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + @migrator.create_model + class Tool(pw.Model): + id = pw.TextField(unique=True) + user_id = pw.TextField() + + name = pw.TextField() + content = pw.TextField() + specs = pw.TextField() + + meta = pw.TextField() + + created_at = pw.BigIntegerField(null=False) + updated_at = pw.BigIntegerField(null=False) + + class Meta: + table_name = "tool" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model("tool") diff --git a/backend/apps/webui/internal/migrations/011_add_user_oauth_sub.py b/backend/apps/webui/internal/migrations/013_add_user_oauth_sub.py similarity index 97% rename from backend/apps/webui/internal/migrations/011_add_user_oauth_sub.py rename to backend/apps/webui/internal/migrations/013_add_user_oauth_sub.py index 70dfeccf0..9bd3f4721 100644 --- a/backend/apps/webui/internal/migrations/011_add_user_oauth_sub.py +++ b/backend/apps/webui/internal/migrations/013_add_user_oauth_sub.py @@ -1,4 +1,4 @@ -"""Peewee migrations -- 011_add_user_oauth_sub.py. +"""Peewee migrations -- 013_add_user_oauth_sub.py. Some examples (model - class or model name):: diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 5b238b12b..a9a187ee5 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -8,6 +8,7 @@ from apps.webui.routers import ( users, chats, documents, + tools, models, prompts, configs, @@ -27,9 +28,9 @@ from config import ( WEBHOOK_URL, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, JWT_EXPIRES_IN, - AppConfig, - ENABLE_COMMUNITY_SHARING, WEBUI_BANNERS, + ENABLE_COMMUNITY_SHARING, + AppConfig, ) app = FastAPI() @@ -40,6 +41,7 @@ app.state.config = AppConfig() app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN +app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS @@ -56,7 +58,7 @@ app.state.config.BANNERS = WEBUI_BANNERS app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING app.state.MODELS = {} -app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER +app.state.TOOLS = {} app.add_middleware( @@ -72,6 +74,7 @@ app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(documents.router, prefix="/documents", tags=["documents"]) +app.include_router(tools.router, prefix="/tools", tags=["tools"]) app.include_router(models.router, prefix="/models", tags=["models"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(memories.router, prefix="/memories", tags=["memories"]) diff --git a/backend/apps/webui/models/memories.py b/backend/apps/webui/models/memories.py index 70e5577e9..ef63674ab 100644 --- a/backend/apps/webui/models/memories.py +++ b/backend/apps/webui/models/memories.py @@ -65,6 +65,20 @@ class MemoriesTable: else: return None + def update_memory_by_id( + self, + id: str, + content: str, + ) -> Optional[MemoryModel]: + try: + memory = Memory.get(Memory.id == id) + memory.content = content + memory.updated_at = int(time.time()) + memory.save() + return MemoryModel(**model_to_dict(memory)) + except: + return None + def get_memories(self) -> List[MemoryModel]: try: memories = Memory.select() diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py new file mode 100644 index 000000000..e2db1e35f --- /dev/null +++ b/backend/apps/webui/models/tools.py @@ -0,0 +1,132 @@ +from pydantic import BaseModel +from peewee import * +from playhouse.shortcuts import model_to_dict +from typing import List, Union, Optional +import time +import logging +from apps.webui.internal.db import DB, JSONField + +import json + +from config import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + +#################### +# Tools DB Schema +#################### + + +class Tool(Model): + id = CharField(unique=True) + user_id = CharField() + name = TextField() + content = TextField() + specs = JSONField() + meta = JSONField() + updated_at = BigIntegerField() + created_at = BigIntegerField() + + class Meta: + database = DB + + +class ToolMeta(BaseModel): + description: Optional[str] = None + + +class ToolModel(BaseModel): + id: str + user_id: str + name: str + content: str + specs: List[dict] + meta: ToolMeta + updated_at: int # timestamp in epoch + created_at: int # timestamp in epoch + + +#################### +# Forms +#################### + + +class ToolResponse(BaseModel): + id: str + user_id: str + name: str + meta: ToolMeta + updated_at: int # timestamp in epoch + created_at: int # timestamp in epoch + + +class ToolForm(BaseModel): + id: str + name: str + content: str + meta: ToolMeta + + +class ToolsTable: + def __init__(self, db): + self.db = db + self.db.create_tables([Tool]) + + def insert_new_tool( + self, user_id: str, form_data: ToolForm, specs: List[dict] + ) -> Optional[ToolModel]: + tool = ToolModel( + **{ + **form_data.model_dump(), + "specs": specs, + "user_id": user_id, + "updated_at": int(time.time()), + "created_at": int(time.time()), + } + ) + + try: + result = Tool.create(**tool.model_dump()) + if result: + return tool + else: + return None + except Exception as e: + print(f"Error creating tool: {e}") + return None + + def get_tool_by_id(self, id: str) -> Optional[ToolModel]: + try: + tool = Tool.get(Tool.id == id) + return ToolModel(**model_to_dict(tool)) + except: + return None + + def get_tools(self) -> List[ToolModel]: + return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()] + + def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: + try: + query = Tool.update( + **updated, + updated_at=int(time.time()), + ).where(Tool.id == id) + query.execute() + + tool = Tool.get(Tool.id == id) + return ToolModel(**model_to_dict(tool)) + except: + return None + + def delete_tool_by_id(self, id: str) -> bool: + try: + query = Tool.delete().where((Tool.id == id)) + query.execute() # Remove the rows, return number of rows removed. + + return True + except: + return False + + +Tools = ToolsTable(DB) diff --git a/backend/apps/webui/routers/chats.py b/backend/apps/webui/routers/chats.py index 49df3d284..9d1cceaa1 100644 --- a/backend/apps/webui/routers/chats.py +++ b/backend/apps/webui/routers/chats.py @@ -161,7 +161,7 @@ async def get_archived_session_user_chat_list( ############################ -@router.post("/archive/all", response_model=List[ChatTitleIdResponse]) +@router.post("/archive/all", response_model=bool) async def archive_all_chats(user=Depends(get_current_user)): return Chats.archive_all_chats_by_user_id(user.id) diff --git a/backend/apps/webui/routers/documents.py b/backend/apps/webui/routers/documents.py index c5447a3fe..311455390 100644 --- a/backend/apps/webui/routers/documents.py +++ b/backend/apps/webui/routers/documents.py @@ -73,7 +73,7 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)): ############################ -@router.get("/name/{name}", response_model=Optional[DocumentResponse]) +@router.get("/doc", response_model=Optional[DocumentResponse]) async def get_doc_by_name(name: str, user=Depends(get_current_user)): doc = Documents.get_doc_by_name(name) @@ -105,7 +105,7 @@ class TagDocumentForm(BaseModel): tags: List[dict] -@router.post("/name/{name}/tags", response_model=Optional[DocumentResponse]) +@router.post("/doc/tags", response_model=Optional[DocumentResponse]) async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)): doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags}) @@ -128,7 +128,7 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_u ############################ -@router.post("/name/{name}/update", response_model=Optional[DocumentResponse]) +@router.post("/doc/update", response_model=Optional[DocumentResponse]) async def update_doc_by_name( name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user) ): @@ -152,7 +152,7 @@ async def update_doc_by_name( ############################ -@router.delete("/name/{name}/delete", response_model=bool) +@router.delete("/doc/delete", response_model=bool) async def delete_doc_by_name(name: str, user=Depends(get_admin_user)): result = Documents.delete_doc_by_name(name) return result diff --git a/backend/apps/webui/routers/memories.py b/backend/apps/webui/routers/memories.py index 6448ebe1e..3832fe9a1 100644 --- a/backend/apps/webui/routers/memories.py +++ b/backend/apps/webui/routers/memories.py @@ -44,6 +44,10 @@ class AddMemoryForm(BaseModel): content: str +class MemoryUpdateModel(BaseModel): + content: Optional[str] = None + + @router.post("/add", response_model=Optional[MemoryModel]) async def add_memory( request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user) @@ -62,6 +66,34 @@ async def add_memory( return memory +@router.post("/{memory_id}/update", response_model=Optional[MemoryModel]) +async def update_memory_by_id( + memory_id: str, + request: Request, + form_data: MemoryUpdateModel, + user=Depends(get_verified_user), +): + memory = Memories.update_memory_by_id(memory_id, form_data.content) + if memory is None: + raise HTTPException(status_code=404, detail="Memory not found") + + if form_data.content is not None: + memory_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content) + collection = CHROMA_CLIENT.get_or_create_collection( + name=f"user-memory-{user.id}" + ) + collection.upsert( + documents=[form_data.content], + ids=[memory.id], + embeddings=[memory_embedding], + metadatas=[ + {"created_at": memory.created_at, "updated_at": memory.updated_at} + ], + ) + + return memory + + ############################ # QueryMemory ############################ diff --git a/backend/apps/webui/routers/tools.py b/backend/apps/webui/routers/tools.py new file mode 100644 index 000000000..b68ed32ee --- /dev/null +++ b/backend/apps/webui/routers/tools.py @@ -0,0 +1,183 @@ +from fastapi import Depends, FastAPI, HTTPException, status, Request +from datetime import datetime, timedelta +from typing import List, Union, Optional + +from fastapi import APIRouter +from pydantic import BaseModel +import json + +from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse +from apps.webui.utils import load_toolkit_module_by_id + +from utils.utils import get_current_user, get_admin_user +from utils.tools import get_tools_specs +from constants import ERROR_MESSAGES + +from importlib import util +import os + +from config import DATA_DIR + + +TOOLS_DIR = f"{DATA_DIR}/tools" +os.makedirs(TOOLS_DIR, exist_ok=True) + + +router = APIRouter() + +############################ +# GetToolkits +############################ + + +@router.get("/", response_model=List[ToolResponse]) +async def get_toolkits(user=Depends(get_current_user)): + toolkits = [toolkit for toolkit in Tools.get_tools()] + return toolkits + + +############################ +# ExportToolKits +############################ + + +@router.get("/export", response_model=List[ToolModel]) +async def get_toolkits(user=Depends(get_admin_user)): + toolkits = [toolkit for toolkit in Tools.get_tools()] + return toolkits + + +############################ +# CreateNewToolKit +############################ + + +@router.post("/create", response_model=Optional[ToolResponse]) +async def create_new_toolkit( + request: Request, form_data: ToolForm, user=Depends(get_admin_user) +): + if not form_data.id.isidentifier(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Only alphanumeric characters and underscores are allowed in the id", + ) + + form_data.id = form_data.id.lower() + + toolkit = Tools.get_tool_by_id(form_data.id) + if toolkit == None: + toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py") + try: + with open(toolkit_path, "w") as tool_file: + tool_file.write(form_data.content) + + toolkit_module = load_toolkit_module_by_id(form_data.id) + + TOOLS = request.app.state.TOOLS + TOOLS[form_data.id] = toolkit_module + + specs = get_tools_specs(TOOLS[form_data.id]) + toolkit = Tools.insert_new_tool(user.id, form_data, specs) + + if toolkit: + return toolkit + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error creating toolkit"), + ) + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ID_TAKEN, + ) + + +############################ +# GetToolkitById +############################ + + +@router.get("/id/{id}", response_model=Optional[ToolModel]) +async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)): + toolkit = Tools.get_tool_by_id(id) + + if toolkit: + return toolkit + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdateToolkitById +############################ + + +@router.post("/id/{id}/update", response_model=Optional[ToolModel]) +async def update_toolkit_by_id( + request: Request, id: str, form_data: ToolForm, user=Depends(get_admin_user) +): + toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py") + + try: + with open(toolkit_path, "w") as tool_file: + tool_file.write(form_data.content) + + toolkit_module = load_toolkit_module_by_id(id) + + TOOLS = request.app.state.TOOLS + TOOLS[id] = toolkit_module + + specs = get_tools_specs(TOOLS[id]) + + updated = { + **form_data.model_dump(exclude={"id"}), + "specs": specs, + } + + print(updated) + toolkit = Tools.update_tool_by_id(id, updated) + + if toolkit: + return toolkit + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"), + ) + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +############################ +# DeleteToolkitById +############################ + + +@router.delete("/id/{id}/delete", response_model=bool) +async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)): + result = Tools.delete_tool_by_id(id) + + if result: + TOOLS = request.app.state.TOOLS + if id in TOOLS: + del TOOLS[id] + + # delete the toolkit file + toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py") + os.remove(toolkit_path) + + return result diff --git a/backend/apps/webui/routers/utils.py b/backend/apps/webui/routers/utils.py index 18491130a..8f6d663b4 100644 --- a/backend/apps/webui/routers/utils.py +++ b/backend/apps/webui/routers/utils.py @@ -7,6 +7,8 @@ from pydantic import BaseModel from fpdf import FPDF import markdown +import black + from apps.webui.internal.db import DB from utils.utils import get_admin_user @@ -26,6 +28,21 @@ async def get_gravatar( return get_gravatar_url(email) +class CodeFormatRequest(BaseModel): + code: str + + +@router.post("/code/format") +async def format_code(request: CodeFormatRequest): + try: + formatted_code = black.format_str(request.code, mode=black.Mode()) + return {"code": formatted_code} + except black.NothingChanged: + return {"code": request.code} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + class MarkdownForm(BaseModel): md: str diff --git a/backend/apps/webui/utils.py b/backend/apps/webui/utils.py new file mode 100644 index 000000000..19a8615bc --- /dev/null +++ b/backend/apps/webui/utils.py @@ -0,0 +1,23 @@ +from importlib import util +import os + +from config import TOOLS_DIR + + +def load_toolkit_module_by_id(toolkit_id): + toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py") + spec = util.spec_from_file_location(toolkit_id, toolkit_path) + module = util.module_from_spec(spec) + + try: + spec.loader.exec_module(module) + print(f"Loaded module: {module.__name__}") + if hasattr(module, "Tools"): + return module.Tools() + else: + raise Exception("No Tools class found") + except Exception as e: + print(f"Error loading module: {toolkit_id}") + # Move the file to the error folder + os.rename(toolkit_path, f"{toolkit_path}.error") + raise e diff --git a/backend/config.py b/backend/config.py index 8f4aebff2..21287ae72 100644 --- a/backend/config.py +++ b/backend/config.py @@ -435,7 +435,11 @@ STATIC_DIR = Path(os.getenv("STATIC_DIR", BACKEND_DIR / "static")).resolve() frontend_favicon = FRONTEND_BUILD_DIR / "favicon.png" if frontend_favicon.exists(): - shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png") + try: + shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png") + except Exception as e: + logging.error(f"An error occurred: {e}") + else: logging.warning(f"Frontend favicon not found at {frontend_favicon}") @@ -493,6 +497,14 @@ DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs") Path(DOCS_DIR).mkdir(parents=True, exist_ok=True) +#################################### +# Tools DIR +#################################### + +TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools") +Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True) + + #################################### # LITELLM_CONFIG #################################### @@ -542,6 +554,7 @@ OLLAMA_API_BASE_URL = os.environ.get( ) OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "") +AIOHTTP_CLIENT_TIMEOUT = int(os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "300")) K8S_FLAG = os.environ.get("K8S_FLAG", "") USE_OLLAMA_DOCKER = os.environ.get("USE_OLLAMA_DOCKER", "false") @@ -744,6 +757,78 @@ ADMIN_EMAIL = PersistentConfig( ) +#################################### +# TASKS +#################################### + + +TASK_MODEL = PersistentConfig( + "TASK_MODEL", + "task.model.default", + os.environ.get("TASK_MODEL", ""), +) + +TASK_MODEL_EXTERNAL = PersistentConfig( + "TASK_MODEL_EXTERNAL", + "task.model.external", + os.environ.get("TASK_MODEL_EXTERNAL", ""), +) + +TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig( + "TITLE_GENERATION_PROMPT_TEMPLATE", + "task.title.prompt_template", + os.environ.get( + "TITLE_GENERATION_PROMPT_TEMPLATE", + """Here is the query: +{{prompt:middletruncate:8000}} + +Create a concise, 3-5 word phrase with an emoji as a title for the previous query. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT. + +Examples of titles: +📉 Stock Market Trends +🍪 Perfect Chocolate Chip Recipe +Evolution of Music Streaming +Remote Work Productivity Tips +Artificial Intelligence in Healthcare +🎮 Video Game Development Insights""", + ), +) + + +SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig( + "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", + "task.search.prompt_template", + os.environ.get( + "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", + """You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is {{CURRENT_DATE}}. + +Question: +{{prompt:end:4000}}""", + ), +) + +SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig( + "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD", + "task.search.prompt_length_threshold", + int( + os.environ.get( + "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD", + 100, + ) + ), +) + +TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", + "task.tools.prompt_template", + os.environ.get( + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", + """Tools: {{TOOLS}} +If a function tool doesn't match the query, return an empty string. Else, pick a function tool, fill in the parameters from the function tool's schema, and return it in the format { "name": \"functionName\", "parameters": { "key": "value" } }. Only pick a function if the user asks. Only return the object. Do not return any other text.""", + ), +) + + #################################### # WEBUI_SECRET_KEY #################################### @@ -991,6 +1076,17 @@ SERPER_API_KEY = PersistentConfig( os.getenv("SERPER_API_KEY", ""), ) +SERPLY_API_KEY = PersistentConfig( + "SERPLY_API_KEY", + "rag.web.search.serply_api_key", + os.getenv("SERPLY_API_KEY", ""), +) + +TAVILY_API_KEY = PersistentConfig( + "TAVILY_API_KEY", + "rag.web.search.tavily_api_key", + os.getenv("TAVILY_API_KEY", ""), +) RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig( "RAG_WEB_SEARCH_RESULT_COUNT", @@ -1072,25 +1168,59 @@ IMAGE_GENERATION_MODEL = PersistentConfig( # Audio #################################### -AUDIO_OPENAI_API_BASE_URL = PersistentConfig( - "AUDIO_OPENAI_API_BASE_URL", - "audio.openai.api_base_url", - os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), +AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig( + "AUDIO_STT_OPENAI_API_BASE_URL", + "audio.stt.openai.api_base_url", + os.getenv("AUDIO_STT_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), ) -AUDIO_OPENAI_API_KEY = PersistentConfig( - "AUDIO_OPENAI_API_KEY", - "audio.openai.api_key", - os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY), + +AUDIO_STT_OPENAI_API_KEY = PersistentConfig( + "AUDIO_STT_OPENAI_API_KEY", + "audio.stt.openai.api_key", + os.getenv("AUDIO_STT_OPENAI_API_KEY", OPENAI_API_KEY), ) -AUDIO_OPENAI_API_MODEL = PersistentConfig( - "AUDIO_OPENAI_API_MODEL", - "audio.openai.api_model", - os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1"), + +AUDIO_STT_ENGINE = PersistentConfig( + "AUDIO_STT_ENGINE", + "audio.stt.engine", + os.getenv("AUDIO_STT_ENGINE", ""), ) -AUDIO_OPENAI_API_VOICE = PersistentConfig( - "AUDIO_OPENAI_API_VOICE", - "audio.openai.api_voice", - os.getenv("AUDIO_OPENAI_API_VOICE", "alloy"), + +AUDIO_STT_MODEL = PersistentConfig( + "AUDIO_STT_MODEL", + "audio.stt.model", + os.getenv("AUDIO_STT_MODEL", "whisper-1"), +) + +AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig( + "AUDIO_TTS_OPENAI_API_BASE_URL", + "audio.tts.openai.api_base_url", + os.getenv("AUDIO_TTS_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), +) +AUDIO_TTS_OPENAI_API_KEY = PersistentConfig( + "AUDIO_TTS_OPENAI_API_KEY", + "audio.tts.openai.api_key", + os.getenv("AUDIO_TTS_OPENAI_API_KEY", OPENAI_API_KEY), +) + + +AUDIO_TTS_ENGINE = PersistentConfig( + "AUDIO_TTS_ENGINE", + "audio.tts.engine", + os.getenv("AUDIO_TTS_ENGINE", ""), +) + + +AUDIO_TTS_MODEL = PersistentConfig( + "AUDIO_TTS_MODEL", + "audio.tts.model", + os.getenv("AUDIO_TTS_MODEL", "tts-1"), +) + +AUDIO_TTS_VOICE = PersistentConfig( + "AUDIO_TTS_VOICE", + "audio.tts.voice", + os.getenv("AUDIO_TTS_VOICE", "alloy"), ) diff --git a/backend/constants.py b/backend/constants.py index 0740fa49d..f1eed43d3 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -32,6 +32,7 @@ class ERROR_MESSAGES(str, Enum): COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string." FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file." + ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string." MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string." NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string." diff --git a/backend/main.py b/backend/main.py index bf36f559b..5076e91c5 100644 --- a/backend/main.py +++ b/backend/main.py @@ -13,8 +13,12 @@ import logging import aiohttp import requests import mimetypes +import shutil +import os +import inspect +import asyncio -from fastapi import FastAPI, Request, Depends, status +from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form from fastapi.staticfiles import StaticFiles from fastapi.responses import JSONResponse from fastapi import HTTPException @@ -27,21 +31,33 @@ from starlette.responses import StreamingResponse, Response, RedirectResponse from apps.socket.main import app as socket_app -from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models -from apps.openai.main import app as openai_app, get_all_models as get_openai_models +from apps.ollama.main import ( + app as ollama_app, + OpenAIChatCompletionForm, + get_all_models as get_ollama_models, + generate_openai_chat_completion as generate_ollama_chat_completion, +) +from apps.openai.main import ( + app as openai_app, + get_all_models as get_openai_models, + generate_chat_completion as generate_openai_chat_completion, +) from apps.audio.main import app as audio_app from apps.images.main import app as images_app from apps.rag.main import app as rag_app from apps.webui.main import app as webui_app -import asyncio + from pydantic import BaseModel from typing import List, Optional from apps.webui.models.auths import Auths -from apps.webui.models.models import Models +from apps.webui.models.models import Models, ModelModel +from apps.webui.models.tools import Tools from apps.webui.models.users import Users +from apps.webui.utils import load_toolkit_module_by_id + from utils.misc import parse_duration from utils.utils import ( get_admin_user, @@ -51,7 +67,14 @@ from utils.utils import ( get_password_hash, create_token, ) -from apps.rag.utils import rag_messages +from utils.task import ( + title_generation_template, + search_query_generation_template, + tools_function_calling_generation_template, +) +from utils.misc import get_last_user_message, add_or_update_system_message + +from apps.rag.utils import get_rag_context, rag_template from config import ( CONFIG_DATA, @@ -72,14 +95,20 @@ from config import ( SRC_LOG_LEVELS, WEBHOOK_URL, ENABLE_ADMIN_EXPORT, - AppConfig, WEBUI_BUILD_HASH, + TASK_MODEL, + TASK_MODEL_EXTERNAL, + TITLE_GENERATION_PROMPT_TEMPLATE, + SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, + SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, OAUTH_PROVIDERS, ENABLE_OAUTH_SIGNUP, OAUTH_MERGE_ACCOUNTS_BY_EMAIL, WEBUI_SECRET_KEY, WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE, + AppConfig, ) from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from utils.webhook import post_webhook @@ -134,27 +163,133 @@ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST - app.state.config.WEBHOOK_URL = WEBHOOK_URL +app.state.config.TASK_MODEL = TASK_MODEL +app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL +app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE +app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( + SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE +) +app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = ( + SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD +) +app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE +) + app.state.MODELS = {} origins = ["*"] -# Custom middleware to add security headers -# class SecurityHeadersMiddleware(BaseHTTPMiddleware): -# async def dispatch(self, request: Request, call_next): -# response: Response = await call_next(request) -# response.headers["Cross-Origin-Opener-Policy"] = "same-origin" -# response.headers["Cross-Origin-Embedder-Policy"] = "require-corp" -# return response + +async def get_function_call_response(messages, tool_id, template, task_model_id, user): + tool = Tools.get_tool_by_id(tool_id) + tools_specs = json.dumps(tool.specs, indent=2) + content = tools_function_calling_generation_template(template, tools_specs) + + user_message = get_last_user_message(messages) + prompt = ( + "History:\n" + + "\n".join( + [ + f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" + for message in messages[::-1][:4] + ] + ) + + f"\nQuery: {user_message}" + ) + + print(prompt) + + payload = { + "model": task_model_id, + "messages": [ + {"role": "system", "content": content}, + {"role": "user", "content": f"Query: {prompt}"}, + ], + "stream": False, + } + + try: + payload = filter_pipeline(payload, user) + except Exception as e: + raise e + + model = app.state.MODELS[task_model_id] + + response = None + try: + if model["owned_by"] == "ollama": + response = await generate_ollama_chat_completion( + OpenAIChatCompletionForm(**payload), user=user + ) + else: + response = await generate_openai_chat_completion(payload, user=user) + + content = None + + if hasattr(response, "body_iterator"): + async for chunk in response.body_iterator: + data = json.loads(chunk.decode("utf-8")) + content = data["choices"][0]["message"]["content"] + + # Cleanup any remaining background tasks if necessary + if response.background is not None: + await response.background() + else: + content = response["choices"][0]["message"]["content"] + + # Parse the function response + if content is not None: + print(f"content: {content}") + result = json.loads(content) + print(result) + + # Call the function + if "name" in result: + if tool_id in webui_app.state.TOOLS: + toolkit_module = webui_app.state.TOOLS[tool_id] + else: + toolkit_module = load_toolkit_module_by_id(tool_id) + webui_app.state.TOOLS[tool_id] = toolkit_module + + function = getattr(toolkit_module, result["name"]) + function_result = None + try: + # Get the signature of the function + sig = inspect.signature(function) + # Check if '__user__' is a parameter of the function + if "__user__" in sig.parameters: + # Call the function with the '__user__' parameter included + function_result = function( + **{ + **result["parameters"], + "__user__": { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + } + ) + else: + # Call the function without modifying the parameters + function_result = function(**result["parameters"]) + except Exception as e: + print(e) + + # Add the function result to the system prompt + if function_result: + return function_result + except Exception as e: + print(f"Error: {e}") + + return None -# app.add_middleware(SecurityHeadersMiddleware) - - -class RAGMiddleware(BaseHTTPMiddleware): +class ChatCompletionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): return_citations = False @@ -171,35 +306,98 @@ class RAGMiddleware(BaseHTTPMiddleware): # Parse string to JSON data = json.loads(body_str) if body_str else {} + user = get_current_user( + get_http_authorization_cred(request.headers.get("Authorization")) + ) + + # Remove the citations from the body return_citations = data.get("citations", False) if "citations" in data: del data["citations"] - # Example: Add a new key-value pair or modify existing ones - # data["modified"] = True # Example modification + # Set the task model + task_model_id = data["model"] + if task_model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + if app.state.MODELS[task_model_id]["owned_by"] == "ollama": + if ( + app.state.config.TASK_MODEL + and app.state.config.TASK_MODEL in app.state.MODELS + ): + task_model_id = app.state.config.TASK_MODEL + else: + if ( + app.state.config.TASK_MODEL_EXTERNAL + and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS + ): + task_model_id = app.state.config.TASK_MODEL_EXTERNAL + + prompt = get_last_user_message(data["messages"]) + context = "" + + # If tool_ids field is present, call the functions + if "tool_ids" in data: + print(data["tool_ids"]) + for tool_id in data["tool_ids"]: + print(tool_id) + try: + response = await get_function_call_response( + messages=data["messages"], + tool_id=tool_id, + template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + task_model_id=task_model_id, + user=user, + ) + + if response: + context += ("\n" if context != "" else "") + response + except Exception as e: + print(f"Error: {e}") + del data["tool_ids"] + + print(f"tool_context: {context}") + + # If docs field is present, generate RAG completions if "docs" in data: data = {**data} - data["messages"], citations = rag_messages( + rag_context, citations = get_rag_context( docs=data["docs"], messages=data["messages"], - template=rag_app.state.config.RAG_TEMPLATE, embedding_function=rag_app.state.EMBEDDING_FUNCTION, k=rag_app.state.config.TOP_K, reranking_function=rag_app.state.sentence_transformer_rf, r=rag_app.state.config.RELEVANCE_THRESHOLD, hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH, ) + + if rag_context: + context += ("\n" if context != "" else "") + rag_context + del data["docs"] - log.debug( - f"data['messages']: {data['messages']}, citations: {citations}" + log.debug(f"rag_context: {rag_context}, citations: {citations}") + + if context != "": + system_prompt = rag_template( + rag_app.state.config.RAG_TEMPLATE, context, prompt + ) + + print(system_prompt) + + data["messages"] = add_or_update_system_message( + f"\n{system_prompt}", data["messages"] ) modified_body_bytes = json.dumps(data).encode("utf-8") # Replace the request body with the modified one request._body = modified_body_bytes - # Set custom header to ensure content-length matches new body length request.headers.__dict__["_list"] = [ (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), @@ -242,7 +440,80 @@ class RAGMiddleware(BaseHTTPMiddleware): yield data -app.add_middleware(RAGMiddleware) +app.add_middleware(ChatCompletionMiddleware) + + +def filter_pipeline(payload, user): + user = {"id": user.id, "name": user.name, "role": user.role} + model_id = payload["model"] + filters = [ + model + for model in app.state.MODELS.values() + if "pipeline" in model + and "type" in model["pipeline"] + and model["pipeline"]["type"] == "filter" + and ( + model["pipeline"]["pipelines"] == ["*"] + or any( + model_id == target_model_id + for target_model_id in model["pipeline"]["pipelines"] + ) + ) + ] + sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) + + model = app.state.MODELS[model_id] + + if "pipeline" in model: + sorted_filters.append(model) + + for filter in sorted_filters: + r = None + try: + urlIdx = filter["urlIdx"] + + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + if key != "": + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/{filter['id']}/filter/inlet", + headers=headers, + json={ + "user": user, + "body": payload, + }, + ) + + r.raise_for_status() + payload = r.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + if r is not None: + try: + res = r.json() + except: + pass + if "detail" in res: + raise Exception(r.status_code, res["detail"]) + + else: + pass + + if "pipeline" not in app.state.MODELS[model_id]: + if "chat_id" in payload: + del payload["chat_id"] + + if "title" in payload: + del payload["title"] + + if "task" in payload: + del payload["task"] + + return payload class PipelineMiddleware(BaseHTTPMiddleware): @@ -260,85 +531,17 @@ class PipelineMiddleware(BaseHTTPMiddleware): # Parse string to JSON data = json.loads(body_str) if body_str else {} - model_id = data["model"] - filters = [ - model - for model in app.state.MODELS.values() - if "pipeline" in model - and "type" in model["pipeline"] - and model["pipeline"]["type"] == "filter" - and ( - model["pipeline"]["pipelines"] == ["*"] - or any( - model_id == target_model_id - for target_model_id in model["pipeline"]["pipelines"] - ) + user = get_current_user( + get_http_authorization_cred(request.headers.get("Authorization")) + ) + + try: + data = filter_pipeline(data, user) + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, ) - ] - sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) - - user = None - if len(sorted_filters) > 0: - try: - user = get_current_user( - get_http_authorization_cred( - request.headers.get("Authorization") - ) - ) - user = {"id": user.id, "name": user.name, "role": user.role} - except: - pass - - model = app.state.MODELS[model_id] - - if "pipeline" in model: - sorted_filters.append(model) - - for filter in sorted_filters: - r = None - try: - urlIdx = filter["urlIdx"] - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - if key != "": - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/inlet", - headers=headers, - json={ - "user": user, - "body": data, - }, - ) - - r.raise_for_status() - data = r.json() - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - if r is not None: - try: - res = r.json() - if "detail" in res: - return JSONResponse( - status_code=r.status_code, - content=res, - ) - except: - pass - - else: - pass - - if "pipeline" not in app.state.MODELS[model_id]: - if "chat_id" in data: - del data["chat_id"] - - if "title" in data: - del data["title"] modified_body_bytes = json.dumps(data).encode("utf-8") # Replace the request body with the modified one @@ -499,6 +702,302 @@ async def get_models(user=Depends(get_verified_user)): return {"data": models} +@app.get("/api/task/config") +async def get_task_config(user=Depends(get_verified_user)): + return { + "TASK_MODEL": app.state.config.TASK_MODEL, + "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, + "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, + "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + } + + +class TaskConfigForm(BaseModel): + TASK_MODEL: Optional[str] + TASK_MODEL_EXTERNAL: Optional[str] + TITLE_GENERATION_PROMPT_TEMPLATE: str + SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str + SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str + + +@app.post("/api/task/config/update") +async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)): + app.state.config.TASK_MODEL = form_data.TASK_MODEL + app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL + app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = ( + form_data.TITLE_GENERATION_PROMPT_TEMPLATE + ) + app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( + form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE + ) + app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = ( + form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD + ) + app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( + form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + ) + + return { + "TASK_MODEL": app.state.config.TASK_MODEL, + "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, + "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, + "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, + "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + } + + +@app.post("/api/task/title/completions") +async def generate_title(form_data: dict, user=Depends(get_verified_user)): + print("generate_title") + + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + if app.state.MODELS[model_id]["owned_by"] == "ollama": + if app.state.config.TASK_MODEL: + task_model_id = app.state.config.TASK_MODEL + if task_model_id in app.state.MODELS: + model_id = task_model_id + else: + if app.state.config.TASK_MODEL_EXTERNAL: + task_model_id = app.state.config.TASK_MODEL_EXTERNAL + if task_model_id in app.state.MODELS: + model_id = task_model_id + + print(model_id) + model = app.state.MODELS[model_id] + + template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE + + content = title_generation_template( + template, form_data["prompt"], user.model_dump() + ) + + payload = { + "model": model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "max_tokens": 50, + "chat_id": form_data.get("chat_id", None), + "title": True, + } + + print(payload) + + try: + payload = filter_pipeline(payload, user) + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + + if model["owned_by"] == "ollama": + return await generate_ollama_chat_completion( + OpenAIChatCompletionForm(**payload), user=user + ) + else: + return await generate_openai_chat_completion(payload, user=user) + + +@app.post("/api/task/query/completions") +async def generate_search_query(form_data: dict, user=Depends(get_verified_user)): + print("generate_search_query") + + if len(form_data["prompt"]) < app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Skip search query generation for short prompts (< {app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD} characters)", + ) + + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + if app.state.MODELS[model_id]["owned_by"] == "ollama": + if app.state.config.TASK_MODEL: + task_model_id = app.state.config.TASK_MODEL + if task_model_id in app.state.MODELS: + model_id = task_model_id + else: + if app.state.config.TASK_MODEL_EXTERNAL: + task_model_id = app.state.config.TASK_MODEL_EXTERNAL + if task_model_id in app.state.MODELS: + model_id = task_model_id + + print(model_id) + model = app.state.MODELS[model_id] + + template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE + + content = search_query_generation_template( + template, form_data["prompt"], user.model_dump() + ) + + payload = { + "model": model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "max_tokens": 30, + "task": True, + } + + print(payload) + + try: + payload = filter_pipeline(payload, user) + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + + if model["owned_by"] == "ollama": + return await generate_ollama_chat_completion( + OpenAIChatCompletionForm(**payload), user=user + ) + else: + return await generate_openai_chat_completion(payload, user=user) + + +@app.post("/api/task/emoji/completions") +async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): + print("generate_emoji") + + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + if app.state.MODELS[model_id]["owned_by"] == "ollama": + if app.state.config.TASK_MODEL: + task_model_id = app.state.config.TASK_MODEL + if task_model_id in app.state.MODELS: + model_id = task_model_id + else: + if app.state.config.TASK_MODEL_EXTERNAL: + task_model_id = app.state.config.TASK_MODEL_EXTERNAL + if task_model_id in app.state.MODELS: + model_id = task_model_id + + print(model_id) + model = app.state.MODELS[model_id] + + template = ''' +Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). + +Message: """{{prompt}}""" +''' + + content = title_generation_template( + template, form_data["prompt"], user.model_dump() + ) + + payload = { + "model": model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "max_tokens": 4, + "chat_id": form_data.get("chat_id", None), + "task": True, + } + + print(payload) + + try: + payload = filter_pipeline(payload, user) + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + + if model["owned_by"] == "ollama": + return await generate_ollama_chat_completion( + OpenAIChatCompletionForm(**payload), user=user + ) + else: + return await generate_openai_chat_completion(payload, user=user) + + +@app.post("/api/task/tools/completions") +async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)): + print("get_tools_function_calling") + + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + if app.state.MODELS[model_id]["owned_by"] == "ollama": + if app.state.config.TASK_MODEL: + task_model_id = app.state.config.TASK_MODEL + if task_model_id in app.state.MODELS: + model_id = task_model_id + else: + if app.state.config.TASK_MODEL_EXTERNAL: + task_model_id = app.state.config.TASK_MODEL_EXTERNAL + if task_model_id in app.state.MODELS: + model_id = task_model_id + + print(model_id) + template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + + try: + context = await get_function_call_response( + form_data["messages"], form_data["tool_id"], template, model_id, user + ) + return context + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + + +@app.post("/api/chat/completions") +async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)): + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + model = app.state.MODELS[model_id] + print(model) + + if model["owned_by"] == "ollama": + return await generate_ollama_chat_completion( + OpenAIChatCompletionForm(**form_data), user=user + ) + else: + return await generate_openai_chat_completion(form_data, user=user) + + @app.post("/api/chat/completed") async def chat_completed(form_data: dict, user=Depends(get_verified_user)): data = form_data @@ -591,6 +1090,63 @@ async def get_pipelines_list(user=Depends(get_admin_user)): } +@app.post("/api/pipelines/upload") +async def upload_pipeline( + urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user) +): + print("upload_pipeline", urlIdx, file.filename) + # Check if the uploaded file is a python file + if not file.filename.endswith(".py"): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Only Python (.py) files are allowed.", + ) + + upload_folder = f"{CACHE_DIR}/pipelines" + os.makedirs(upload_folder, exist_ok=True) + file_path = os.path.join(upload_folder, file.filename) + + try: + # Save the uploaded file + with open(file_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + headers = {"Authorization": f"Bearer {key}"} + + with open(file_path, "rb") as f: + files = {"file": f} + r = requests.post(f"{url}/pipelines/upload", headers=headers, files=files) + + r.raise_for_status() + data = r.json() + + return {**data} + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + detail = "Pipeline not found" + if r is not None: + try: + res = r.json() + if "detail" in res: + detail = res["detail"] + except: + pass + + raise HTTPException( + status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + detail=detail, + ) + finally: + # Ensure the file is deleted after the upload is completed or on failure + if os.path.exists(file_path): + os.remove(file_path) + + class AddPipelineForm(BaseModel): url: str urlIdx: int @@ -857,6 +1413,15 @@ async def get_app_config(): "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING, "enable_admin_export": ENABLE_ADMIN_EXPORT, }, + "audio": { + "tts": { + "engine": audio_app.state.config.TTS_ENGINE, + "voice": audio_app.state.config.TTS_VOICE, + }, + "stt": { + "engine": audio_app.state.config.STT_ENGINE, + }, + }, "oauth": { "providers": { name: config.get("name", name) @@ -925,7 +1490,7 @@ async def get_app_changelog(): @app.get("/api/version/updates") async def get_app_latest_release_version(): try: - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(trust_env=True) as session: async with session.get( "https://api.github.com/repos/open-webui/open-webui/releases/latest" ) as response: diff --git a/backend/requirements.txt b/backend/requirements.txt index 8687c7e43..53e21826c 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -57,4 +57,8 @@ authlib==1.3.0 black==24.4.2 langfuse==2.33.0 youtube-transcript-api==0.6.2 -pytube==15.0.0 \ No newline at end of file +pytube==15.0.0 + +extract_msg +pydub +duckduckgo-search~=6.1.5 \ No newline at end of file diff --git a/backend/start.sh b/backend/start.sh index 15fc568d3..16a004e45 100755 --- a/backend/start.sh +++ b/backend/start.sh @@ -20,12 +20,12 @@ if test "$WEBUI_SECRET_KEY $WEBUI_JWT_SECRET_KEY" = " "; then WEBUI_SECRET_KEY=$(cat "$KEY_FILE") fi -if [ "$USE_OLLAMA_DOCKER" = "true" ]; then +if [[ "${USE_OLLAMA_DOCKER,,}" == "true" ]]; then echo "USE_OLLAMA is set to true, starting ollama serve." ollama serve & fi -if [ "$USE_CUDA_DOCKER" = "true" ]; then +if [[ "${USE_CUDA_DOCKER,,}" == "true" ]]; then echo "CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries." export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.11/site-packages/torch/lib:/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib" fi diff --git a/backend/start_windows.bat b/backend/start_windows.bat index d56c91916..b2498f9c2 100644 --- a/backend/start_windows.bat +++ b/backend/start_windows.bat @@ -8,6 +8,7 @@ cd /d "%SCRIPT_DIR%" || exit /b SET "KEY_FILE=.webui_secret_key" IF "%PORT%"=="" SET PORT=8080 +IF "%HOST%"=="" SET HOST=0.0.0.0 SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%" SET "WEBUI_JWT_SECRET_KEY=%WEBUI_JWT_SECRET_KEY%" @@ -29,4 +30,4 @@ IF "%WEBUI_SECRET_KEY%%WEBUI_JWT_SECRET_KEY%" == " " ( :: Execute uvicorn SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%" -uvicorn main:app --host 0.0.0.0 --port "%PORT%" --forwarded-allow-ips '*' +uvicorn main:app --host "%HOST%" --port "%PORT%" --forwarded-allow-ips '*' diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 9069857b7..c3c65d3f5 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -3,7 +3,48 @@ import hashlib import json import re from datetime import timedelta -from typing import Optional +from typing import Optional, List + + +def get_last_user_message(messages: List[dict]) -> str: + for message in reversed(messages): + if message["role"] == "user": + if isinstance(message["content"], list): + for item in message["content"]: + if item["type"] == "text": + return item["text"] + return message["content"] + return None + + +def get_last_assistant_message(messages: List[dict]) -> str: + for message in reversed(messages): + if message["role"] == "assistant": + if isinstance(message["content"], list): + for item in message["content"]: + if item["type"] == "text": + return item["text"] + return message["content"] + return None + + +def add_or_update_system_message(content: str, messages: List[dict]): + """ + Adds a new system message at the beginning of the messages list + or updates the existing system message at the beginning. + + :param msg: The message to be added or appended. + :param messages: The list of message dictionaries. + :return: The updated list of message dictionaries. + """ + + if messages and messages[0].get("role") == "system": + messages[0]["content"] += f"{content}\n{messages[0]['content']}" + else: + # Insert at the beginning + messages.insert(0, {"role": "system", "content": content}) + + return messages def get_gravatar_url(email): @@ -193,8 +234,14 @@ def parse_ollama_modelfile(model_text): system_desc_match = re.search( r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE ) + system_desc_match_single = re.search( + r"SYSTEM\s+([^\n]+)", model_text, re.IGNORECASE + ) + if system_desc_match: data["params"]["system"] = system_desc_match.group(1).strip() + elif system_desc_match_single: + data["params"]["system"] = system_desc_match_single.group(1).strip() # Parse messages messages = [] diff --git a/backend/utils/models.py b/backend/utils/models.py deleted file mode 100644 index c4d675d29..000000000 --- a/backend/utils/models.py +++ /dev/null @@ -1,10 +0,0 @@ -from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse - - -def get_model_id_from_custom_model_id(id: str): - model = Models.get_model_by_id(id) - - if model: - return model.id - else: - return id diff --git a/backend/utils/task.py b/backend/utils/task.py new file mode 100644 index 000000000..615febcdc --- /dev/null +++ b/backend/utils/task.py @@ -0,0 +1,117 @@ +import re +import math + +from datetime import datetime +from typing import Optional + + +def prompt_template( + template: str, user_name: str = None, current_location: str = None +) -> str: + # Get the current date + current_date = datetime.now() + + # Format the date to YYYY-MM-DD + formatted_date = current_date.strftime("%Y-%m-%d") + + # Replace {{CURRENT_DATE}} in the template with the formatted date + template = template.replace("{{CURRENT_DATE}}", formatted_date) + + if user_name: + # Replace {{USER_NAME}} in the template with the user's name + template = template.replace("{{USER_NAME}}", user_name) + + if current_location: + # Replace {{CURRENT_LOCATION}} in the template with the current location + template = template.replace("{{CURRENT_LOCATION}}", current_location) + + return template + + +def title_generation_template( + template: str, prompt: str, user: Optional[dict] = None +) -> str: + def replacement_function(match): + full_match = match.group(0) + start_length = match.group(1) + end_length = match.group(2) + middle_length = match.group(3) + + if full_match == "{{prompt}}": + return prompt + elif start_length is not None: + return prompt[: int(start_length)] + elif end_length is not None: + return prompt[-int(end_length) :] + elif middle_length is not None: + middle_length = int(middle_length) + if len(prompt) <= middle_length: + return prompt + start = prompt[: math.ceil(middle_length / 2)] + end = prompt[-math.floor(middle_length / 2) :] + return f"{start}...{end}" + return "" + + template = re.sub( + r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}", + replacement_function, + template, + ) + + template = prompt_template( + template, + **( + {"user_name": user.get("name"), "current_location": user.get("location")} + if user + else {} + ), + ) + + return template + + +def search_query_generation_template( + template: str, prompt: str, user: Optional[dict] = None +) -> str: + + def replacement_function(match): + full_match = match.group(0) + start_length = match.group(1) + end_length = match.group(2) + middle_length = match.group(3) + + if full_match == "{{prompt}}": + return prompt + elif start_length is not None: + return prompt[: int(start_length)] + elif end_length is not None: + return prompt[-int(end_length) :] + elif middle_length is not None: + middle_length = int(middle_length) + if len(prompt) <= middle_length: + return prompt + start = prompt[: math.ceil(middle_length / 2)] + end = prompt[-math.floor(middle_length / 2) :] + return f"{start}...{end}" + return "" + + template = re.sub( + r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}", + replacement_function, + template, + ) + + template = prompt_template( + template, + **( + {"user_name": user.get("name"), "current_location": user.get("location")} + if user + else {} + ), + ) + return template + + +def tools_function_calling_generation_template(template: str, tools_specs: str) -> str: + template = template.replace("{{TOOLS}}", tools_specs) + return template diff --git a/backend/utils/tools.py b/backend/utils/tools.py new file mode 100644 index 000000000..5fef2a2b6 --- /dev/null +++ b/backend/utils/tools.py @@ -0,0 +1,73 @@ +import inspect +from typing import get_type_hints, List, Dict, Any + + +def doc_to_dict(docstring): + lines = docstring.split("\n") + description = lines[1].strip() + param_dict = {} + + for line in lines: + if ":param" in line: + line = line.replace(":param", "").strip() + param, desc = line.split(":", 1) + param_dict[param.strip()] = desc.strip() + ret_dict = {"description": description, "params": param_dict} + return ret_dict + + +def get_tools_specs(tools) -> List[dict]: + function_list = [ + {"name": func, "function": getattr(tools, func)} + for func in dir(tools) + if callable(getattr(tools, func)) and not func.startswith("__") + ] + + specs = [] + for function_item in function_list: + function_name = function_item["name"] + function = function_item["function"] + + function_doc = doc_to_dict(function.__doc__ or function_name) + specs.append( + { + "name": function_name, + # TODO: multi-line desc? + "description": function_doc.get("description", function_name), + "parameters": { + "type": "object", + "properties": { + param_name: { + "type": param_annotation.__name__.lower(), + **( + { + "enum": ( + str(param_annotation.__args__) + if hasattr(param_annotation, "__args__") + else None + ) + } + if hasattr(param_annotation, "__args__") + else {} + ), + "description": function_doc.get("params", {}).get( + param_name, param_name + ), + } + for param_name, param_annotation in get_type_hints( + function + ).items() + if param_name != "return" and param_name != "__user__" + }, + "required": [ + name + for name, param in inspect.signature( + function + ).parameters.items() + if param.default is param.empty + ], + }, + } + ) + + return specs diff --git a/cypress/e2e/settings.cy.ts b/cypress/e2e/settings.cy.ts index 5db232faa..4ea916980 100644 --- a/cypress/e2e/settings.cy.ts +++ b/cypress/e2e/settings.cy.ts @@ -28,19 +28,6 @@ describe('Settings', () => { }); }); - context('Connections', () => { - it('user can open the Connections modal and hit save', () => { - cy.get('button').contains('Connections').click(); - cy.get('button').contains('Save').click(); - }); - }); - - context('Models', () => { - it('user can open the Models modal', () => { - cy.get('button').contains('Models').click(); - }); - }); - context('Interface', () => { it('user can open the Interface modal and hit save', () => { cy.get('button').contains('Interface').click(); @@ -55,14 +42,6 @@ describe('Settings', () => { }); }); - context('Images', () => { - it('user can open the Images modal and hit save', () => { - cy.get('button').contains('Images').click(); - // Currently fails because the backend requires a valid URL - // cy.get('button').contains('Save').click(); - }); - }); - context('Chats', () => { it('user can open the Chats modal', () => { cy.get('button').contains('Chats').click(); diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index 92238d307..325964b1a 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -41,7 +41,7 @@ Looking to contribute? Great! Here's how you can help: We welcome pull requests. Before submitting one, please: -1. Discuss your idea or issue in the [issues section](https://github.com/open-webui/open-webui/issues). +1. Open a discussion regarding your ideas [here](https://github.com/open-webui/open-webui/discussions/new/choose). 2. Follow the project's coding standards and include tests for new features. 3. Update documentation as necessary. 4. Write clear, descriptive commit messages. diff --git a/package-lock.json b/package-lock.json index 7d3b385e2..f5b9d6a78 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,17 +1,21 @@ { "name": "open-webui", - "version": "0.2.5", + "version": "0.3.4", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.2.5", + "version": "0.3.4", "dependencies": { + "@codemirror/lang-javascript": "^6.2.2", + "@codemirror/lang-python": "^6.1.6", + "@codemirror/theme-one-dark": "^6.1.2", "@pyscript/core": "^0.4.32", "@sveltejs/adapter-node": "^1.3.1", "async": "^3.2.5", "bits-ui": "^0.19.7", + "codemirror": "^6.0.1", "dayjs": "^1.11.10", "eventsource-parser": "^1.1.2", "file-saver": "^2.0.5", @@ -108,6 +112,119 @@ "resolved": "https://registry.npmjs.org/@braintree/sanitize-url/-/sanitize-url-6.0.4.tgz", "integrity": "sha512-s3jaWicZd0pkP0jf5ysyHUI/RE7MHos6qlToFcGWXVp+ykHOy77OUMrfbgJ9it2C5bow7OIQwYYaHjk9XlBQ2A==" }, + "node_modules/@codemirror/autocomplete": { + "version": "6.16.2", + "resolved": "https://registry.npmjs.org/@codemirror/autocomplete/-/autocomplete-6.16.2.tgz", + "integrity": "sha512-MjfDrHy0gHKlPWsvSsikhO1+BOh+eBHNgfH1OXs1+DAf30IonQldgMM3kxLDTG9ktE7kDLaA1j/l7KMPA4KNfw==", + "dependencies": { + "@codemirror/language": "^6.0.0", + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.17.0", + "@lezer/common": "^1.0.0" + }, + "peerDependencies": { + "@codemirror/language": "^6.0.0", + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.0.0", + "@lezer/common": "^1.0.0" + } + }, + "node_modules/@codemirror/commands": { + "version": "6.6.0", + "resolved": "https://registry.npmjs.org/@codemirror/commands/-/commands-6.6.0.tgz", + "integrity": "sha512-qnY+b7j1UNcTS31Eenuc/5YJB6gQOzkUoNmJQc0rznwqSRpeaWWpjkWy2C/MPTcePpsKJEM26hXrOXl1+nceXg==", + "dependencies": { + "@codemirror/language": "^6.0.0", + "@codemirror/state": "^6.4.0", + "@codemirror/view": "^6.27.0", + "@lezer/common": "^1.1.0" + } + }, + "node_modules/@codemirror/lang-javascript": { + "version": "6.2.2", + "resolved": "https://registry.npmjs.org/@codemirror/lang-javascript/-/lang-javascript-6.2.2.tgz", + "integrity": "sha512-VGQfY+FCc285AhWuwjYxQyUQcYurWlxdKYT4bqwr3Twnd5wP5WSeu52t4tvvuWmljT4EmgEgZCqSieokhtY8hg==", + "dependencies": { + "@codemirror/autocomplete": "^6.0.0", + "@codemirror/language": "^6.6.0", + "@codemirror/lint": "^6.0.0", + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.17.0", + "@lezer/common": "^1.0.0", + "@lezer/javascript": "^1.0.0" + } + }, + "node_modules/@codemirror/lang-python": { + "version": "6.1.6", + "resolved": "https://registry.npmjs.org/@codemirror/lang-python/-/lang-python-6.1.6.tgz", + "integrity": "sha512-ai+01WfZhWqM92UqjnvorkxosZ2aq2u28kHvr+N3gu012XqY2CThD67JPMHnGceRfXPDBmn1HnyqowdpF57bNg==", + "dependencies": { + "@codemirror/autocomplete": "^6.3.2", + "@codemirror/language": "^6.8.0", + "@codemirror/state": "^6.0.0", + "@lezer/common": "^1.2.1", + "@lezer/python": "^1.1.4" + } + }, + "node_modules/@codemirror/language": { + "version": "6.10.2", + "resolved": "https://registry.npmjs.org/@codemirror/language/-/language-6.10.2.tgz", + "integrity": "sha512-kgbTYTo0Au6dCSc/TFy7fK3fpJmgHDv1sG1KNQKJXVi+xBTEeBPY/M30YXiU6mMXeH+YIDLsbrT4ZwNRdtF+SA==", + "dependencies": { + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.23.0", + "@lezer/common": "^1.1.0", + "@lezer/highlight": "^1.0.0", + "@lezer/lr": "^1.0.0", + "style-mod": "^4.0.0" + } + }, + "node_modules/@codemirror/lint": { + "version": "6.8.0", + "resolved": "https://registry.npmjs.org/@codemirror/lint/-/lint-6.8.0.tgz", + "integrity": "sha512-lsFofvaw0lnPRJlQylNsC4IRt/1lI4OD/yYslrSGVndOJfStc58v+8p9dgGiD90ktOfL7OhBWns1ZETYgz0EJA==", + "dependencies": { + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.0.0", + "crelt": "^1.0.5" + } + }, + "node_modules/@codemirror/search": { + "version": "6.5.6", + "resolved": "https://registry.npmjs.org/@codemirror/search/-/search-6.5.6.tgz", + "integrity": "sha512-rpMgcsh7o0GuCDUXKPvww+muLA1pDJaFrpq/CCHtpQJYz8xopu4D1hPcKRoDD0YlF8gZaqTNIRa4VRBWyhyy7Q==", + "dependencies": { + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.0.0", + "crelt": "^1.0.5" + } + }, + "node_modules/@codemirror/state": { + "version": "6.4.1", + "resolved": "https://registry.npmjs.org/@codemirror/state/-/state-6.4.1.tgz", + "integrity": "sha512-QkEyUiLhsJoZkbumGZlswmAhA7CBU02Wrz7zvH4SrcifbsqwlXShVXg65f3v/ts57W3dqyamEriMhij1Z3Zz4A==" + }, + "node_modules/@codemirror/theme-one-dark": { + "version": "6.1.2", + "resolved": "https://registry.npmjs.org/@codemirror/theme-one-dark/-/theme-one-dark-6.1.2.tgz", + "integrity": "sha512-F+sH0X16j/qFLMAfbciKTxVOwkdAS336b7AXTKOZhy8BR3eH/RelsnLgLFINrpST63mmN2OuwUt0W2ndUgYwUA==", + "dependencies": { + "@codemirror/language": "^6.0.0", + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.0.0", + "@lezer/highlight": "^1.0.0" + } + }, + "node_modules/@codemirror/view": { + "version": "6.28.0", + "resolved": "https://registry.npmjs.org/@codemirror/view/-/view-6.28.0.tgz", + "integrity": "sha512-fo7CelaUDKWIyemw4b+J57cWuRkOu4SWCCPfNDkPvfWkGjM9D5racHQXr4EQeYCD6zEBIBxGCeaKkQo+ysl0gA==", + "dependencies": { + "@codemirror/state": "^6.4.0", + "style-mod": "^4.1.0", + "w3c-keyname": "^2.2.4" + } + }, "node_modules/@colors/colors": { "version": "1.5.0", "resolved": "https://registry.npmjs.org/@colors/colors/-/colors-1.5.0.tgz", @@ -825,6 +942,47 @@ "@jridgewell/sourcemap-codec": "^1.4.14" } }, + "node_modules/@lezer/common": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/@lezer/common/-/common-1.2.1.tgz", + "integrity": "sha512-yemX0ZD2xS/73llMZIK6KplkjIjf2EvAHcinDi/TfJ9hS25G0388+ClHt6/3but0oOxinTcQHJLDXh6w1crzFQ==" + }, + "node_modules/@lezer/highlight": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/@lezer/highlight/-/highlight-1.2.0.tgz", + "integrity": "sha512-WrS5Mw51sGrpqjlh3d4/fOwpEV2Hd3YOkp9DBt4k8XZQcoTHZFB7sx030A6OcahF4J1nDQAa3jXlTVVYH50IFA==", + "dependencies": { + "@lezer/common": "^1.0.0" + } + }, + "node_modules/@lezer/javascript": { + "version": "1.4.16", + "resolved": "https://registry.npmjs.org/@lezer/javascript/-/javascript-1.4.16.tgz", + "integrity": "sha512-84UXR3N7s11MPQHWgMnjb9571fr19MmXnr5zTv2XX0gHXXUvW3uPJ8GCjKrfTXmSdfktjRK0ayKklw+A13rk4g==", + "dependencies": { + "@lezer/common": "^1.2.0", + "@lezer/highlight": "^1.1.3", + "@lezer/lr": "^1.3.0" + } + }, + "node_modules/@lezer/lr": { + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/@lezer/lr/-/lr-1.4.1.tgz", + "integrity": "sha512-CHsKq8DMKBf9b3yXPDIU4DbH+ZJd/sJdYOW2llbW/HudP5u0VS6Bfq1hLYfgU7uAYGFIyGGQIsSOXGPEErZiJw==", + "dependencies": { + "@lezer/common": "^1.0.0" + } + }, + "node_modules/@lezer/python": { + "version": "1.1.14", + "resolved": "https://registry.npmjs.org/@lezer/python/-/python-1.1.14.tgz", + "integrity": "sha512-ykDOb2Ti24n76PJsSa4ZoDF0zH12BSw1LGfQXCYJhJyOGiFTfGaX0Du66Ze72R+u/P35U+O6I9m8TFXov1JzsA==", + "dependencies": { + "@lezer/common": "^1.2.0", + "@lezer/highlight": "^1.0.0", + "@lezer/lr": "^1.0.0" + } + }, "node_modules/@melt-ui/svelte": { "version": "0.76.0", "resolved": "https://registry.npmjs.org/@melt-ui/svelte/-/svelte-0.76.0.tgz", @@ -2224,12 +2382,12 @@ } }, "node_modules/braces": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", - "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", "dev": true, "dependencies": { - "fill-range": "^7.0.1" + "fill-range": "^7.1.1" }, "engines": { "node": ">=8" @@ -2769,6 +2927,20 @@ "plain-tag": "^0.1.3" } }, + "node_modules/codemirror": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/codemirror/-/codemirror-6.0.1.tgz", + "integrity": "sha512-J8j+nZ+CdWmIeFIGXEFbFPtpiYacFMDR8GlHK3IyHQJMCaVRfGx9NT+Hxivv1ckLWPvNdZqndbr/7lVhrf/Svg==", + "dependencies": { + "@codemirror/autocomplete": "^6.0.0", + "@codemirror/commands": "^6.0.0", + "@codemirror/language": "^6.0.0", + "@codemirror/lint": "^6.0.0", + "@codemirror/search": "^6.0.0", + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.0.0" + } + }, "node_modules/coincident": { "version": "1.2.3", "resolved": "https://registry.npmjs.org/coincident/-/coincident-1.2.3.tgz", @@ -2891,6 +3063,11 @@ "layout-base": "^1.0.0" } }, + "node_modules/crelt": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/crelt/-/crelt-1.0.6.tgz", + "integrity": "sha512-VQ2MBenTq1fWZUH9DJNGti7kKv6EeAuYr3cLwxUWhIu1baTaXh4Ib5W2CqHVqib4/MqbYGJqiL3Zb8GJZr3l4g==" + }, "node_modules/cross-spawn": { "version": "7.0.3", "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", @@ -4429,9 +4606,9 @@ "integrity": "sha512-P9bmyZ3h/PRG+Nzga+rbdI4OEpNDzAVyy74uVO9ATgzLK6VtAsYybF/+TOCvrc0MO793d6+42lLyZTw7/ArVzA==" }, "node_modules/fill-range": { - "version": "7.0.1", - "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", - "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", "dev": true, "dependencies": { "to-regex-range": "^5.0.1" @@ -8278,6 +8455,11 @@ "url": "https://github.com/sponsors/antfu" } }, + "node_modules/style-mod": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/style-mod/-/style-mod-4.1.2.tgz", + "integrity": "sha512-wnD1HyVqpJUI2+eKZ+eo1UwghftP6yuFheBqqe+bWCotBjC2K1YnteJILRMs3SM4V/0dLEW1SC27MWP5y+mwmw==" + }, "node_modules/stylis": { "version": "4.3.2", "resolved": "https://registry.npmjs.org/stylis/-/stylis-4.3.2.tgz", @@ -10022,6 +10204,11 @@ "he": "^1.2.0" } }, + "node_modules/w3c-keyname": { + "version": "2.2.8", + "resolved": "https://registry.npmjs.org/w3c-keyname/-/w3c-keyname-2.2.8.tgz", + "integrity": "sha512-dpojBhNsCNN7T82Tm7k26A6G9ML3NkhDsnw9n/eoxSRlVBB4CEtIQ/KTCLI2Fwf3ataSXRhYFkQi3SlnFwPvPQ==" + }, "node_modules/walk-sync": { "version": "2.2.0", "resolved": "https://registry.npmjs.org/walk-sync/-/walk-sync-2.2.0.tgz", diff --git a/package.json b/package.json index 7ea3bf3c7..bf353ef7f 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.2.5", + "version": "0.3.4", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", @@ -16,7 +16,7 @@ "format:backend": "black . --exclude \".venv/|/venv/\"", "i18n:parse": "i18next --config i18next-parser.config.ts && prettier --write \"src/lib/i18n/**/*.{js,json}\"", "cy:open": "cypress open", - "test:frontend": "vitest", + "test:frontend": "vitest --passWithNoTests", "pyodide:fetch": "node scripts/prepare-pyodide.js" }, "devDependencies": { @@ -48,10 +48,14 @@ }, "type": "module", "dependencies": { + "@codemirror/lang-javascript": "^6.2.2", + "@codemirror/lang-python": "^6.1.6", + "@codemirror/theme-one-dark": "^6.1.2", "@pyscript/core": "^0.4.32", "@sveltejs/adapter-node": "^1.3.1", "async": "^3.2.5", "bits-ui": "^0.19.7", + "codemirror": "^6.0.1", "dayjs": "^1.11.10", "eventsource-parser": "^1.1.2", "file-saver": "^2.0.5", diff --git a/pyproject.toml b/pyproject.toml index bb4744136..80893b15b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,8 +26,6 @@ dependencies = [ "PyMySQL==1.1.0", "bcrypt==4.1.3", - "litellm[proxy]==1.37.20", - "boto3==1.34.110", "argon2-cffi==23.1.0", @@ -67,6 +65,10 @@ dependencies = [ "langfuse==2.33.0", "youtube-transcript-api==0.6.2", "pytube==15.0.0", + "extract_msg", + "pydub", + "duckduckgo-search~=6.1.5" + ] readme = "README.md" requires-python = ">= 3.11, < 3.12.0a1" diff --git a/requirements-dev.lock b/requirements-dev.lock index fa4f48e63..f7660eae3 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -12,7 +12,6 @@ aiohttp==3.9.5 # via langchain # via langchain-community - # via litellm # via open-webui aiosignal==1.3.1 # via aiohttp @@ -20,11 +19,9 @@ annotated-types==0.6.0 # via pydantic anyio==4.3.0 # via httpx - # via openai # via starlette # via watchfiles apscheduler==3.10.4 - # via litellm # via open-webui argon2-cffi==23.1.0 # via open-webui @@ -40,7 +37,6 @@ av==11.0.0 # via faster-whisper backoff==2.2.1 # via langfuse - # via litellm # via posthog # via unstructured bcrypt==4.1.3 @@ -48,6 +44,7 @@ bcrypt==4.1.3 # via open-webui # via passlib beautifulsoup4==4.12.3 + # via extract-msg # via unstructured bidict==0.23.1 # via python-socketio @@ -85,18 +82,21 @@ chromadb==0.5.0 # via open-webui click==8.1.7 # via black + # via duckduckgo-search # via flask - # via litellm # via nltk # via peewee-migrate - # via rq # via typer # via uvicorn +colorclass==2.2.2 + # via oletools coloredlogs==15.0.1 # via onnxruntime +compressed-rtf==1.0.6 + # via extract-msg cryptography==42.0.7 # via authlib - # via litellm + # via msoffcrypto-tool # via pyjwt ctranslate2==4.2.1 # via faster-whisper @@ -112,33 +112,34 @@ defusedxml==0.7.1 deprecated==1.2.14 # via opentelemetry-api # via opentelemetry-exporter-otlp-proto-grpc -distro==1.9.0 - # via openai dnspython==2.6.1 # via email-validator docx2txt==0.8 # via open-webui +duckduckgo-search==6.1.5 + # via open-webui +easygui==0.98.3 + # via oletools +ebcdic==1.1.1 + # via extract-msg ecdsa==0.19.0 # via python-jose email-validator==2.1.1 # via fastapi - # via pydantic emoji==2.11.1 # via unstructured et-xmlfile==1.1.0 # via openpyxl +extract-msg==0.48.5 + # via open-webui fake-useragent==1.5.1 # via open-webui fastapi==0.111.0 # via chromadb - # via fastapi-sso # via langchain-chroma - # via litellm # via open-webui fastapi-cli==0.0.4 # via fastapi -fastapi-sso==0.10.0 - # via litellm faster-whisper==1.0.2 # via open-webui filelock==3.14.0 @@ -194,8 +195,6 @@ grpcio==1.63.0 # via opentelemetry-exporter-otlp-proto-grpc grpcio-status==1.62.2 # via google-api-core -gunicorn==22.0.0 - # via litellm h11==0.14.0 # via httpcore # via uvicorn @@ -209,9 +208,7 @@ httptools==0.6.1 # via uvicorn httpx==0.27.0 # via fastapi - # via fastapi-sso # via langfuse - # via openai huggingface-hub==0.23.0 # via faster-whisper # via sentence-transformers @@ -228,7 +225,6 @@ idna==3.7 # via unstructured-client # via yarl importlib-metadata==7.0.0 - # via litellm # via opentelemetry-api importlib-resources==6.4.0 # via chromadb @@ -237,7 +233,6 @@ itsdangerous==2.2.0 jinja2==3.1.4 # via fastapi # via flask - # via litellm # via torch jmespath==1.0.1 # via boto3 @@ -275,8 +270,8 @@ langsmith==0.1.57 # via langchain # via langchain-community # via langchain-core -litellm==1.37.20 - # via open-webui +lark==1.1.8 + # via rtfde lxml==5.2.2 # via unstructured markdown==3.6 @@ -297,6 +292,8 @@ monotonic==1.6 # via posthog mpmath==1.3.0 # via sympy +msoffcrypto-tool==5.4.1 + # via oletools multidict==6.0.5 # via aiohttp # via yarl @@ -328,15 +325,19 @@ numpy==1.26.4 # via transformers # via unstructured oauthlib==3.2.2 - # via fastapi-sso # via kubernetes # via requests-oauthlib +olefile==0.47 + # via extract-msg + # via msoffcrypto-tool + # via oletools +oletools==0.60.1 + # via pcodedmp + # via rtfde onnxruntime==1.17.3 # via chromadb # via faster-whisper # via rapidocr-onnxruntime -openai==1.28.1 - # via litellm opencv-python==4.9.0.80 # via rapidocr-onnxruntime opencv-python-headless==4.9.0.80 @@ -378,15 +379,14 @@ ordered-set==4.1.0 # via deepdiff orjson==3.10.3 # via chromadb + # via duckduckgo-search # via fastapi # via langsmith - # via litellm overrides==7.7.0 # via chromadb packaging==23.2 # via black # via build - # via gunicorn # via huggingface-hub # via langchain-core # via langfuse @@ -398,8 +398,11 @@ pandas==2.2.2 # via open-webui passlib==1.7.4 # via open-webui + # via passlib pathspec==0.12.1 # via black +pcodedmp==1.2.6 + # via oletools peewee==3.17.5 # via open-webui # via peewee-migrate @@ -440,27 +443,28 @@ pycparser==2.22 pydantic==2.7.1 # via chromadb # via fastapi - # via fastapi-sso # via google-generativeai # via langchain # via langchain-core # via langfuse # via langsmith # via open-webui - # via openai pydantic-core==2.18.2 # via pydantic +pydub==0.25.1 + # via open-webui pygments==2.18.0 # via rich pyjwt==2.8.0 - # via litellm # via open-webui + # via pyjwt pymysql==1.1.0 # via open-webui pypandoc==1.13 # via open-webui -pyparsing==3.1.2 +pyparsing==2.4.7 # via httplib2 + # via oletools pypdf==4.2.0 # via open-webui # via unstructured-client @@ -468,6 +472,8 @@ pypika==0.48.9 # via chromadb pyproject-hooks==1.1.0 # via build +pyreqwest-impersonate==0.4.7 + # via duckduckgo-search python-dateutil==2.9.0.post0 # via botocore # via kubernetes @@ -475,7 +481,6 @@ python-dateutil==2.9.0.post0 # via posthog # via unstructured-client python-dotenv==1.0.1 - # via litellm # via uvicorn python-engineio==4.9.0 # via python-socketio @@ -487,7 +492,6 @@ python-magic==0.4.27 # via unstructured python-multipart==0.0.9 # via fastapi - # via litellm # via open-webui python-socketio==5.11.2 # via open-webui @@ -506,7 +510,6 @@ pyyaml==6.0.1 # via langchain # via langchain-community # via langchain-core - # via litellm # via rapidocr-onnxruntime # via transformers # via uvicorn @@ -516,11 +519,10 @@ rapidfuzz==3.9.0 # via unstructured rapidocr-onnxruntime==1.3.22 # via open-webui -redis==5.0.4 - # via rq +red-black-tree-mod==1.20 + # via extract-msg regex==2024.5.10 # via nltk - # via tiktoken # via transformers requests==2.32.2 # via chromadb @@ -530,11 +532,9 @@ requests==2.32.2 # via langchain # via langchain-community # via langsmith - # via litellm # via open-webui # via posthog # via requests-oauthlib - # via tiktoken # via transformers # via unstructured # via unstructured-client @@ -543,11 +543,11 @@ requests-oauthlib==2.0.0 # via kubernetes rich==13.7.1 # via typer -rq==1.16.2 - # via litellm rsa==4.9 # via google-auth # via python-jose +rtfde==0.1.1 + # via extract-msg s3transfer==0.10.1 # via boto3 safetensors==0.4.3 @@ -559,9 +559,6 @@ scipy==1.13.0 # via sentence-transformers sentence-transformers==2.7.0 # via open-webui -setuptools==69.5.1 - # via ctranslate2 - # via opentelemetry-instrumentation shapely==2.0.4 # via rapidocr-onnxruntime shellingham==1.5.4 @@ -580,7 +577,6 @@ six==1.16.0 sniffio==1.3.1 # via anyio # via httpx - # via openai soupsieve==2.5 # via beautifulsoup4 sqlalchemy==2.0.30 @@ -600,12 +596,9 @@ tenacity==8.3.0 # via langchain-core threadpoolctl==3.5.0 # via scikit-learn -tiktoken==0.6.0 - # via litellm tokenizers==0.15.2 # via chromadb # via faster-whisper - # via litellm # via transformers torch==2.3.0 # via sentence-transformers @@ -614,7 +607,6 @@ tqdm==4.66.4 # via google-generativeai # via huggingface-hub # via nltk - # via openai # via sentence-transformers # via transformers transformers==4.39.3 @@ -627,7 +619,6 @@ typing-extensions==4.11.0 # via fastapi # via google-generativeai # via huggingface-hub - # via openai # via opentelemetry-sdk # via pydantic # via pydantic-core @@ -644,6 +635,7 @@ tzdata==2024.1 # via pandas tzlocal==5.2 # via apscheduler + # via extract-msg ujson==5.10.0 # via fastapi unstructured==0.14.0 @@ -660,8 +652,8 @@ urllib3==2.2.1 uvicorn==0.22.0 # via chromadb # via fastapi - # via litellm # via open-webui + # via uvicorn uvloop==0.19.0 # via uvicorn validators==0.28.1 @@ -689,3 +681,6 @@ youtube-transcript-api==0.6.2 # via open-webui zipp==3.18.1 # via importlib-metadata +setuptools==69.5.1 + # via ctranslate2 + # via opentelemetry-instrumentation diff --git a/requirements.lock b/requirements.lock index fa4f48e63..f7660eae3 100644 --- a/requirements.lock +++ b/requirements.lock @@ -12,7 +12,6 @@ aiohttp==3.9.5 # via langchain # via langchain-community - # via litellm # via open-webui aiosignal==1.3.1 # via aiohttp @@ -20,11 +19,9 @@ annotated-types==0.6.0 # via pydantic anyio==4.3.0 # via httpx - # via openai # via starlette # via watchfiles apscheduler==3.10.4 - # via litellm # via open-webui argon2-cffi==23.1.0 # via open-webui @@ -40,7 +37,6 @@ av==11.0.0 # via faster-whisper backoff==2.2.1 # via langfuse - # via litellm # via posthog # via unstructured bcrypt==4.1.3 @@ -48,6 +44,7 @@ bcrypt==4.1.3 # via open-webui # via passlib beautifulsoup4==4.12.3 + # via extract-msg # via unstructured bidict==0.23.1 # via python-socketio @@ -85,18 +82,21 @@ chromadb==0.5.0 # via open-webui click==8.1.7 # via black + # via duckduckgo-search # via flask - # via litellm # via nltk # via peewee-migrate - # via rq # via typer # via uvicorn +colorclass==2.2.2 + # via oletools coloredlogs==15.0.1 # via onnxruntime +compressed-rtf==1.0.6 + # via extract-msg cryptography==42.0.7 # via authlib - # via litellm + # via msoffcrypto-tool # via pyjwt ctranslate2==4.2.1 # via faster-whisper @@ -112,33 +112,34 @@ defusedxml==0.7.1 deprecated==1.2.14 # via opentelemetry-api # via opentelemetry-exporter-otlp-proto-grpc -distro==1.9.0 - # via openai dnspython==2.6.1 # via email-validator docx2txt==0.8 # via open-webui +duckduckgo-search==6.1.5 + # via open-webui +easygui==0.98.3 + # via oletools +ebcdic==1.1.1 + # via extract-msg ecdsa==0.19.0 # via python-jose email-validator==2.1.1 # via fastapi - # via pydantic emoji==2.11.1 # via unstructured et-xmlfile==1.1.0 # via openpyxl +extract-msg==0.48.5 + # via open-webui fake-useragent==1.5.1 # via open-webui fastapi==0.111.0 # via chromadb - # via fastapi-sso # via langchain-chroma - # via litellm # via open-webui fastapi-cli==0.0.4 # via fastapi -fastapi-sso==0.10.0 - # via litellm faster-whisper==1.0.2 # via open-webui filelock==3.14.0 @@ -194,8 +195,6 @@ grpcio==1.63.0 # via opentelemetry-exporter-otlp-proto-grpc grpcio-status==1.62.2 # via google-api-core -gunicorn==22.0.0 - # via litellm h11==0.14.0 # via httpcore # via uvicorn @@ -209,9 +208,7 @@ httptools==0.6.1 # via uvicorn httpx==0.27.0 # via fastapi - # via fastapi-sso # via langfuse - # via openai huggingface-hub==0.23.0 # via faster-whisper # via sentence-transformers @@ -228,7 +225,6 @@ idna==3.7 # via unstructured-client # via yarl importlib-metadata==7.0.0 - # via litellm # via opentelemetry-api importlib-resources==6.4.0 # via chromadb @@ -237,7 +233,6 @@ itsdangerous==2.2.0 jinja2==3.1.4 # via fastapi # via flask - # via litellm # via torch jmespath==1.0.1 # via boto3 @@ -275,8 +270,8 @@ langsmith==0.1.57 # via langchain # via langchain-community # via langchain-core -litellm==1.37.20 - # via open-webui +lark==1.1.8 + # via rtfde lxml==5.2.2 # via unstructured markdown==3.6 @@ -297,6 +292,8 @@ monotonic==1.6 # via posthog mpmath==1.3.0 # via sympy +msoffcrypto-tool==5.4.1 + # via oletools multidict==6.0.5 # via aiohttp # via yarl @@ -328,15 +325,19 @@ numpy==1.26.4 # via transformers # via unstructured oauthlib==3.2.2 - # via fastapi-sso # via kubernetes # via requests-oauthlib +olefile==0.47 + # via extract-msg + # via msoffcrypto-tool + # via oletools +oletools==0.60.1 + # via pcodedmp + # via rtfde onnxruntime==1.17.3 # via chromadb # via faster-whisper # via rapidocr-onnxruntime -openai==1.28.1 - # via litellm opencv-python==4.9.0.80 # via rapidocr-onnxruntime opencv-python-headless==4.9.0.80 @@ -378,15 +379,14 @@ ordered-set==4.1.0 # via deepdiff orjson==3.10.3 # via chromadb + # via duckduckgo-search # via fastapi # via langsmith - # via litellm overrides==7.7.0 # via chromadb packaging==23.2 # via black # via build - # via gunicorn # via huggingface-hub # via langchain-core # via langfuse @@ -398,8 +398,11 @@ pandas==2.2.2 # via open-webui passlib==1.7.4 # via open-webui + # via passlib pathspec==0.12.1 # via black +pcodedmp==1.2.6 + # via oletools peewee==3.17.5 # via open-webui # via peewee-migrate @@ -440,27 +443,28 @@ pycparser==2.22 pydantic==2.7.1 # via chromadb # via fastapi - # via fastapi-sso # via google-generativeai # via langchain # via langchain-core # via langfuse # via langsmith # via open-webui - # via openai pydantic-core==2.18.2 # via pydantic +pydub==0.25.1 + # via open-webui pygments==2.18.0 # via rich pyjwt==2.8.0 - # via litellm # via open-webui + # via pyjwt pymysql==1.1.0 # via open-webui pypandoc==1.13 # via open-webui -pyparsing==3.1.2 +pyparsing==2.4.7 # via httplib2 + # via oletools pypdf==4.2.0 # via open-webui # via unstructured-client @@ -468,6 +472,8 @@ pypika==0.48.9 # via chromadb pyproject-hooks==1.1.0 # via build +pyreqwest-impersonate==0.4.7 + # via duckduckgo-search python-dateutil==2.9.0.post0 # via botocore # via kubernetes @@ -475,7 +481,6 @@ python-dateutil==2.9.0.post0 # via posthog # via unstructured-client python-dotenv==1.0.1 - # via litellm # via uvicorn python-engineio==4.9.0 # via python-socketio @@ -487,7 +492,6 @@ python-magic==0.4.27 # via unstructured python-multipart==0.0.9 # via fastapi - # via litellm # via open-webui python-socketio==5.11.2 # via open-webui @@ -506,7 +510,6 @@ pyyaml==6.0.1 # via langchain # via langchain-community # via langchain-core - # via litellm # via rapidocr-onnxruntime # via transformers # via uvicorn @@ -516,11 +519,10 @@ rapidfuzz==3.9.0 # via unstructured rapidocr-onnxruntime==1.3.22 # via open-webui -redis==5.0.4 - # via rq +red-black-tree-mod==1.20 + # via extract-msg regex==2024.5.10 # via nltk - # via tiktoken # via transformers requests==2.32.2 # via chromadb @@ -530,11 +532,9 @@ requests==2.32.2 # via langchain # via langchain-community # via langsmith - # via litellm # via open-webui # via posthog # via requests-oauthlib - # via tiktoken # via transformers # via unstructured # via unstructured-client @@ -543,11 +543,11 @@ requests-oauthlib==2.0.0 # via kubernetes rich==13.7.1 # via typer -rq==1.16.2 - # via litellm rsa==4.9 # via google-auth # via python-jose +rtfde==0.1.1 + # via extract-msg s3transfer==0.10.1 # via boto3 safetensors==0.4.3 @@ -559,9 +559,6 @@ scipy==1.13.0 # via sentence-transformers sentence-transformers==2.7.0 # via open-webui -setuptools==69.5.1 - # via ctranslate2 - # via opentelemetry-instrumentation shapely==2.0.4 # via rapidocr-onnxruntime shellingham==1.5.4 @@ -580,7 +577,6 @@ six==1.16.0 sniffio==1.3.1 # via anyio # via httpx - # via openai soupsieve==2.5 # via beautifulsoup4 sqlalchemy==2.0.30 @@ -600,12 +596,9 @@ tenacity==8.3.0 # via langchain-core threadpoolctl==3.5.0 # via scikit-learn -tiktoken==0.6.0 - # via litellm tokenizers==0.15.2 # via chromadb # via faster-whisper - # via litellm # via transformers torch==2.3.0 # via sentence-transformers @@ -614,7 +607,6 @@ tqdm==4.66.4 # via google-generativeai # via huggingface-hub # via nltk - # via openai # via sentence-transformers # via transformers transformers==4.39.3 @@ -627,7 +619,6 @@ typing-extensions==4.11.0 # via fastapi # via google-generativeai # via huggingface-hub - # via openai # via opentelemetry-sdk # via pydantic # via pydantic-core @@ -644,6 +635,7 @@ tzdata==2024.1 # via pandas tzlocal==5.2 # via apscheduler + # via extract-msg ujson==5.10.0 # via fastapi unstructured==0.14.0 @@ -660,8 +652,8 @@ urllib3==2.2.1 uvicorn==0.22.0 # via chromadb # via fastapi - # via litellm # via open-webui + # via uvicorn uvloop==0.19.0 # via uvicorn validators==0.28.1 @@ -689,3 +681,6 @@ youtube-transcript-api==0.6.2 # via open-webui zipp==3.18.1 # via importlib-metadata +setuptools==69.5.1 + # via ctranslate2 + # via opentelemetry-instrumentation diff --git a/src/app.css b/src/app.css index f7c14bcbd..baf620845 100644 --- a/src/app.css +++ b/src/app.css @@ -28,6 +28,10 @@ math { @apply rounded-lg; } +.markdown a { + @apply underline; +} + ol > li { counter-increment: list-number; display: block; @@ -92,10 +96,18 @@ select { visibility: hidden; } +.scrollbar-hidden::-webkit-scrollbar-corner { + display: none; +} + .scrollbar-none::-webkit-scrollbar { display: none; /* for Chrome, Safari and Opera */ } +.scrollbar-none::-webkit-scrollbar-corner { + display: none; +} + .scrollbar-none { -ms-overflow-style: none; /* IE and Edge */ scrollbar-width: none; /* Firefox */ @@ -111,3 +123,16 @@ input::-webkit-inner-spin-button { input[type='number'] { -moz-appearance: textfield; /* Firefox */ } + +.cm-editor { + height: 100%; + width: 100%; +} + +.cm-scroller { + @apply scrollbar-hidden; +} + +.cm-editor.cm-focused { + outline: none; +} diff --git a/src/app.html b/src/app.html index 138fb2829..a79343df5 100644 --- a/src/app.html +++ b/src/app.html @@ -32,6 +32,9 @@ } else if (localStorage.theme && localStorage.theme === 'system') { systemTheme = window.matchMedia('(prefers-color-scheme: dark)').matches; document.documentElement.classList.add(systemTheme ? 'dark' : 'light'); + } else if (localStorage.theme && localStorage.theme === 'her') { + document.documentElement.classList.add('dark'); + document.documentElement.classList.add('her'); } else { document.documentElement.classList.add('dark'); } @@ -59,15 +62,7 @@