2024-10-16 14:32:57 +00:00
import asyncio
2024-08-27 22:10:27 +00:00
import inspect
2024-02-23 08:30:26 +00:00
import json
2024-03-20 23:11:36 +00:00
import logging
2024-05-21 22:04:00 +00:00
import mimetypes
2024-08-27 22:10:27 +00:00
import os
2024-06-05 20:57:48 +00:00
import shutil
2024-08-27 22:10:27 +00:00
import sys
import time
2024-10-22 10:16:48 +00:00
import random
2024-12-12 02:36:59 +00:00
from typing import AsyncGenerator , Generator , Iterator
2024-08-27 22:10:27 +00:00
from contextlib import asynccontextmanager
2024-12-10 08:54:13 +00:00
from urllib . parse import urlencode , parse_qs , urlparse
from pydantic import BaseModel
from sqlalchemy import text
2024-02-23 08:30:26 +00:00
2024-12-10 08:54:13 +00:00
from typing import Optional
2024-11-16 12:41:07 +00:00
from aiocache import cached
2024-08-27 22:10:27 +00:00
import aiohttp
import requests
2024-10-16 14:58:03 +00:00
from fastapi import (
Depends ,
FastAPI ,
File ,
Form ,
HTTPException ,
Request ,
UploadFile ,
status ,
)
from fastapi . middleware . cors import CORSMiddleware
2024-10-22 10:16:48 +00:00
from fastapi . responses import JSONResponse , RedirectResponse
2024-10-16 14:58:03 +00:00
from fastapi . staticfiles import StaticFiles
2024-12-10 08:54:13 +00:00
2024-10-16 14:58:03 +00:00
from starlette . exceptions import HTTPException as StarletteHTTPException
from starlette . middleware . base import BaseHTTPMiddleware
from starlette . middleware . sessions import SessionMiddleware
from starlette . responses import Response , StreamingResponse
2024-09-04 14:54:48 +00:00
2024-12-10 08:54:13 +00:00
from open_webui . routers import (
audio ,
images ,
ollama ,
openai ,
retrieval ,
pipelines ,
tasks ,
2024-12-11 10:41:25 +00:00
auths ,
chats ,
folders ,
configs ,
groups ,
files ,
functions ,
memories ,
models ,
knowledge ,
prompts ,
evaluations ,
tools ,
users ,
utils ,
2024-09-04 14:54:48 +00:00
)
2024-12-10 08:54:13 +00:00
from open_webui . retrieval . utils import get_sources_from_files
2024-12-12 02:08:55 +00:00
from open_webui . routers . retrieval import (
get_embedding_function ,
update_embedding_model ,
update_reranking_model ,
)
2024-11-25 02:49:56 +00:00
2024-12-10 08:54:13 +00:00
from open_webui . socket . main import (
2024-09-27 23:27:46 +00:00
app as socket_app ,
periodic_usage_pool_cleanup ,
get_event_call ,
get_event_emitter ,
)
2024-12-10 08:54:13 +00:00
from open_webui . internal . db import Session
2024-12-12 02:08:55 +00:00
from open_webui . routers . webui import (
2024-09-27 23:27:46 +00:00
app as webui_app ,
2024-09-04 14:54:48 +00:00
generate_function_chat_completion ,
2024-10-21 11:14:49 +00:00
get_all_models as get_open_webui_models ,
2024-09-04 14:54:48 +00:00
)
2024-12-12 02:08:55 +00:00
2024-12-10 08:54:13 +00:00
from open_webui . models . functions import Functions
from open_webui . models . models import Models
from open_webui . models . users import UserModel , Users
2024-12-12 02:36:59 +00:00
from open_webui . utils . plugin import load_function_module_by_id
2024-12-10 08:54:13 +00:00
from open_webui . constants import TASKS
2024-09-04 14:54:48 +00:00
from open_webui . config import (
2024-12-10 08:54:13 +00:00
# Ollama
2024-08-27 22:10:27 +00:00
ENABLE_OLLAMA_API ,
2024-12-10 08:54:13 +00:00
OLLAMA_BASE_URLS ,
OLLAMA_API_CONFIGS ,
# OpenAI
2024-08-27 22:10:27 +00:00
ENABLE_OPENAI_API ,
2024-12-10 08:54:13 +00:00
OPENAI_API_BASE_URLS ,
OPENAI_API_KEYS ,
OPENAI_API_CONFIGS ,
# Image
AUTOMATIC1111_API_AUTH ,
AUTOMATIC1111_BASE_URL ,
AUTOMATIC1111_CFG_SCALE ,
AUTOMATIC1111_SAMPLER ,
AUTOMATIC1111_SCHEDULER ,
COMFYUI_BASE_URL ,
COMFYUI_WORKFLOW ,
COMFYUI_WORKFLOW_NODES ,
ENABLE_IMAGE_GENERATION ,
IMAGE_GENERATION_ENGINE ,
IMAGE_GENERATION_MODEL ,
IMAGE_SIZE ,
IMAGE_STEPS ,
IMAGES_OPENAI_API_BASE_URL ,
IMAGES_OPENAI_API_KEY ,
# Audio
AUDIO_STT_ENGINE ,
AUDIO_STT_MODEL ,
AUDIO_STT_OPENAI_API_BASE_URL ,
AUDIO_STT_OPENAI_API_KEY ,
AUDIO_TTS_API_KEY ,
AUDIO_TTS_ENGINE ,
AUDIO_TTS_MODEL ,
AUDIO_TTS_OPENAI_API_BASE_URL ,
AUDIO_TTS_OPENAI_API_KEY ,
AUDIO_TTS_SPLIT_ON ,
AUDIO_TTS_VOICE ,
AUDIO_TTS_AZURE_SPEECH_REGION ,
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT ,
WHISPER_MODEL ,
WHISPER_MODEL_AUTO_UPDATE ,
WHISPER_MODEL_DIR ,
2024-12-11 10:41:25 +00:00
# Retrieval
RAG_TEMPLATE ,
DEFAULT_RAG_TEMPLATE ,
RAG_EMBEDDING_MODEL ,
RAG_EMBEDDING_MODEL_AUTO_UPDATE ,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE ,
RAG_RERANKING_MODEL ,
RAG_RERANKING_MODEL_AUTO_UPDATE ,
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE ,
RAG_EMBEDDING_ENGINE ,
RAG_EMBEDDING_BATCH_SIZE ,
RAG_RELEVANCE_THRESHOLD ,
RAG_FILE_MAX_COUNT ,
RAG_FILE_MAX_SIZE ,
RAG_OPENAI_API_BASE_URL ,
RAG_OPENAI_API_KEY ,
RAG_OLLAMA_BASE_URL ,
RAG_OLLAMA_API_KEY ,
CHUNK_OVERLAP ,
CHUNK_SIZE ,
CONTENT_EXTRACTION_ENGINE ,
TIKA_SERVER_URL ,
RAG_TOP_K ,
RAG_TEXT_SPLITTER ,
TIKTOKEN_ENCODING_NAME ,
PDF_EXTRACT_IMAGES ,
YOUTUBE_LOADER_LANGUAGE ,
YOUTUBE_LOADER_PROXY_URL ,
# Retrieval (Web Search)
RAG_WEB_SEARCH_ENGINE ,
RAG_WEB_SEARCH_RESULT_COUNT ,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS ,
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST ,
JINA_API_KEY ,
SEARCHAPI_API_KEY ,
SEARCHAPI_ENGINE ,
SEARXNG_QUERY_URL ,
SERPER_API_KEY ,
SERPLY_API_KEY ,
SERPSTACK_API_KEY ,
SERPSTACK_HTTPS ,
TAVILY_API_KEY ,
BING_SEARCH_V7_ENDPOINT ,
BING_SEARCH_V7_SUBSCRIPTION_KEY ,
BRAVE_SEARCH_API_KEY ,
KAGI_SEARCH_API_KEY ,
MOJEEK_SEARCH_API_KEY ,
GOOGLE_PSE_API_KEY ,
GOOGLE_PSE_ENGINE_ID ,
ENABLE_RAG_HYBRID_SEARCH ,
ENABLE_RAG_LOCAL_WEB_FETCH ,
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION ,
ENABLE_RAG_WEB_SEARCH ,
UPLOAD_DIR ,
2024-12-10 08:54:13 +00:00
# WebUI
WEBUI_AUTH ,
WEBUI_NAME ,
WEBUI_BANNERS ,
WEBHOOK_URL ,
ADMIN_EMAIL ,
SHOW_ADMIN_DETAILS ,
JWT_EXPIRES_IN ,
ENABLE_SIGNUP ,
ENABLE_LOGIN_FORM ,
ENABLE_API_KEY ,
ENABLE_COMMUNITY_SHARING ,
ENABLE_MESSAGE_RATING ,
ENABLE_EVALUATION_ARENA_MODELS ,
USER_PERMISSIONS ,
DEFAULT_USER_ROLE ,
DEFAULT_PROMPT_SUGGESTIONS ,
DEFAULT_MODELS ,
DEFAULT_ARENA_MODEL ,
MODEL_ORDER_LIST ,
EVALUATION_ARENA_MODELS ,
# WebUI (OAuth)
ENABLE_OAUTH_ROLE_MANAGEMENT ,
OAUTH_ROLES_CLAIM ,
OAUTH_EMAIL_CLAIM ,
OAUTH_PICTURE_CLAIM ,
OAUTH_USERNAME_CLAIM ,
OAUTH_ALLOWED_ROLES ,
OAUTH_ADMIN_ROLES ,
# WebUI (LDAP)
ENABLE_LDAP ,
LDAP_SERVER_LABEL ,
LDAP_SERVER_HOST ,
LDAP_SERVER_PORT ,
LDAP_ATTRIBUTE_FOR_USERNAME ,
LDAP_SEARCH_FILTERS ,
LDAP_SEARCH_BASE ,
LDAP_APP_DN ,
LDAP_APP_PASSWORD ,
LDAP_USE_TLS ,
LDAP_CA_CERT_FILE ,
LDAP_CIPHERS ,
# Misc
2024-08-27 22:10:27 +00:00
ENV ,
2024-12-10 08:54:13 +00:00
CACHE_DIR ,
STATIC_DIR ,
2024-08-27 22:10:27 +00:00
FRONTEND_BUILD_DIR ,
2024-12-10 08:54:13 +00:00
CORS_ALLOW_ORIGIN ,
DEFAULT_LOCALE ,
2024-08-27 22:10:27 +00:00
OAUTH_PROVIDERS ,
2024-12-10 08:54:13 +00:00
# Admin
ENABLE_ADMIN_CHAT_ACCESS ,
ENABLE_ADMIN_EXPORT ,
# Tasks
2024-06-09 21:53:10 +00:00
TASK_MODEL ,
TASK_MODEL_EXTERNAL ,
2024-12-10 08:54:13 +00:00
ENABLE_TAGS_GENERATION ,
2024-11-19 10:24:32 +00:00
ENABLE_SEARCH_QUERY_GENERATION ,
ENABLE_RETRIEVAL_QUERY_GENERATION ,
2024-12-10 08:54:13 +00:00
ENABLE_AUTOCOMPLETE_GENERATION ,
2024-06-09 21:25:31 +00:00
TITLE_GENERATION_PROMPT_TEMPLATE ,
2024-10-20 04:27:10 +00:00
TAGS_GENERATION_PROMPT_TEMPLATE ,
2024-06-11 06:40:27 +00:00
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE ,
2024-12-10 08:54:13 +00:00
QUERY_GENERATION_PROMPT_TEMPLATE ,
AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE ,
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH ,
2024-08-27 22:10:27 +00:00
AppConfig ,
2024-09-24 23:06:11 +00:00
reset_config ,
2024-08-27 22:10:27 +00:00
)
2024-09-04 14:54:48 +00:00
from open_webui . env import (
2024-08-27 22:10:27 +00:00
CHANGELOG ,
GLOBAL_LOG_LEVEL ,
2024-06-24 02:28:33 +00:00
SAFE_MODE ,
2024-08-27 22:10:27 +00:00
SRC_LOG_LEVELS ,
VERSION ,
2024-12-10 08:54:13 +00:00
WEBUI_URL ,
2024-08-27 22:10:27 +00:00
WEBUI_BUILD_HASH ,
2024-05-27 17:07:38 +00:00
WEBUI_SECRET_KEY ,
2024-06-05 18:21:42 +00:00
WEBUI_SESSION_COOKIE_SAME_SITE ,
2024-06-07 08:13:42 +00:00
WEBUI_SESSION_COOKIE_SECURE ,
2024-12-10 08:54:13 +00:00
WEBUI_AUTH_TRUSTED_EMAIL_HEADER ,
WEBUI_AUTH_TRUSTED_NAME_HEADER ,
2024-12-02 02:25:44 +00:00
BYPASS_MODEL_ACCESS_CONTROL ,
2024-09-24 23:06:11 +00:00
RESET_CONFIG_ON_START ,
2024-10-08 05:13:49 +00:00
OFFLINE_MODE ,
2024-08-27 22:10:27 +00:00
)
2024-12-10 08:54:13 +00:00
2024-09-04 14:54:48 +00:00
from open_webui . utils . misc import (
2024-08-27 22:10:27 +00:00
add_or_update_system_message ,
get_last_user_message ,
prepend_to_first_user_message_content ,
2024-12-12 02:36:59 +00:00
openai_chat_chunk_message_template ,
openai_chat_completion_message_template ,
)
from open_webui . utils . payload import (
apply_model_params_to_body_openai ,
apply_model_system_prompt_to_body ,
2024-08-27 22:10:27 +00:00
)
2024-12-10 08:54:13 +00:00
2024-10-16 14:32:57 +00:00
from open_webui . utils . payload import convert_payload_openai_to_ollama
from open_webui . utils . response import (
convert_response_ollama_to_openai ,
convert_streaming_response_ollama_to_openai ,
)
2024-12-10 08:54:13 +00:00
2024-09-04 14:54:48 +00:00
from open_webui . utils . task import (
2024-11-25 02:49:56 +00:00
rag_template ,
2024-08-27 22:10:27 +00:00
tools_function_calling_generation_template ,
)
2024-09-04 14:54:48 +00:00
from open_webui . utils . tools import get_tools
2024-12-10 08:54:13 +00:00
from open_webui . utils . access_control import has_access
2024-12-09 00:01:56 +00:00
from open_webui . utils . auth import (
2024-08-27 22:10:27 +00:00
decode_token ,
get_admin_user ,
get_current_user ,
get_http_authorization_cred ,
get_verified_user ,
2024-03-10 05:47:01 +00:00
)
2024-12-10 08:54:13 +00:00
from open_webui . utils . oauth import oauth_manager
from open_webui . utils . security_headers import SecurityHeadersMiddleware
2024-09-20 22:30:13 +00:00
2024-06-24 02:28:33 +00:00
if SAFE_MODE :
print ( " SAFE MODE ENABLED " )
Functions . deactivate_all_functions ( )
2024-03-20 23:11:36 +00:00
logging . basicConfig ( stream = sys . stdout , level = GLOBAL_LOG_LEVEL )
log = logging . getLogger ( __name__ )
log . setLevel ( SRC_LOG_LEVELS [ " MAIN " ] )
2023-11-15 00:28:51 +00:00
2024-03-28 09:45:56 +00:00
2023-11-15 00:28:51 +00:00
class SPAStaticFiles ( StaticFiles ) :
async def get_response ( self , path : str , scope ) :
try :
return await super ( ) . get_response ( path , scope )
except ( HTTPException , StarletteHTTPException ) as ex :
if ex . status_code == 404 :
return await super ( ) . get_response ( " index.html " , scope )
else :
raise ex
2024-04-02 10:03:55 +00:00
print (
2024-05-03 21:23:38 +00:00
rf """
2024-10-08 05:13:49 +00:00
___ __ __ _ _ _ ___
2024-04-02 10:03:55 +00:00
/ _ \ _ __ ___ _ __ \ \ / / __ | | __ | | | | _ _ |
2024-10-08 05:13:49 +00:00
| | | | ' _ \ / _ \ ' _ \ \ \ / \ / / _ \ ' _ \ | | | || |
| | _ | | | _ ) | __ / | | | \ V V / __ / | _ ) | | _ | | | |
2024-04-02 10:03:55 +00:00
\___ / | . __ / \___ | _ | | _ | \_ / \_ / \___ | _ . __ / \___ / | ___ |
2024-10-08 05:13:49 +00:00
| _ |
2024-04-02 10:03:55 +00:00
2024-05-22 19:22:38 +00:00
v { VERSION } - building the best open - source AI user interface .
2024-05-26 07:49:30 +00:00
{ f " Commit: { WEBUI_BUILD_HASH } " if WEBUI_BUILD_HASH != " dev-build " else " " }
2024-04-02 10:03:55 +00:00
https : / / github . com / open - webui / open - webui
"""
)
2023-11-15 00:28:51 +00:00
2024-05-09 04:00:03 +00:00
@asynccontextmanager
async def lifespan ( app : FastAPI ) :
2024-09-24 23:06:11 +00:00
if RESET_CONFIG_ON_START :
reset_config ( )
2024-09-24 15:43:43 +00:00
asyncio . create_task ( periodic_usage_pool_cleanup ( ) )
2024-05-09 04:00:03 +00:00
yield
app = FastAPI (
2024-11-13 11:09:46 +00:00
docs_url = " /docs " if ENV == " dev " else None ,
openapi_url = " /openapi.json " if ENV == " dev " else None ,
redoc_url = None ,
lifespan = lifespan ,
2024-05-09 04:00:03 +00:00
)
2023-11-15 00:28:51 +00:00
2024-05-10 07:03:24 +00:00
app . state . config = AppConfig ( )
2024-05-24 08:40:48 +00:00
2024-12-10 08:54:13 +00:00
########################################
#
# OLLAMA
#
########################################
2024-05-24 08:40:48 +00:00
app . state . config . ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
2024-12-10 08:54:13 +00:00
app . state . config . OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app . state . config . OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS
2024-12-11 11:38:45 +00:00
app . state . OLLAMA_MODELS = { }
2024-12-10 08:54:13 +00:00
########################################
#
# OPENAI
#
########################################
app . state . config . ENABLE_OPENAI_API = ENABLE_OPENAI_API
app . state . config . OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app . state . config . OPENAI_API_KEYS = OPENAI_API_KEYS
app . state . config . OPENAI_API_CONFIGS = OPENAI_API_CONFIGS
2024-12-11 11:38:45 +00:00
app . state . OPENAI_MODELS = { }
2024-12-10 08:54:13 +00:00
########################################
#
# WEBUI
#
########################################
app . state . config . ENABLE_SIGNUP = ENABLE_SIGNUP
app . state . config . ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
app . state . config . ENABLE_API_KEY = ENABLE_API_KEY
2024-05-24 08:40:48 +00:00
2024-12-10 08:54:13 +00:00
app . state . config . JWT_EXPIRES_IN = JWT_EXPIRES_IN
app . state . config . SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
app . state . config . ADMIN_EMAIL = ADMIN_EMAIL
app . state . config . DEFAULT_MODELS = DEFAULT_MODELS
app . state . config . DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app . state . config . DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app . state . config . USER_PERMISSIONS = USER_PERMISSIONS
2024-05-10 07:03:24 +00:00
app . state . config . WEBHOOK_URL = WEBHOOK_URL
2024-12-10 08:54:13 +00:00
app . state . config . BANNERS = WEBUI_BANNERS
app . state . config . MODEL_ORDER_LIST = MODEL_ORDER_LIST
app . state . config . ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app . state . config . ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING
app . state . config . ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS
app . state . config . EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS
app . state . config . OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
app . state . config . OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
app . state . config . OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
app . state . config . ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
app . state . config . OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
app . state . config . OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
app . state . config . OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
app . state . config . ENABLE_LDAP = ENABLE_LDAP
app . state . config . LDAP_SERVER_LABEL = LDAP_SERVER_LABEL
app . state . config . LDAP_SERVER_HOST = LDAP_SERVER_HOST
app . state . config . LDAP_SERVER_PORT = LDAP_SERVER_PORT
app . state . config . LDAP_ATTRIBUTE_FOR_USERNAME = LDAP_ATTRIBUTE_FOR_USERNAME
app . state . config . LDAP_APP_DN = LDAP_APP_DN
app . state . config . LDAP_APP_PASSWORD = LDAP_APP_PASSWORD
app . state . config . LDAP_SEARCH_BASE = LDAP_SEARCH_BASE
app . state . config . LDAP_SEARCH_FILTERS = LDAP_SEARCH_FILTERS
app . state . config . LDAP_USE_TLS = LDAP_USE_TLS
app . state . config . LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE
app . state . config . LDAP_CIPHERS = LDAP_CIPHERS
app . state . AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app . state . AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
app . state . TOOLS = { }
app . state . FUNCTIONS = { }
########################################
#
# RETRIEVAL
#
########################################
2024-12-11 10:41:25 +00:00
app . state . config . TOP_K = RAG_TOP_K
app . state . config . RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app . state . config . FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
app . state . config . FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
app . state . config . ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app . state . config . ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app . state . config . CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
app . state . config . TIKA_SERVER_URL = TIKA_SERVER_URL
app . state . config . TEXT_SPLITTER = RAG_TEXT_SPLITTER
app . state . config . TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
app . state . config . CHUNK_SIZE = CHUNK_SIZE
app . state . config . CHUNK_OVERLAP = CHUNK_OVERLAP
app . state . config . RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app . state . config . RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app . state . config . RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
app . state . config . RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app . state . config . RAG_TEMPLATE = RAG_TEMPLATE
app . state . config . RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app . state . config . RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY
app . state . config . RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
app . state . config . RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
app . state . config . PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
app . state . config . YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
app . state . config . YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL
app . state . config . ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
app . state . config . RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
app . state . config . RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
app . state . config . SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app . state . config . GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
app . state . config . GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
app . state . config . BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
app . state . config . KAGI_SEARCH_API_KEY = KAGI_SEARCH_API_KEY
app . state . config . MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY
app . state . config . SERPSTACK_API_KEY = SERPSTACK_API_KEY
app . state . config . SERPSTACK_HTTPS = SERPSTACK_HTTPS
app . state . config . SERPER_API_KEY = SERPER_API_KEY
app . state . config . SERPLY_API_KEY = SERPLY_API_KEY
app . state . config . TAVILY_API_KEY = TAVILY_API_KEY
app . state . config . SEARCHAPI_API_KEY = SEARCHAPI_API_KEY
app . state . config . SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
app . state . config . JINA_API_KEY = JINA_API_KEY
app . state . config . BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
app . state . config . BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
app . state . config . RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app . state . config . RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
2024-12-12 02:05:42 +00:00
app . state . EMBEDDING_FUNCTION = None
app . state . sentence_transformer_ef = None
app . state . sentence_transformer_rf = None
2024-12-11 10:41:25 +00:00
app . state . YOUTUBE_LOADER_TRANSLATION = None
2024-12-12 02:05:42 +00:00
2024-12-11 10:41:25 +00:00
2024-12-12 02:08:55 +00:00
app . state . EMBEDDING_FUNCTION = get_embedding_function (
app . state . config . RAG_EMBEDDING_ENGINE ,
app . state . config . RAG_EMBEDDING_MODEL ,
app . state . sentence_transformer_ef ,
(
app . state . config . OPENAI_API_BASE_URL
if app . state . config . RAG_EMBEDDING_ENGINE == " openai "
else app . state . config . OLLAMA_BASE_URL
) ,
(
app . state . config . OPENAI_API_KEY
if app . state . config . RAG_EMBEDDING_ENGINE == " openai "
else app . state . config . OLLAMA_API_KEY
) ,
app . state . config . RAG_EMBEDDING_BATCH_SIZE ,
)
update_embedding_model (
app . state . config . RAG_EMBEDDING_MODEL ,
RAG_EMBEDDING_MODEL_AUTO_UPDATE ,
)
update_reranking_model (
app . state . config . RAG_RERANKING_MODEL ,
RAG_RERANKING_MODEL_AUTO_UPDATE ,
)
2024-12-10 08:54:13 +00:00
########################################
#
# IMAGES
#
########################################
app . state . config . IMAGE_GENERATION_ENGINE = IMAGE_GENERATION_ENGINE
app . state . config . ENABLE_IMAGE_GENERATION = ENABLE_IMAGE_GENERATION
app . state . config . IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app . state . config . IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
app . state . config . IMAGE_GENERATION_MODEL = IMAGE_GENERATION_MODEL
app . state . config . AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app . state . config . AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
app . state . config . AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE
app . state . config . AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER
app . state . config . AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER
app . state . config . COMFYUI_BASE_URL = COMFYUI_BASE_URL
app . state . config . COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
app . state . config . COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
app . state . config . IMAGE_SIZE = IMAGE_SIZE
app . state . config . IMAGE_STEPS = IMAGE_STEPS
########################################
#
# AUDIO
#
########################################
app . state . config . STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
app . state . config . STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
app . state . config . STT_ENGINE = AUDIO_STT_ENGINE
app . state . config . STT_MODEL = AUDIO_STT_MODEL
app . state . config . WHISPER_MODEL = WHISPER_MODEL
app . state . config . TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
app . state . config . TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
app . state . config . TTS_ENGINE = AUDIO_TTS_ENGINE
app . state . config . TTS_MODEL = AUDIO_TTS_MODEL
app . state . config . TTS_VOICE = AUDIO_TTS_VOICE
app . state . config . TTS_API_KEY = AUDIO_TTS_API_KEY
app . state . config . TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
app . state . config . TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION
app . state . config . TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT
app . state . faster_whisper_model = None
app . state . speech_synthesiser = None
app . state . speech_speaker_embeddings_dataset = None
########################################
#
# TASKS
#
########################################
2024-06-09 21:53:10 +00:00
app . state . config . TASK_MODEL = TASK_MODEL
app . state . config . TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
2024-11-16 12:41:07 +00:00
2024-12-10 08:54:13 +00:00
app . state . config . ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION
app . state . config . ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION
2024-12-01 02:30:59 +00:00
app . state . config . ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION
2024-11-06 02:32:08 +00:00
app . state . config . ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
2024-11-16 12:41:07 +00:00
2024-12-10 08:54:13 +00:00
app . state . config . TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
app . state . config . TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
app . state . config . TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
2024-11-19 10:24:32 +00:00
app . state . config . QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE
2024-11-29 07:53:52 +00:00
app . state . config . AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = (
AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
)
2024-12-10 08:54:13 +00:00
app . state . config . AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
2024-06-11 06:40:27 +00:00
)
2024-05-25 01:26:36 +00:00
2024-12-10 08:54:13 +00:00
2024-12-11 11:38:45 +00:00
########################################
#
# WEBUI
#
########################################
app . state . MODELS = { }
2024-06-20 08:51:39 +00:00
##################################
#
# ChatCompletion Middleware
#
##################################
2024-07-02 02:37:54 +00:00
def get_filter_function_ids ( model ) :
def get_priority ( function_id ) :
function = Functions . get_function_by_id ( function_id )
if function is not None and hasattr ( function , " valves " ) :
2024-08-11 07:31:40 +00:00
# TODO: Fix FunctionModel
2024-07-02 02:37:54 +00:00
return ( function . valves if function . valves else { } ) . get ( " priority " , 0 )
return 0
filter_ids = [ function . id for function in Functions . get_global_filter_functions ( ) ]
if " info " in model and " meta " in model [ " info " ] :
filter_ids . extend ( model [ " info " ] [ " meta " ] . get ( " filterIds " , [ ] ) )
filter_ids = list ( set ( filter_ids ) )
enabled_filter_ids = [
function . id
for function in Functions . get_functions_by_type ( " filter " , active_only = True )
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
filter_ids . sort ( key = get_priority )
return filter_ids
2024-08-17 14:00:18 +00:00
async def chat_completion_filter_functions_handler ( body , model , extra_params ) :
2024-07-02 02:33:58 +00:00
skip_files = None
filter_ids = get_filter_function_ids ( model )
for filter_id in filter_ids :
filter = Functions . get_function_by_id ( filter_id )
2024-07-09 10:51:43 +00:00
if not filter :
continue
2024-06-27 20:04:12 +00:00
2024-07-09 10:51:43 +00:00
if filter_id in webui_app . state . FUNCTIONS :
function_module = webui_app . state . FUNCTIONS [ filter_id ]
else :
function_module , _ , _ = load_function_module_by_id ( filter_id )
webui_app . state . FUNCTIONS [ filter_id ] = function_module
2024-07-02 02:33:58 +00:00
2024-07-09 10:51:43 +00:00
# Check if the function has a file_handler variable
if hasattr ( function_module , " file_handler " ) :
skip_files = function_module . file_handler
2024-07-02 02:33:58 +00:00
2024-07-09 10:51:43 +00:00
if hasattr ( function_module , " valves " ) and hasattr ( function_module , " Valves " ) :
valves = Functions . get_function_valves_by_id ( filter_id )
function_module . valves = function_module . Valves (
* * ( valves if valves else { } )
)
2024-07-09 11:15:09 +00:00
if not hasattr ( function_module , " inlet " ) :
continue
2024-07-09 10:51:43 +00:00
try :
2024-07-09 11:15:09 +00:00
inlet = function_module . inlet
# Get the signature of the function
sig = inspect . signature ( inlet )
2024-08-22 14:02:29 +00:00
params = { " body " : body } | {
k : v
for k , v in {
* * extra_params ,
" __model__ " : model ,
" __id__ " : filter_id ,
} . items ( )
if k in sig . parameters
}
if " __user__ " in params and hasattr ( function_module , " UserValves " ) :
2024-08-11 08:07:12 +00:00
try :
2024-08-22 14:02:29 +00:00
params [ " __user__ " ] [ " valves " ] = function_module . UserValves (
* * Functions . get_user_valves_by_id_and_user_id (
filter_id , params [ " __user__ " ] [ " id " ]
)
2024-08-11 08:07:12 +00:00
)
except Exception as e :
print ( e )
2024-07-09 11:15:09 +00:00
if inspect . iscoroutinefunction ( inlet ) :
body = await inlet ( * * params )
else :
body = inlet ( * * params )
2024-07-09 10:51:43 +00:00
except Exception as e :
print ( f " Error: { e } " )
raise e
2024-06-20 09:30:00 +00:00
2024-08-20 14:41:49 +00:00
if skip_files and " files " in body . get ( " metadata " , { } ) :
del body [ " metadata " ] [ " files " ]
2024-06-11 06:40:27 +00:00
2024-07-02 02:33:58 +00:00
return body , { }
2024-06-11 08:10:24 +00:00
2024-06-20 09:06:10 +00:00
2024-08-17 14:41:34 +00:00
def get_tools_function_calling_payload ( messages , task_model_id , content ) :
user_message = get_last_user_message ( messages )
history = " \n " . join (
f " { message [ ' role ' ] . upper ( ) } : \" \" \" { message [ ' content ' ] } \" \" \" "
for message in messages [ : : - 1 ] [ : 4 ]
)
prompt = f " History: \n { history } \n Query: { user_message } "
return {
" model " : task_model_id ,
" messages " : [
{ " role " : " system " , " content " : content } ,
{ " role " : " user " , " content " : f " Query: { prompt } " } ,
] ,
" stream " : False ,
" metadata " : { " task " : str ( TASKS . FUNCTION_CALLING ) } ,
}
2024-08-17 14:24:11 +00:00
async def get_content_from_response ( response ) - > Optional [ str ] :
content = None
if hasattr ( response , " body_iterator " ) :
async for chunk in response . body_iterator :
data = json . loads ( chunk . decode ( " utf-8 " ) )
content = data [ " choices " ] [ 0 ] [ " message " ] [ " content " ]
# Cleanup any remaining background tasks if necessary
if response . background is not None :
await response . background ( )
else :
content = response [ " choices " ] [ 0 ] [ " message " ] [ " content " ]
return content
2024-11-16 12:41:07 +00:00
def get_task_model_id (
default_model_id : str , task_model : str , task_model_external : str , models
) - > str :
# Set the task model
task_model_id = default_model_id
# Check if the user has a custom task model and use that model
if models [ task_model_id ] [ " owned_by " ] == " ollama " :
if task_model and task_model in models :
task_model_id = task_model
else :
if task_model_external and task_model_external in models :
task_model_id = task_model_external
return task_model_id
2024-08-11 14:16:57 +00:00
async def chat_completion_tools_handler (
2024-11-16 12:41:07 +00:00
body : dict , user : UserModel , models , extra_params : dict
2024-08-11 14:16:57 +00:00
) - > tuple [ dict , dict ] :
2024-08-19 10:11:00 +00:00
# If tool_ids field is present, call the functions
2024-08-20 14:41:49 +00:00
metadata = body . get ( " metadata " , { } )
2024-08-22 14:02:29 +00:00
2024-08-20 14:41:49 +00:00
tool_ids = metadata . get ( " tool_ids " , None )
2024-08-22 14:02:29 +00:00
log . debug ( f " { tool_ids =} " )
2024-08-19 10:11:00 +00:00
if not tool_ids :
return body , { }
2024-08-12 13:48:57 +00:00
skip_files = False
2024-11-22 03:46:09 +00:00
sources = [ ]
2024-07-02 02:33:58 +00:00
2024-11-16 12:41:07 +00:00
task_model_id = get_task_model_id (
body [ " model " ] ,
app . state . config . TASK_MODEL ,
app . state . config . TASK_MODEL_EXTERNAL ,
models ,
)
2024-08-22 14:02:29 +00:00
tools = get_tools (
webui_app ,
tool_ids ,
user ,
{
* * extra_params ,
2024-11-16 12:41:07 +00:00
" __model__ " : models [ task_model_id ] ,
2024-08-22 14:02:29 +00:00
" __messages__ " : body [ " messages " ] ,
" __files__ " : metadata . get ( " files " , [ ] ) ,
} ,
)
2024-08-17 14:24:11 +00:00
log . info ( f " { tools =} " )
2024-08-12 13:48:57 +00:00
2024-08-17 14:24:11 +00:00
specs = [ tool [ " spec " ] for tool in tools . values ( ) ]
2024-08-12 14:53:47 +00:00
tools_specs = json . dumps ( specs )
2024-08-17 14:27:11 +00:00
2024-09-07 03:50:29 +00:00
if app . state . config . TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != " " :
template = app . state . config . TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
else :
template = """ Available Tools: {{ TOOLS}} \n Return an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format { \" name \" : \" functionName \" , \" parameters \" : { \" requiredFunctionParamKey \" : \" requiredFunctionParamValue \" }} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text. """
2024-08-17 15:01:35 +00:00
tools_function_calling_prompt = tools_function_calling_generation_template (
2024-09-07 03:50:29 +00:00
template , tools_specs
2024-08-17 14:32:39 +00:00
)
2024-08-17 15:01:35 +00:00
log . info ( f " { tools_function_calling_prompt =} " )
2024-08-17 14:32:39 +00:00
payload = get_tools_function_calling_payload (
2024-08-17 15:01:35 +00:00
body [ " messages " ] , task_model_id , tools_function_calling_prompt
2024-08-17 14:32:39 +00:00
)
2024-08-17 14:24:11 +00:00
2024-08-12 14:53:47 +00:00
try :
2024-11-16 12:41:07 +00:00
payload = filter_pipeline ( payload , user , models )
2024-08-12 14:53:47 +00:00
except Exception as e :
raise e
2024-07-02 02:33:58 +00:00
2024-08-12 14:53:47 +00:00
try :
response = await generate_chat_completions ( form_data = payload , user = user )
log . debug ( f " { response =} " )
content = await get_content_from_response ( response )
log . debug ( f " { content =} " )
2024-08-17 14:27:11 +00:00
2024-08-19 16:04:57 +00:00
if not content :
2024-08-12 14:53:47 +00:00
return body , { }
2024-07-02 02:33:58 +00:00
2024-08-12 14:53:47 +00:00
try :
2024-09-28 17:51:28 +00:00
content = content [ content . find ( " { " ) : content . rfind ( " } " ) + 1 ]
if not content :
raise Exception ( " No JSON object found in the response " )
result = json . loads ( content )
tool_function_name = result . get ( " name " , None )
if tool_function_name not in tools :
return body , { }
tool_function_params = result . get ( " parameters " , { } )
try :
2024-10-26 19:21:05 +00:00
required_params = (
tools [ tool_function_name ]
. get ( " spec " , { } )
. get ( " parameters " , { } )
. get ( " required " , [ ] )
)
2024-10-25 21:36:44 +00:00
tool_function = tools [ tool_function_name ] [ " callable " ]
tool_function_params = {
2024-10-26 19:21:05 +00:00
k : v
for k , v in tool_function_params . items ( )
if k in required_params
2024-10-25 21:36:44 +00:00
}
tool_output = await tool_function ( * * tool_function_params )
2024-10-26 05:18:48 +00:00
2024-09-28 17:51:28 +00:00
except Exception as e :
tool_output = str ( e )
2024-11-22 03:46:09 +00:00
if isinstance ( tool_output , str ) :
if tools [ tool_function_name ] [ " citation " ] :
sources . append (
{
" source " : {
" name " : f " TOOL: { tools [ tool_function_name ] [ ' toolkit_id ' ] } / { tool_function_name } "
} ,
" document " : [ tool_output ] ,
" metadata " : [
{
" source " : f " TOOL: { tools [ tool_function_name ] [ ' toolkit_id ' ] } / { tool_function_name } "
}
] ,
}
)
else :
sources . append (
{
" source " : { } ,
" document " : [ tool_output ] ,
" metadata " : [
{
" source " : f " TOOL: { tools [ tool_function_name ] [ ' toolkit_id ' ] } / { tool_function_name } "
}
] ,
}
)
2024-11-22 02:26:38 +00:00
2024-11-22 03:46:09 +00:00
if tools [ tool_function_name ] [ " file_handler " ] :
skip_files = True
2024-09-28 17:51:28 +00:00
2024-08-10 10:58:18 +00:00
except Exception as e :
2024-09-28 17:51:28 +00:00
log . exception ( f " Error: { e } " )
content = None
2024-08-12 14:53:47 +00:00
except Exception as e :
2024-08-19 10:03:55 +00:00
log . exception ( f " Error: { e } " )
2024-08-12 14:53:47 +00:00
content = None
2024-08-10 10:58:18 +00:00
2024-11-22 03:46:09 +00:00
log . debug ( f " tool_contexts: { sources } " )
2024-07-02 02:33:58 +00:00
2024-08-20 14:41:49 +00:00
if skip_files and " files " in body . get ( " metadata " , { } ) :
del body [ " metadata " ] [ " files " ]
2024-07-02 02:33:58 +00:00
2024-11-22 03:46:09 +00:00
return body , { " sources " : sources }
2024-07-02 02:33:58 +00:00
2024-11-19 10:24:32 +00:00
async def chat_completion_files_handler (
body : dict , user : UserModel
) - > tuple [ dict , dict [ str , list ] ] :
2024-11-22 03:46:09 +00:00
sources = [ ]
2024-07-02 02:33:58 +00:00
2024-11-22 20:31:06 +00:00
if files := body . get ( " metadata " , { } ) . get ( " files " , None ) :
2024-11-19 10:24:32 +00:00
try :
2024-11-22 20:31:06 +00:00
queries_response = await generate_queries (
{
" model " : body [ " model " ] ,
" messages " : body [ " messages " ] ,
" type " : " retrieval " ,
} ,
user ,
)
queries_response = queries_response [ " choices " ] [ 0 ] [ " message " ] [ " content " ]
2024-11-19 10:24:32 +00:00
2024-11-22 20:31:06 +00:00
try :
2024-11-26 22:00:49 +00:00
bracket_start = queries_response . find ( " { " )
bracket_end = queries_response . rfind ( " } " ) + 1
if bracket_start == - 1 or bracket_end == - 1 :
raise Exception ( " No JSON object found in the response " )
queries_response = queries_response [ bracket_start : bracket_end ]
2024-11-22 20:31:06 +00:00
queries_response = json . loads ( queries_response )
except Exception as e :
2024-11-26 22:00:49 +00:00
queries_response = { " queries " : [ queries_response ] }
2024-11-19 10:24:32 +00:00
2024-11-22 20:31:06 +00:00
queries = queries_response . get ( " queries " , [ ] )
except Exception as e :
queries = [ ]
2024-11-19 10:24:32 +00:00
2024-11-22 20:31:06 +00:00
if len ( queries ) == 0 :
queries = [ get_last_user_message ( body [ " messages " ] ) ]
2024-11-22 03:46:09 +00:00
sources = get_sources_from_files (
2024-07-02 02:33:58 +00:00
files = files ,
2024-11-19 10:24:32 +00:00
queries = queries ,
2024-09-27 23:27:46 +00:00
embedding_function = retrieval_app . state . EMBEDDING_FUNCTION ,
k = retrieval_app . state . config . TOP_K ,
reranking_function = retrieval_app . state . sentence_transformer_rf ,
r = retrieval_app . state . config . RELEVANCE_THRESHOLD ,
hybrid_search = retrieval_app . state . config . ENABLE_RAG_HYBRID_SEARCH ,
2024-07-02 02:33:58 +00:00
)
2024-11-22 03:46:09 +00:00
log . debug ( f " rag_contexts:sources: { sources } " )
return body , { " sources " : sources }
2024-07-02 02:33:58 +00:00
2024-11-16 12:41:07 +00:00
async def get_body_and_model_and_user ( request , models ) :
2024-08-17 14:41:34 +00:00
# Read the original request body
body = await request . body ( )
body_str = body . decode ( " utf-8 " )
body = json . loads ( body_str ) if body_str else { }
model_id = body [ " model " ]
2024-11-16 12:41:07 +00:00
if model_id not in models :
2024-08-17 14:41:34 +00:00
raise Exception ( " Model not found " )
2024-11-16 12:41:07 +00:00
model = models [ model_id ]
2024-08-17 14:41:34 +00:00
user = get_current_user (
request ,
get_http_authorization_cred ( request . headers . get ( " Authorization " ) ) ,
)
return body , model , user
2024-07-02 02:33:58 +00:00
class ChatCompletionMiddleware ( BaseHTTPMiddleware ) :
async def dispatch ( self , request : Request , call_next ) :
2024-12-10 08:54:13 +00:00
if not request . method == " POST " and any (
endpoint in request . url . path
for endpoint in [ " /ollama/api/chat " , " /chat/completions " ]
) :
2024-08-10 11:03:47 +00:00
return await call_next ( request )
log . debug ( f " request.url.path: { request . url . path } " )
2024-07-02 02:33:58 +00:00
2024-11-16 12:41:07 +00:00
model_list = await get_all_models ( )
models = { model [ " id " ] : model for model in model_list }
2024-08-10 11:03:47 +00:00
try :
2024-11-16 12:41:07 +00:00
body , model , user = await get_body_and_model_and_user ( request , models )
2024-08-10 11:03:47 +00:00
except Exception as e :
return JSONResponse (
status_code = status . HTTP_400_BAD_REQUEST ,
content = { " detail " : str ( e ) } ,
)
2024-07-31 13:01:40 +00:00
2024-11-16 12:41:07 +00:00
model_info = Models . get_model_by_id ( model [ " id " ] )
2024-12-02 02:25:44 +00:00
if user . role == " user " and not BYPASS_MODEL_ACCESS_CONTROL :
2024-11-18 15:40:37 +00:00
if model . get ( " arena " ) :
if not has_access (
user . id ,
type = " read " ,
access_control = model . get ( " info " , { } )
. get ( " meta " , { } )
. get ( " access_control " , { } ) ,
) :
raise HTTPException (
status_code = 403 ,
detail = " Model not found " ,
)
else :
if not model_info :
return JSONResponse (
status_code = status . HTTP_404_NOT_FOUND ,
content = { " detail " : " Model not found " } ,
)
elif not (
user . id == model_info . user_id
or has_access (
user . id , type = " read " , access_control = model_info . access_control
)
) :
return JSONResponse (
status_code = status . HTTP_403_FORBIDDEN ,
content = { " detail " : " User does not have access to the model " } ,
)
2024-11-16 12:41:07 +00:00
2024-08-10 11:03:47 +00:00
metadata = {
" chat_id " : body . pop ( " chat_id " , None ) ,
" message_id " : body . pop ( " id " , None ) ,
" session_id " : body . pop ( " session_id " , None ) ,
2024-08-21 23:08:59 +00:00
" tool_ids " : body . get ( " tool_ids " , None ) ,
" files " : body . get ( " files " , None ) ,
2024-08-10 11:03:47 +00:00
}
2024-08-20 16:41:51 +00:00
body [ " metadata " ] = metadata
2024-07-09 04:39:06 +00:00
2024-08-11 07:31:40 +00:00
extra_params = {
" __event_emitter__ " : get_event_emitter ( metadata ) ,
" __event_call__ " : get_event_call ( metadata ) ,
2024-08-22 14:02:29 +00:00
" __user__ " : {
" id " : user . id ,
" email " : user . email ,
" name " : user . name ,
" role " : user . role ,
} ,
2024-11-18 18:12:54 +00:00
" __metadata__ " : metadata ,
2024-08-11 07:31:40 +00:00
}
2024-07-02 02:33:58 +00:00
2024-08-10 11:03:47 +00:00
# Initialize data_items to store additional data to be sent to the client
2024-10-14 07:13:26 +00:00
# Initialize contexts and citation
2024-08-10 11:03:47 +00:00
data_items = [ ]
2024-11-22 03:46:09 +00:00
sources = [ ]
2024-07-02 02:33:58 +00:00
2024-08-10 11:03:47 +00:00
try :
2024-08-17 14:00:18 +00:00
body , flags = await chat_completion_filter_functions_handler (
2024-08-11 07:31:40 +00:00
body , model , extra_params
2024-08-10 11:03:47 +00:00
)
except Exception as e :
return JSONResponse (
status_code = status . HTTP_400_BAD_REQUEST ,
content = { " detail " : str ( e ) } ,
)
2024-07-02 02:33:58 +00:00
2024-11-16 12:41:07 +00:00
tool_ids = body . pop ( " tool_ids " , None )
files = body . pop ( " files " , None )
2024-08-21 23:08:59 +00:00
metadata = {
* * metadata ,
2024-11-16 12:41:07 +00:00
" tool_ids " : tool_ids ,
" files " : files ,
2024-08-21 23:08:59 +00:00
}
body [ " metadata " ] = metadata
2024-08-10 11:03:47 +00:00
try :
2024-11-16 12:41:07 +00:00
body , flags = await chat_completion_tools_handler (
body , user , models , extra_params
)
2024-11-22 03:46:09 +00:00
sources . extend ( flags . get ( " sources " , [ ] ) )
2024-08-10 11:03:47 +00:00
except Exception as e :
2024-08-14 19:40:10 +00:00
log . exception ( e )
2024-03-09 06:34:47 +00:00
2024-08-10 11:03:47 +00:00
try :
2024-11-19 10:24:32 +00:00
body , flags = await chat_completion_files_handler ( body , user )
2024-11-22 03:46:09 +00:00
sources . extend ( flags . get ( " sources " , [ ] ) )
2024-08-10 11:03:47 +00:00
except Exception as e :
2024-08-14 19:40:10 +00:00
log . exception ( e )
2024-08-10 11:03:47 +00:00
# If context is not empty, insert it into the messages
2024-11-22 03:46:09 +00:00
if len ( sources ) > 0 :
2024-11-22 01:58:29 +00:00
context_string = " "
2024-11-22 03:46:09 +00:00
for source_idx , source in enumerate ( sources ) :
source_id = source . get ( " source " , { } ) . get ( " name " , " " )
2024-11-22 02:26:38 +00:00
2024-11-22 03:46:09 +00:00
if " document " in source :
for doc_idx , doc_context in enumerate ( source [ " document " ] ) :
metadata = source . get ( " metadata " )
2024-11-22 18:35:59 +00:00
doc_source_id = None
2024-11-22 03:46:09 +00:00
if metadata :
doc_source_id = metadata [ doc_idx ] . get ( " source " , source_id )
if source_id :
2024-11-22 18:35:59 +00:00
context_string + = f " <source><source_id> { doc_source_id if doc_source_id is not None else source_id } </source_id><source_context> { doc_context } </source_context></source> \n "
2024-11-22 03:46:09 +00:00
else :
# If there is no source_id, then do not include the source_id tag
context_string + = f " <source><source_context> { doc_context } </source_context></source> \n "
2024-11-22 01:58:29 +00:00
context_string = context_string . strip ( )
2024-08-10 11:03:47 +00:00
prompt = get_last_user_message ( body [ " messages " ] )
2024-09-13 05:07:03 +00:00
2024-08-10 11:11:41 +00:00
if prompt is None :
raise Exception ( " No user message found " )
2024-09-13 05:07:03 +00:00
if (
2024-09-27 23:27:46 +00:00
retrieval_app . state . config . RELEVANCE_THRESHOLD == 0
2024-09-13 05:07:03 +00:00
and context_string . strip ( ) == " "
) :
log . debug (
f " With a 0 relevancy threshold for RAG, the context cannot be empty "
)
2024-09-13 04:56:50 +00:00
2024-08-10 11:03:47 +00:00
# Workaround for Ollama 2.0+ system prompt issue
# TODO: replace with add_or_update_system_message
if model [ " owned_by " ] == " ollama " :
body [ " messages " ] = prepend_to_first_user_message_content (
rag_template (
2024-09-27 23:27:46 +00:00
retrieval_app . state . config . RAG_TEMPLATE , context_string , prompt
2024-08-10 11:03:47 +00:00
) ,
body [ " messages " ] ,
)
2024-06-20 10:23:50 +00:00
else :
2024-08-10 11:03:47 +00:00
body [ " messages " ] = add_or_update_system_message (
rag_template (
2024-09-27 23:27:46 +00:00
retrieval_app . state . config . RAG_TEMPLATE , context_string , prompt
2024-08-10 11:03:47 +00:00
) ,
body [ " messages " ] ,
)
# If there are citations, add them to the data_items
2024-11-22 03:46:09 +00:00
sources = [
source for source in sources if source . get ( " source " , { } ) . get ( " name " , " " )
]
if len ( sources ) > 0 :
data_items . append ( { " sources " : sources } )
2024-08-10 11:03:47 +00:00
modified_body_bytes = json . dumps ( body ) . encode ( " utf-8 " )
# Replace the request body with the modified one
request . _body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
request . headers . __dict__ [ " _list " ] = [
( b " content-length " , str ( len ( modified_body_bytes ) ) . encode ( " utf-8 " ) ) ,
* [ ( k , v ) for k , v in request . headers . raw if k . lower ( ) != b " content-length " ] ,
]
2024-05-06 13:14:51 +00:00
2024-06-20 10:23:50 +00:00
response = await call_next ( request )
2024-08-19 09:34:44 +00:00
if not isinstance ( response , StreamingResponse ) :
return response
2024-08-10 11:03:47 +00:00
2024-08-19 09:34:44 +00:00
content_type = response . headers [ " Content-Type " ]
is_openai = " text/event-stream " in content_type
is_ollama = " application/x-ndjson " in content_type
if not is_openai and not is_ollama :
return response
2024-03-09 06:34:47 +00:00
2024-08-19 09:34:44 +00:00
def wrap_item ( item ) :
return f " data: { item } \n \n " if is_openai else f " { item } \n "
2024-03-09 06:34:47 +00:00
2024-08-19 09:34:44 +00:00
async def stream_wrapper ( original_generator , data_items ) :
for item in data_items :
yield wrap_item ( json . dumps ( item ) )
2024-06-20 09:06:10 +00:00
2024-08-19 09:34:44 +00:00
async for data in original_generator :
yield data
2024-05-06 13:14:51 +00:00
2024-08-31 14:15:21 +00:00
return StreamingResponse (
stream_wrapper ( response . body_iterator , data_items ) ,
headers = dict ( response . headers ) ,
)
2024-06-20 09:06:10 +00:00
2024-08-15 16:03:42 +00:00
async def _receive ( self , body : bytes ) :
return { " type " : " http.request " , " body " : body , " more_body " : False }
2024-05-06 13:14:51 +00:00
2024-03-09 06:34:47 +00:00
2024-06-11 06:40:27 +00:00
app . add_middleware ( ChatCompletionMiddleware )
2024-03-09 06:34:47 +00:00
2024-10-10 21:00:05 +00:00
2024-06-20 08:51:39 +00:00
##################################
#
# Pipeline Middleware
#
##################################
2024-03-09 06:34:47 +00:00
2024-11-16 12:41:07 +00:00
def get_sorted_filters ( model_id , models ) :
2024-06-09 21:25:31 +00:00
filters = [
model
2024-11-16 12:41:07 +00:00
for model in models . values ( )
2024-06-09 21:25:31 +00:00
if " pipeline " in model
2024-10-21 01:38:06 +00:00
and " type " in model [ " pipeline " ]
and model [ " pipeline " ] [ " type " ] == " filter "
and (
model [ " pipeline " ] [ " pipelines " ] == [ " * " ]
or any (
model_id == target_model_id
for target_model_id in model [ " pipeline " ] [ " pipelines " ]
)
)
2024-06-09 21:25:31 +00:00
]
sorted_filters = sorted ( filters , key = lambda x : x [ " pipeline " ] [ " priority " ] )
2024-07-09 11:51:13 +00:00
return sorted_filters
2024-11-16 12:41:07 +00:00
def filter_pipeline ( payload , user , models ) :
2024-07-09 11:51:13 +00:00
user = { " id " : user . id , " email " : user . email , " name " : user . name , " role " : user . role }
model_id = payload [ " model " ]
2024-06-09 21:25:31 +00:00
2024-11-16 12:41:07 +00:00
sorted_filters = get_sorted_filters ( model_id , models )
model = models [ model_id ]
2024-06-09 21:25:31 +00:00
if " pipeline " in model :
sorted_filters . append ( model )
for filter in sorted_filters :
r = None
try :
urlIdx = filter [ " urlIdx " ]
2024-12-11 10:41:25 +00:00
url = app . state . config . OPENAI_API_BASE_URLS [ urlIdx ]
key = app . state . config . OPENAI_API_KEYS [ urlIdx ]
2024-06-09 21:25:31 +00:00
2024-08-10 12:04:01 +00:00
if key == " " :
continue
headers = { " Authorization " : f " Bearer { key } " }
r = requests . post (
f " { url } / { filter [ ' id ' ] } /filter/inlet " ,
headers = headers ,
json = {
" user " : user ,
" body " : payload ,
} ,
)
2024-06-09 21:25:31 +00:00
2024-08-10 12:04:01 +00:00
r . raise_for_status ( )
payload = r . json ( )
2024-06-09 21:25:31 +00:00
except Exception as e :
# Handle connection error here
print ( f " Connection error: { e } " )
if r is not None :
2024-07-09 11:51:13 +00:00
res = r . json ( )
2024-06-12 20:31:05 +00:00
if " detail " in res :
raise Exception ( r . status_code , res [ " detail " ] )
2024-06-09 21:25:31 +00:00
return payload
2024-05-28 02:03:26 +00:00
class PipelineMiddleware ( BaseHTTPMiddleware ) :
async def dispatch ( self , request : Request , call_next ) :
2024-12-10 08:54:13 +00:00
if not request . method == " POST " and any (
endpoint in request . url . path
for endpoint in [ " /ollama/api/chat " , " /chat/completions " ]
) :
2024-08-10 11:03:47 +00:00
return await call_next ( request )
2024-06-12 20:31:05 +00:00
2024-08-10 11:03:47 +00:00
log . debug ( f " request.url.path: { request . url . path } " )
2024-05-30 09:04:29 +00:00
2024-08-10 11:03:47 +00:00
# Read the original request body
body = await request . body ( )
# Decode body to string
body_str = body . decode ( " utf-8 " )
# Parse string to JSON
data = json . loads ( body_str ) if body_str else { }
2024-09-27 18:04:45 +00:00
try :
user = get_current_user (
request ,
get_http_authorization_cred ( request . headers [ " Authorization " ] ) ,
)
except KeyError as e :
if len ( e . args ) > 1 :
return JSONResponse (
status_code = e . args [ 0 ] ,
content = { " detail " : e . args [ 1 ] } ,
)
else :
return JSONResponse (
status_code = status . HTTP_401_UNAUTHORIZED ,
content = { " detail " : " Not authenticated " } ,
)
2024-11-20 07:25:50 +00:00
except HTTPException as e :
return JSONResponse (
status_code = e . status_code ,
content = { " detail " : e . detail } ,
)
2024-08-10 11:03:47 +00:00
2024-11-16 12:41:07 +00:00
model_list = await get_all_models ( )
models = { model [ " id " ] : model for model in model_list }
2024-08-10 11:03:47 +00:00
try :
2024-11-16 12:41:07 +00:00
data = filter_pipeline ( data , user , models )
2024-08-10 11:03:47 +00:00
except Exception as e :
2024-09-04 15:52:59 +00:00
if len ( e . args ) > 1 :
return JSONResponse (
status_code = e . args [ 0 ] ,
content = { " detail " : e . args [ 1 ] } ,
)
else :
return JSONResponse (
status_code = status . HTTP_400_BAD_REQUEST ,
content = { " detail " : str ( e ) } ,
)
2024-08-10 11:03:47 +00:00
modified_body_bytes = json . dumps ( data ) . encode ( " utf-8 " )
# Replace the request body with the modified one
request . _body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
request . headers . __dict__ [ " _list " ] = [
( b " content-length " , str ( len ( modified_body_bytes ) ) . encode ( " utf-8 " ) ) ,
* [ ( k , v ) for k , v in request . headers . raw if k . lower ( ) != b " content-length " ] ,
]
2024-05-28 02:03:26 +00:00
response = await call_next ( request )
return response
async def _receive ( self , body : bytes ) :
return { " type " : " http.request " , " body " : body , " more_body " : False }
app . add_middleware ( PipelineMiddleware )
2024-10-08 01:19:13 +00:00
class RedirectMiddleware ( BaseHTTPMiddleware ) :
async def dispatch ( self , request : Request , call_next ) :
# Check if the request is a GET request
if request . method == " GET " :
path = request . url . path
query_params = dict ( parse_qs ( urlparse ( str ( request . url ) ) . query ) )
# Check for the specific watch path and the presence of 'v' parameter
if path . endswith ( " /watch " ) and " v " in query_params :
video_id = query_params [ " v " ] [ 0 ] # Extract the first 'v' parameter
encoded_video_id = urlencode ( { " youtube " : video_id } )
redirect_url = f " /? { encoded_video_id } "
return RedirectResponse ( url = redirect_url )
# Proceed with the normal flow of other requests
response = await call_next ( request )
return response
# Add the middleware to the app
app . add_middleware ( RedirectMiddleware )
2024-09-17 00:53:30 +00:00
app . add_middleware ( SecurityHeadersMiddleware )
2024-05-28 16:50:17 +00:00
2024-06-24 11:06:15 +00:00
@app.middleware ( " http " )
2024-06-24 11:45:33 +00:00
async def commit_session_after_request ( request : Request , call_next ) :
2024-06-24 11:06:15 +00:00
response = await call_next ( request )
2024-11-23 04:11:46 +00:00
# log.debug("Commit session after request")
2024-06-24 11:06:15 +00:00
Session . commit ( )
return response
2024-05-28 16:50:17 +00:00
2023-11-15 00:28:51 +00:00
@app.middleware ( " http " )
async def check_url ( request : Request , call_next ) :
start_time = int ( time . time ( ) )
2024-11-19 20:17:23 +00:00
request . state . enable_api_key = webui_app . state . config . ENABLE_API_KEY
2023-11-15 00:28:51 +00:00
response = await call_next ( request )
process_time = int ( time . time ( ) ) - start_time
response . headers [ " X-Process-Time " ] = str ( process_time )
return response
2024-09-08 10:54:56 +00:00
@app.middleware ( " http " )
async def inspect_websocket ( request : Request , call_next ) :
if (
2024-10-21 01:38:06 +00:00
" /ws/socket.io " in request . url . path
and request . query_params . get ( " transport " ) == " websocket "
2024-09-08 10:54:56 +00:00
) :
upgrade = ( request . headers . get ( " Upgrade " ) or " " ) . lower ( )
connection = ( request . headers . get ( " Connection " ) or " " ) . lower ( ) . split ( " , " )
# Check that there's the correct headers for an upgrade, else reject the connection
# This is to work around this upstream issue: https://github.com/miguelgrinberg/python-engineio/issues/367
if upgrade != " websocket " or " upgrade " not in connection :
return JSONResponse (
status_code = status . HTTP_400_BAD_REQUEST ,
content = { " detail " : " Invalid WebSocket upgrade request " } ,
)
return await call_next ( request )
2024-12-10 08:54:13 +00:00
app . add_middleware (
CORSMiddleware ,
allow_origins = CORS_ALLOW_ORIGIN ,
allow_credentials = True ,
allow_methods = [ " * " ] ,
allow_headers = [ " * " ] ,
)
2024-06-04 06:39:52 +00:00
app . mount ( " /ws " , socket_app )
2024-12-10 08:54:13 +00:00
2024-12-12 02:36:59 +00:00
app . include_router ( ollama . router , prefix = " /ollama " , tags = [ " ollama " ] )
app . include_router ( openai . router , prefix = " /openai " , tags = [ " openai " ] )
app . include_router ( pipelines . router , prefix = " /pipelines " , tags = [ " pipelines " ] )
app . include_router ( tasks . router , prefix = " /tasks " , tags = [ " tasks " ] )
2024-12-11 10:41:25 +00:00
2024-02-11 08:17:50 +00:00
2024-12-12 01:50:48 +00:00
app . include_router ( images . router , prefix = " /api/v1/images " , tags = [ " images " ] )
app . include_router ( audio . router , prefix = " /api/v1/audio " , tags = [ " audio " ] )
app . include_router ( retrieval . router , prefix = " /api/v1/retrieval " , tags = [ " retrieval " ] )
2024-12-10 08:54:13 +00:00
2024-12-11 10:41:25 +00:00
app . include_router ( configs . router , prefix = " /api/v1/configs " , tags = [ " configs " ] )
2024-01-07 06:07:20 +00:00
2024-12-11 10:41:25 +00:00
app . include_router ( auths . router , prefix = " /api/v1/auths " , tags = [ " auths " ] )
app . include_router ( users . router , prefix = " /api/v1/users " , tags = [ " users " ] )
2024-05-19 15:00:07 +00:00
2024-12-11 10:41:25 +00:00
app . include_router ( chats . router , prefix = " /api/v1/chats " , tags = [ " chats " ] )
app . include_router ( models . router , prefix = " /api/v1/models " , tags = [ " models " ] )
app . include_router ( knowledge . router , prefix = " /api/v1/knowledge " , tags = [ " knowledge " ] )
app . include_router ( prompts . router , prefix = " /api/v1/prompts " , tags = [ " prompts " ] )
app . include_router ( tools . router , prefix = " /api/v1/tools " , tags = [ " tools " ] )
app . include_router ( memories . router , prefix = " /api/v1/memories " , tags = [ " memories " ] )
app . include_router ( folders . router , prefix = " /api/v1/folders " , tags = [ " folders " ] )
app . include_router ( groups . router , prefix = " /api/v1/groups " , tags = [ " groups " ] )
app . include_router ( files . router , prefix = " /api/v1/files " , tags = [ " files " ] )
app . include_router ( functions . router , prefix = " /api/v1/functions " , tags = [ " functions " ] )
app . include_router (
evaluations . router , prefix = " /api/v1/evaluations " , tags = [ " evaluations " ]
)
app . include_router ( utils . router , prefix = " /api/v1/utils " , tags = [ " utils " ] )
2024-05-19 15:00:07 +00:00
2024-03-31 20:59:39 +00:00
2024-12-12 02:36:59 +00:00
##################################
#
# Chat Endpoints
#
##################################
def get_function_module ( pipe_id : str ) :
# Check if function is already loaded
if pipe_id not in app . state . FUNCTIONS :
function_module , _ , _ = load_function_module_by_id ( pipe_id )
app . state . FUNCTIONS [ pipe_id ] = function_module
else :
function_module = app . state . FUNCTIONS [ pipe_id ]
if hasattr ( function_module , " valves " ) and hasattr ( function_module , " Valves " ) :
valves = Functions . get_function_valves_by_id ( pipe_id )
function_module . valves = function_module . Valves ( * * ( valves if valves else { } ) )
return function_module
async def get_function_models ( ) :
pipes = Functions . get_functions_by_type ( " pipe " , active_only = True )
pipe_models = [ ]
for pipe in pipes :
function_module = get_function_module ( pipe . id )
# Check if function is a manifold
if hasattr ( function_module , " pipes " ) :
sub_pipes = [ ]
# Check if pipes is a function or a list
try :
if callable ( function_module . pipes ) :
sub_pipes = function_module . pipes ( )
else :
sub_pipes = function_module . pipes
except Exception as e :
log . exception ( e )
sub_pipes = [ ]
log . debug (
f " get_function_models: function ' { pipe . id } ' is a manifold of { sub_pipes } "
)
for p in sub_pipes :
sub_pipe_id = f ' { pipe . id } . { p [ " id " ] } '
sub_pipe_name = p [ " name " ]
if hasattr ( function_module , " name " ) :
sub_pipe_name = f " { function_module . name } { sub_pipe_name } "
pipe_flag = { " type " : pipe . type }
pipe_models . append (
{
" id " : sub_pipe_id ,
" name " : sub_pipe_name ,
" object " : " model " ,
" created " : pipe . created_at ,
" owned_by " : " openai " ,
" pipe " : pipe_flag ,
}
)
else :
pipe_flag = { " type " : " pipe " }
log . debug (
f " get_function_models: function ' { pipe . id } ' is a single pipe {{ ' id ' : { pipe . id } , ' name ' : { pipe . name } }} "
)
pipe_models . append (
{
" id " : pipe . id ,
" name " : pipe . name ,
" object " : " model " ,
" created " : pipe . created_at ,
" owned_by " : " openai " ,
" pipe " : pipe_flag ,
}
)
return pipe_models
async def generate_function_chat_completion ( form_data , user , models : dict = { } ) :
async def execute_pipe ( pipe , params ) :
if inspect . iscoroutinefunction ( pipe ) :
return await pipe ( * * params )
else :
return pipe ( * * params )
async def get_message_content ( res : str | Generator | AsyncGenerator ) - > str :
if isinstance ( res , str ) :
return res
if isinstance ( res , Generator ) :
return " " . join ( map ( str , res ) )
if isinstance ( res , AsyncGenerator ) :
return " " . join ( [ str ( stream ) async for stream in res ] )
def process_line ( form_data : dict , line ) :
if isinstance ( line , BaseModel ) :
line = line . model_dump_json ( )
line = f " data: { line } "
if isinstance ( line , dict ) :
line = f " data: { json . dumps ( line ) } "
try :
line = line . decode ( " utf-8 " )
except Exception :
pass
if line . startswith ( " data: " ) :
return f " { line } \n \n "
else :
line = openai_chat_chunk_message_template ( form_data [ " model " ] , line )
return f " data: { json . dumps ( line ) } \n \n "
def get_pipe_id ( form_data : dict ) - > str :
pipe_id = form_data [ " model " ]
if " . " in pipe_id :
pipe_id , _ = pipe_id . split ( " . " , 1 )
return pipe_id
def get_function_params ( function_module , form_data , user , extra_params = None ) :
if extra_params is None :
extra_params = { }
pipe_id = get_pipe_id ( form_data )
# Get the signature of the function
sig = inspect . signature ( function_module . pipe )
params = { " body " : form_data } | {
k : v for k , v in extra_params . items ( ) if k in sig . parameters
}
if " __user__ " in params and hasattr ( function_module , " UserValves " ) :
user_valves = Functions . get_user_valves_by_id_and_user_id ( pipe_id , user . id )
try :
params [ " __user__ " ] [ " valves " ] = function_module . UserValves ( * * user_valves )
except Exception as e :
log . exception ( e )
params [ " __user__ " ] [ " valves " ] = function_module . UserValves ( )
return params
model_id = form_data . get ( " model " )
model_info = Models . get_model_by_id ( model_id )
metadata = form_data . pop ( " metadata " , { } )
files = metadata . get ( " files " , [ ] )
tool_ids = metadata . get ( " tool_ids " , [ ] )
# Check if tool_ids is None
if tool_ids is None :
tool_ids = [ ]
__event_emitter__ = None
__event_call__ = None
__task__ = None
__task_body__ = None
if metadata :
if all ( k in metadata for k in ( " session_id " , " chat_id " , " message_id " ) ) :
__event_emitter__ = get_event_emitter ( metadata )
__event_call__ = get_event_call ( metadata )
__task__ = metadata . get ( " task " , None )
__task_body__ = metadata . get ( " task_body " , None )
extra_params = {
" __event_emitter__ " : __event_emitter__ ,
" __event_call__ " : __event_call__ ,
" __task__ " : __task__ ,
" __task_body__ " : __task_body__ ,
" __files__ " : files ,
" __user__ " : {
" id " : user . id ,
" email " : user . email ,
" name " : user . name ,
" role " : user . role ,
} ,
" __metadata__ " : metadata ,
}
extra_params [ " __tools__ " ] = get_tools (
app ,
tool_ids ,
user ,
{
* * extra_params ,
" __model__ " : models . get ( form_data [ " model " ] , None ) ,
" __messages__ " : form_data [ " messages " ] ,
" __files__ " : files ,
} ,
)
if model_info :
if model_info . base_model_id :
form_data [ " model " ] = model_info . base_model_id
params = model_info . params . model_dump ( )
form_data = apply_model_params_to_body_openai ( params , form_data )
form_data = apply_model_system_prompt_to_body ( params , form_data , user )
pipe_id = get_pipe_id ( form_data )
function_module = get_function_module ( pipe_id )
pipe = function_module . pipe
params = get_function_params ( function_module , form_data , user , extra_params )
if form_data . get ( " stream " , False ) :
async def stream_content ( ) :
try :
res = await execute_pipe ( pipe , params )
# Directly return if the response is a StreamingResponse
if isinstance ( res , StreamingResponse ) :
async for data in res . body_iterator :
yield data
return
if isinstance ( res , dict ) :
yield f " data: { json . dumps ( res ) } \n \n "
return
except Exception as e :
log . error ( f " Error: { e } " )
yield f " data: { json . dumps ( { ' error ' : { ' detail ' : str ( e ) } } ) } \n \n "
return
if isinstance ( res , str ) :
message = openai_chat_chunk_message_template ( form_data [ " model " ] , res )
yield f " data: { json . dumps ( message ) } \n \n "
if isinstance ( res , Iterator ) :
for line in res :
yield process_line ( form_data , line )
if isinstance ( res , AsyncGenerator ) :
async for line in res :
yield process_line ( form_data , line )
if isinstance ( res , str ) or isinstance ( res , Generator ) :
finish_message = openai_chat_chunk_message_template (
form_data [ " model " ] , " "
)
finish_message [ " choices " ] [ 0 ] [ " finish_reason " ] = " stop "
yield f " data: { json . dumps ( finish_message ) } \n \n "
yield " data: [DONE] "
return StreamingResponse ( stream_content ( ) , media_type = " text/event-stream " )
else :
try :
res = await execute_pipe ( pipe , params )
except Exception as e :
log . error ( f " Error: { e } " )
return { " error " : { " detail " : str ( e ) } }
if isinstance ( res , StreamingResponse ) or isinstance ( res , dict ) :
return res
if isinstance ( res , BaseModel ) :
return res . model_dump ( )
message = await get_message_content ( res )
return openai_chat_completion_message_template ( form_data [ " model " ] , message )
2024-11-16 03:14:24 +00:00
async def get_all_base_models ( ) :
2024-12-12 02:36:59 +00:00
function_models = [ ]
2024-05-24 08:40:48 +00:00
openai_models = [ ]
ollama_models = [ ]
if app . state . config . ENABLE_OPENAI_API :
2024-12-10 08:54:13 +00:00
openai_models = await openai . get_all_models ( )
2024-05-24 08:40:48 +00:00
openai_models = openai_models [ " data " ]
if app . state . config . ENABLE_OLLAMA_API :
2024-12-11 11:38:45 +00:00
ollama_models = await ollama . get_all_models ( )
2024-05-24 08:40:48 +00:00
ollama_models = [
{
" id " : model [ " model " ] ,
" name " : model [ " name " ] ,
" object " : " model " ,
" created " : int ( time . time ( ) ) ,
" owned_by " : " ollama " ,
" ollama " : model ,
}
for model in ollama_models [ " models " ]
]
2024-12-12 02:36:59 +00:00
function_models = await get_function_models ( )
models = function_models + openai_models + ollama_models
# Add arena models
if app . state . config . ENABLE_EVALUATION_ARENA_MODELS :
arena_models = [ ]
if len ( app . state . config . EVALUATION_ARENA_MODELS ) > 0 :
arena_models = [
{
" id " : model [ " id " ] ,
" name " : model [ " name " ] ,
" info " : {
" meta " : model [ " meta " ] ,
} ,
" object " : " model " ,
" created " : int ( time . time ( ) ) ,
" owned_by " : " arena " ,
" arena " : True ,
}
for model in app . state . config . EVALUATION_ARENA_MODELS
]
else :
# Add default arena model
arena_models = [
{
" id " : DEFAULT_ARENA_MODEL [ " id " ] ,
" name " : DEFAULT_ARENA_MODEL [ " name " ] ,
" info " : {
" meta " : DEFAULT_ARENA_MODEL [ " meta " ] ,
} ,
" object " : " model " ,
" created " : int ( time . time ( ) ) ,
" owned_by " : " arena " ,
" arena " : True ,
}
]
models = models + arena_models
2024-10-22 19:04:45 +00:00
2024-11-16 03:14:24 +00:00
return models
2024-11-21 20:16:50 +00:00
@cached ( ttl = 3 )
2024-11-16 03:14:24 +00:00
async def get_all_models ( ) :
models = await get_all_base_models ( )
2024-05-24 08:40:48 +00:00
2024-10-22 19:04:45 +00:00
# If there are no models, return an empty list
2024-11-16 06:04:33 +00:00
if len ( [ model for model in models if not model . get ( " arena " , False ) ] ) == 0 :
2024-10-22 19:04:45 +00:00
return [ ]
2024-07-12 01:41:00 +00:00
global_action_ids = [
function . id for function in Functions . get_global_action_functions ( )
]
enabled_action_ids = [
function . id
for function in Functions . get_functions_by_type ( " action " , active_only = True )
]
2024-06-20 11:21:55 +00:00
custom_models = Models . get_all_models ( )
2024-05-24 08:40:48 +00:00
for custom_model in custom_models :
2024-08-03 13:24:26 +00:00
if custom_model . base_model_id is None :
2024-05-24 08:40:48 +00:00
for model in models :
2024-05-24 09:11:17 +00:00
if (
2024-10-21 01:38:06 +00:00
custom_model . id == model [ " id " ]
or custom_model . id == model [ " id " ] . split ( " : " ) [ 0 ]
2024-05-24 09:11:17 +00:00
) :
2024-11-16 09:24:34 +00:00
if custom_model . is_active :
model [ " name " ] = custom_model . name
model [ " info " ] = custom_model . model_dump ( )
action_ids = [ ]
if " info " in model and " meta " in model [ " info " ] :
action_ids . extend (
model [ " info " ] [ " meta " ] . get ( " actionIds " , [ ] )
)
model [ " action_ids " ] = action_ids
else :
models . remove ( model )
elif custom_model . is_active and (
custom_model . id not in [ model [ " id " ] for model in models ]
) :
2024-05-24 10:06:57 +00:00
owned_by = " openai "
2024-07-04 20:41:18 +00:00
pipe = None
2024-07-28 12:00:58 +00:00
action_ids = [ ]
2024-07-04 20:41:18 +00:00
2024-05-24 10:06:57 +00:00
for model in models :
2024-05-25 03:29:13 +00:00
if (
2024-10-21 01:38:06 +00:00
custom_model . base_model_id == model [ " id " ]
or custom_model . base_model_id == model [ " id " ] . split ( " : " ) [ 0 ]
2024-05-25 03:29:13 +00:00
) :
2024-05-24 10:06:57 +00:00
owned_by = model [ " owned_by " ]
2024-07-04 20:41:18 +00:00
if " pipe " in model :
pipe = model [ " pipe " ]
2024-05-24 10:06:57 +00:00
break
2024-10-14 07:28:21 +00:00
if custom_model . meta :
meta = custom_model . meta . model_dump ( )
if " actionIds " in meta :
action_ids . extend ( meta [ " actionIds " ] )
2024-05-24 08:40:48 +00:00
models . append (
{
2024-11-16 06:04:33 +00:00
" id " : f " { custom_model . id } " ,
2024-05-24 08:40:48 +00:00
" name " : custom_model . name ,
" object " : " model " ,
" created " : custom_model . created_at ,
2024-05-24 10:06:57 +00:00
" owned_by " : owned_by ,
2024-05-24 08:40:48 +00:00
" info " : custom_model . model_dump ( ) ,
2024-05-25 03:29:13 +00:00
" preset " : True ,
2024-07-04 20:41:18 +00:00
* * ( { " pipe " : pipe } if pipe is not None else { } ) ,
2024-07-28 12:00:58 +00:00
" action_ids " : action_ids ,
}
)
2024-11-15 11:00:18 +00:00
# Process action_ids to get the actions
2024-11-18 05:39:52 +00:00
def get_action_items_from_module ( function , module ) :
2024-11-15 11:00:18 +00:00
actions = [ ]
if hasattr ( module , " actions " ) :
actions = module . actions
return [
{
2024-11-18 05:39:52 +00:00
" id " : f " { function . id } . { action [ ' id ' ] } " ,
" name " : action . get ( " name " , f " { function . name } ( { action [ ' id ' ] } ) " ) ,
" description " : function . meta . description ,
2024-11-15 11:00:18 +00:00
" icon_url " : action . get (
2024-11-18 05:39:52 +00:00
" icon_url " , function . meta . manifest . get ( " icon_url " , None )
2024-11-15 11:00:18 +00:00
) ,
}
for action in actions
]
else :
return [
{
2024-11-18 05:39:52 +00:00
" id " : function . id ,
" name " : function . name ,
" description " : function . meta . description ,
" icon_url " : function . meta . manifest . get ( " icon_url " , None ) ,
2024-11-15 11:00:18 +00:00
}
]
def get_function_module_by_id ( function_id ) :
if function_id in webui_app . state . FUNCTIONS :
function_module = webui_app . state . FUNCTIONS [ function_id ]
else :
function_module , _ , _ = load_function_module_by_id ( function_id )
webui_app . state . FUNCTIONS [ function_id ] = function_module
2024-07-28 12:00:58 +00:00
2024-11-15 11:00:18 +00:00
for model in models :
2024-07-28 12:00:58 +00:00
action_ids = [
2024-11-15 11:00:18 +00:00
action_id
for action_id in list ( set ( model . pop ( " action_ids " , [ ] ) + global_action_ids ) )
if action_id in enabled_action_ids
2024-07-28 12:00:58 +00:00
]
model [ " actions " ] = [ ]
for action_id in action_ids :
2024-11-15 11:00:18 +00:00
action_function = Functions . get_function_by_id ( action_id )
if action_function is None :
2024-08-11 07:31:40 +00:00
raise Exception ( f " Action not found: { action_id } " )
2024-07-28 21:02:23 +00:00
2024-11-15 11:00:18 +00:00
function_module = get_function_module_by_id ( action_id )
2024-11-18 05:39:52 +00:00
model [ " actions " ] . extend (
get_action_items_from_module ( action_function , function_module )
)
2024-11-22 04:14:05 +00:00
log . debug ( f " get_all_models() returned { len ( models ) } models " )
2024-12-12 02:36:59 +00:00
app . state . MODELS = { model [ " id " ] : model for model in models }
2024-05-25 01:26:36 +00:00
return models
@app.get ( " /api/models " )
async def get_models ( user = Depends ( get_verified_user ) ) :
models = await get_all_models ( )
2024-05-28 02:03:26 +00:00
2024-05-28 02:34:05 +00:00
# Filter out filter pipelines
2024-05-28 02:03:26 +00:00
models = [
model
for model in models
2024-05-28 18:43:48 +00:00
if " pipeline " not in model or model [ " pipeline " ] . get ( " type " , None ) != " filter "
2024-05-28 02:03:26 +00:00
]
2024-11-26 08:55:58 +00:00
model_order_list = webui_app . state . config . MODEL_ORDER_LIST
if model_order_list :
model_order_dict = { model_id : i for i , model_id in enumerate ( model_order_list ) }
# Sort models by order list priority, with fallback for those not in the list
models . sort (
key = lambda x : ( model_order_dict . get ( x [ " id " ] , float ( " inf " ) ) , x [ " name " ] )
)
2024-11-16 12:41:07 +00:00
# Filter out models that the user does not have access to
2024-12-02 02:25:44 +00:00
if user . role == " user " and not BYPASS_MODEL_ACCESS_CONTROL :
2024-11-16 12:41:07 +00:00
filtered_models = [ ]
for model in models :
2024-11-18 15:18:47 +00:00
if model . get ( " arena " ) :
if has_access (
user . id ,
type = " read " ,
access_control = model . get ( " info " , { } )
. get ( " meta " , { } )
. get ( " access_control " , { } ) ,
) :
filtered_models . append ( model )
continue
2024-11-16 12:41:07 +00:00
model_info = Models . get_model_by_id ( model [ " id " ] )
if model_info :
2024-11-17 10:51:57 +00:00
if user . id == model_info . user_id or has_access (
2024-11-16 12:41:07 +00:00
user . id , type = " read " , access_control = model_info . access_control
) :
filtered_models . append ( model )
models = filtered_models
2024-05-24 08:40:48 +00:00
2024-11-23 04:11:46 +00:00
log . debug (
f " /api/models returned filtered models accessible to the user: { json . dumps ( [ model [ ' id ' ] for model in models ] ) } "
)
2024-11-22 04:14:05 +00:00
2024-05-24 08:40:48 +00:00
return { " data " : models }
2024-11-16 03:14:24 +00:00
@app.get ( " /api/models/base " )
async def get_base_models ( user = Depends ( get_admin_user ) ) :
models = await get_all_base_models ( )
2024-11-16 23:13:33 +00:00
# Filter out arena models
models = [ model for model in models if not model . get ( " arena " , False ) ]
2024-11-16 03:14:24 +00:00
return { " data " : models }
2024-12-12 02:05:42 +00:00
@app.post ( " /api/chat/completions " )
async def generate_chat_completions (
form_data : dict ,
user = Depends ( get_verified_user ) ,
bypass_filter : bool = False ,
) :
if BYPASS_MODEL_ACCESS_CONTROL :
bypass_filter = True
2024-12-12 02:36:59 +00:00
model_list = app . state . MODELS
2024-12-12 02:05:42 +00:00
models = { model [ " id " ] : model for model in model_list }
model_id = form_data [ " model " ]
if model_id not in models :
raise HTTPException (
status_code = status . HTTP_404_NOT_FOUND ,
detail = " Model not found " ,
)
model = models [ model_id ]
# Check if user has access to the model
if not bypass_filter and user . role == " user " :
if model . get ( " arena " ) :
if not has_access (
user . id ,
type = " read " ,
access_control = model . get ( " info " , { } )
. get ( " meta " , { } )
. get ( " access_control " , { } ) ,
) :
raise HTTPException (
status_code = 403 ,
detail = " Model not found " ,
)
else :
model_info = Models . get_model_by_id ( model_id )
if not model_info :
raise HTTPException (
status_code = 404 ,
detail = " Model not found " ,
)
elif not (
user . id == model_info . user_id
or has_access (
user . id , type = " read " , access_control = model_info . access_control
)
) :
raise HTTPException (
status_code = 403 ,
detail = " Model not found " ,
)
if model [ " owned_by " ] == " arena " :
model_ids = model . get ( " info " , { } ) . get ( " meta " , { } ) . get ( " model_ids " )
filter_mode = model . get ( " info " , { } ) . get ( " meta " , { } ) . get ( " filter_mode " )
if model_ids and filter_mode == " exclude " :
model_ids = [
model [ " id " ]
for model in await get_all_models ( )
if model . get ( " owned_by " ) != " arena " and model [ " id " ] not in model_ids
]
selected_model_id = None
if isinstance ( model_ids , list ) and model_ids :
selected_model_id = random . choice ( model_ids )
else :
model_ids = [
model [ " id " ]
for model in await get_all_models ( )
if model . get ( " owned_by " ) != " arena "
]
selected_model_id = random . choice ( model_ids )
form_data [ " model " ] = selected_model_id
if form_data . get ( " stream " ) == True :
async def stream_wrapper ( stream ) :
yield f " data: { json . dumps ( { ' selected_model_id ' : selected_model_id } ) } \n \n "
async for chunk in stream :
yield chunk
response = await generate_chat_completions (
form_data , user , bypass_filter = True
)
return StreamingResponse (
stream_wrapper ( response . body_iterator ) , media_type = " text/event-stream "
)
else :
return {
* * (
await generate_chat_completions ( form_data , user , bypass_filter = True )
) ,
" selected_model_id " : selected_model_id ,
}
if model . get ( " pipe " ) :
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
return await generate_function_chat_completion (
form_data , user = user , models = models
)
if model [ " owned_by " ] == " ollama " :
# Using /ollama/api/chat endpoint
form_data = convert_payload_openai_to_ollama ( form_data )
form_data = GenerateChatCompletionForm ( * * form_data )
response = await generate_ollama_chat_completion (
form_data = form_data , user = user , bypass_filter = bypass_filter
)
if form_data . stream :
response . headers [ " content-type " ] = " text/event-stream "
return StreamingResponse (
convert_streaming_response_ollama_to_openai ( response ) ,
headers = dict ( response . headers ) ,
)
else :
return convert_response_ollama_to_openai ( response )
else :
return await generate_openai_chat_completion (
form_data , user = user , bypass_filter = bypass_filter
)
@app.post ( " /api/chat/completed " )
async def chat_completed ( form_data : dict , user = Depends ( get_verified_user ) ) :
model_list = await get_all_models ( )
models = { model [ " id " ] : model for model in model_list }
data = form_data
model_id = data [ " model " ]
if model_id not in models :
raise HTTPException (
status_code = status . HTTP_404_NOT_FOUND ,
detail = " Model not found " ,
)
model = models [ model_id ]
sorted_filters = get_sorted_filters ( model_id , models )
if " pipeline " in model :
sorted_filters = [ model ] + sorted_filters
for filter in sorted_filters :
r = None
try :
urlIdx = filter [ " urlIdx " ]
2024-12-12 02:36:59 +00:00
url = app . state . config . OPENAI_API_BASE_URLS [ urlIdx ]
key = app . state . config . OPENAI_API_KEYS [ urlIdx ]
2024-12-12 02:05:42 +00:00
if key != " " :
headers = { " Authorization " : f " Bearer { key } " }
r = requests . post (
f " { url } / { filter [ ' id ' ] } /filter/outlet " ,
headers = headers ,
json = {
" user " : {
" id " : user . id ,
" name " : user . name ,
" email " : user . email ,
" role " : user . role ,
} ,
" body " : data ,
} ,
)
r . raise_for_status ( )
data = r . json ( )
except Exception as e :
# Handle connection error here
print ( f " Connection error: { e } " )
if r is not None :
try :
res = r . json ( )
if " detail " in res :
return JSONResponse (
status_code = r . status_code ,
content = res ,
)
except Exception :
pass
else :
pass
__event_emitter__ = get_event_emitter (
{
" chat_id " : data [ " chat_id " ] ,
" message_id " : data [ " id " ] ,
" session_id " : data [ " session_id " ] ,
}
)
__event_call__ = get_event_call (
{
" chat_id " : data [ " chat_id " ] ,
" message_id " : data [ " id " ] ,
" session_id " : data [ " session_id " ] ,
}
)
def get_priority ( function_id ) :
function = Functions . get_function_by_id ( function_id )
if function is not None and hasattr ( function , " valves " ) :
# TODO: Fix FunctionModel to include vavles
return ( function . valves if function . valves else { } ) . get ( " priority " , 0 )
return 0
filter_ids = [ function . id for function in Functions . get_global_filter_functions ( ) ]
if " info " in model and " meta " in model [ " info " ] :
filter_ids . extend ( model [ " info " ] [ " meta " ] . get ( " filterIds " , [ ] ) )
filter_ids = list ( set ( filter_ids ) )
enabled_filter_ids = [
function . id
for function in Functions . get_functions_by_type ( " filter " , active_only = True )
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
# Sort filter_ids by priority, using the get_priority function
filter_ids . sort ( key = get_priority )
for filter_id in filter_ids :
filter = Functions . get_function_by_id ( filter_id )
if not filter :
continue
if filter_id in webui_app . state . FUNCTIONS :
function_module = webui_app . state . FUNCTIONS [ filter_id ]
else :
function_module , _ , _ = load_function_module_by_id ( filter_id )
webui_app . state . FUNCTIONS [ filter_id ] = function_module
if hasattr ( function_module , " valves " ) and hasattr ( function_module , " Valves " ) :
valves = Functions . get_function_valves_by_id ( filter_id )
function_module . valves = function_module . Valves (
* * ( valves if valves else { } )
)
if not hasattr ( function_module , " outlet " ) :
continue
try :
outlet = function_module . outlet
# Get the signature of the function
sig = inspect . signature ( outlet )
params = { " body " : data }
# Extra parameters to be passed to the function
extra_params = {
" __model__ " : model ,
" __id__ " : filter_id ,
" __event_emitter__ " : __event_emitter__ ,
" __event_call__ " : __event_call__ ,
}
# Add extra params in contained in function signature
for key , value in extra_params . items ( ) :
if key in sig . parameters :
params [ key ] = value
if " __user__ " in sig . parameters :
__user__ = {
" id " : user . id ,
" email " : user . email ,
" name " : user . name ,
" role " : user . role ,
}
try :
if hasattr ( function_module , " UserValves " ) :
__user__ [ " valves " ] = function_module . UserValves (
* * Functions . get_user_valves_by_id_and_user_id (
filter_id , user . id
)
)
except Exception as e :
print ( e )
params = { * * params , " __user__ " : __user__ }
if inspect . iscoroutinefunction ( outlet ) :
data = await outlet ( * * params )
else :
data = outlet ( * * params )
except Exception as e :
print ( f " Error: { e } " )
return JSONResponse (
status_code = status . HTTP_400_BAD_REQUEST ,
content = { " detail " : str ( e ) } ,
)
return data
@app.post ( " /api/chat/actions/ {action_id} " )
async def chat_action ( action_id : str , form_data : dict , user = Depends ( get_verified_user ) ) :
if " . " in action_id :
action_id , sub_action_id = action_id . split ( " . " )
else :
sub_action_id = None
action = Functions . get_function_by_id ( action_id )
if not action :
raise HTTPException (
status_code = status . HTTP_404_NOT_FOUND ,
detail = " Action not found " ,
)
model_list = await get_all_models ( )
models = { model [ " id " ] : model for model in model_list }
data = form_data
model_id = data [ " model " ]
if model_id not in models :
raise HTTPException (
status_code = status . HTTP_404_NOT_FOUND ,
detail = " Model not found " ,
)
model = models [ model_id ]
__event_emitter__ = get_event_emitter (
{
" chat_id " : data [ " chat_id " ] ,
" message_id " : data [ " id " ] ,
" session_id " : data [ " session_id " ] ,
}
)
__event_call__ = get_event_call (
{
" chat_id " : data [ " chat_id " ] ,
" message_id " : data [ " id " ] ,
" session_id " : data [ " session_id " ] ,
}
)
if action_id in webui_app . state . FUNCTIONS :
function_module = webui_app . state . FUNCTIONS [ action_id ]
else :
function_module , _ , _ = load_function_module_by_id ( action_id )
webui_app . state . FUNCTIONS [ action_id ] = function_module
if hasattr ( function_module , " valves " ) and hasattr ( function_module , " Valves " ) :
valves = Functions . get_function_valves_by_id ( action_id )
function_module . valves = function_module . Valves ( * * ( valves if valves else { } ) )
if hasattr ( function_module , " action " ) :
try :
action = function_module . action
# Get the signature of the function
sig = inspect . signature ( action )
params = { " body " : data }
# Extra parameters to be passed to the function
extra_params = {
" __model__ " : model ,
" __id__ " : sub_action_id if sub_action_id is not None else action_id ,
" __event_emitter__ " : __event_emitter__ ,
" __event_call__ " : __event_call__ ,
}
# Add extra params in contained in function signature
for key , value in extra_params . items ( ) :
if key in sig . parameters :
params [ key ] = value
if " __user__ " in sig . parameters :
__user__ = {
" id " : user . id ,
" email " : user . email ,
" name " : user . name ,
" role " : user . role ,
}
try :
if hasattr ( function_module , " UserValves " ) :
__user__ [ " valves " ] = function_module . UserValves (
* * Functions . get_user_valves_by_id_and_user_id (
action_id , user . id
)
)
except Exception as e :
print ( e )
params = { * * params , " __user__ " : __user__ }
if inspect . iscoroutinefunction ( action ) :
data = await action ( * * params )
else :
data = action ( * * params )
except Exception as e :
print ( f " Error: { e } " )
return JSONResponse (
status_code = status . HTTP_400_BAD_REQUEST ,
content = { " detail " : str ( e ) } ,
)
return data
2024-06-20 08:51:39 +00:00
##################################
#
# Config Endpoints
#
##################################
2024-02-22 02:12:01 +00:00
@app.get ( " /api/config " )
2024-08-19 14:49:40 +00:00
async def get_app_config ( request : Request ) :
user = None
if " token " in request . cookies :
token = request . cookies . get ( " token " )
2024-11-06 05:14:02 +00:00
try :
data = decode_token ( token )
except Exception as e :
log . debug ( e )
raise HTTPException (
status_code = status . HTTP_401_UNAUTHORIZED ,
detail = " Invalid token " ,
)
2024-08-19 14:49:40 +00:00
if data is not None and " id " in data :
user = Users . get_user_by_id ( data [ " id " ] )
2024-11-06 04:47:23 +00:00
onboarding = False
2024-11-03 11:00:28 +00:00
if user is None :
user_count = Users . get_num_users ( )
2024-11-06 04:47:23 +00:00
onboarding = user_count == 0
2024-11-03 11:00:28 +00:00
2024-02-22 02:12:01 +00:00
return {
2024-11-06 04:47:23 +00:00
* * ( { " onboarding " : True } if onboarding else { } ) ,
2024-02-22 02:12:01 +00:00
" status " : True ,
2024-02-24 01:12:19 +00:00
" name " : WEBUI_NAME ,
2024-02-23 08:30:26 +00:00
" version " : VERSION ,
2024-06-30 21:48:05 +00:00
" default_locale " : str ( DEFAULT_LOCALE ) ,
2024-08-19 14:49:40 +00:00
" oauth " : {
" providers " : {
name : config . get ( " name " , name )
for name , config in OAUTH_PROVIDERS . items ( )
}
} ,
2024-05-26 20:02:40 +00:00
" features " : {
2024-05-26 16:05:26 +00:00
" auth " : WEBUI_AUTH ,
" auth_trusted_header " : bool ( webui_app . state . AUTH_TRUSTED_EMAIL_HEADER ) ,
2024-11-06 04:32:09 +00:00
" enable_ldap " : webui_app . state . config . ENABLE_LDAP ,
2024-11-19 20:17:23 +00:00
" enable_api_key " : webui_app . state . config . ENABLE_API_KEY ,
2024-05-26 20:02:40 +00:00
" enable_signup " : webui_app . state . config . ENABLE_SIGNUP ,
2024-07-25 01:44:40 +00:00
" enable_login_form " : webui_app . state . config . ENABLE_LOGIN_FORM ,
2024-08-19 14:49:40 +00:00
* * (
{
2024-09-27 23:27:46 +00:00
" enable_web_search " : retrieval_app . state . config . ENABLE_RAG_WEB_SEARCH ,
2024-08-19 14:49:40 +00:00
" enable_image_generation " : images_app . state . config . ENABLED ,
" enable_community_sharing " : webui_app . state . config . ENABLE_COMMUNITY_SHARING ,
" enable_message_rating " : webui_app . state . config . ENABLE_MESSAGE_RATING ,
" enable_admin_export " : ENABLE_ADMIN_EXPORT ,
" enable_admin_chat_access " : ENABLE_ADMIN_CHAT_ACCESS ,
}
if user is not None
else { }
) ,
2024-06-08 03:18:48 +00:00
} ,
2024-08-19 14:49:40 +00:00
* * (
{
" default_models " : webui_app . state . config . DEFAULT_MODELS ,
" default_prompt_suggestions " : webui_app . state . config . DEFAULT_PROMPT_SUGGESTIONS ,
" audio " : {
" tts " : {
" engine " : audio_app . state . config . TTS_ENGINE ,
" voice " : audio_app . state . config . TTS_VOICE ,
2024-08-25 00:35:42 +00:00
" split_on " : audio_app . state . config . TTS_SPLIT_ON ,
2024-08-19 14:49:40 +00:00
} ,
" stt " : {
" engine " : audio_app . state . config . STT_ENGINE ,
} ,
} ,
2024-08-27 15:05:24 +00:00
" file " : {
2024-09-27 23:27:46 +00:00
" max_size " : retrieval_app . state . config . FILE_MAX_SIZE ,
" max_count " : retrieval_app . state . config . FILE_MAX_COUNT ,
2024-08-27 15:05:24 +00:00
} ,
2024-08-19 14:49:40 +00:00
" permissions " : { * * webui_app . state . config . USER_PERMISSIONS } ,
2024-05-26 07:37:09 +00:00
}
2024-08-19 14:49:40 +00:00
if user is not None
else { }
) ,
2024-02-22 02:12:01 +00:00
}
2024-12-10 08:54:13 +00:00
class UrlForm ( BaseModel ) :
url : str
2024-06-20 08:51:39 +00:00
2024-03-21 01:35:02 +00:00
@app.get ( " /api/webhook " )
async def get_webhook_url ( user = Depends ( get_admin_user ) ) :
return {
2024-05-10 07:03:24 +00:00
" url " : app . state . config . WEBHOOK_URL ,
2024-03-21 01:35:02 +00:00
}
@app.post ( " /api/webhook " )
async def update_webhook_url ( form_data : UrlForm , user = Depends ( get_admin_user ) ) :
2024-05-10 07:03:24 +00:00
app . state . config . WEBHOOK_URL = form_data . url
webui_app . state . WEBHOOK_URL = app . state . config . WEBHOOK_URL
2024-06-04 04:17:43 +00:00
return { " url " : app . state . config . WEBHOOK_URL }
2024-05-26 16:23:24 +00:00
2024-03-05 08:59:35 +00:00
@app.get ( " /api/version " )
2024-08-03 13:24:26 +00:00
async def get_app_version ( ) :
2024-03-05 08:59:35 +00:00
return {
" version " : VERSION ,
}
2024-02-25 19:26:58 +00:00
@app.get ( " /api/version/updates " )
async def get_app_latest_release_version ( ) :
2024-10-08 05:13:49 +00:00
if OFFLINE_MODE :
log . debug (
f " Offline mode is enabled, returning current version as latest version "
)
return { " current " : VERSION , " latest " : VERSION }
2024-02-25 19:26:58 +00:00
try :
2024-09-27 12:38:56 +00:00
timeout = aiohttp . ClientTimeout ( total = 1 )
async with aiohttp . ClientSession ( timeout = timeout , trust_env = True ) as session :
2024-04-10 06:03:05 +00:00
async with session . get (
2024-10-21 01:38:06 +00:00
" https://api.github.com/repos/open-webui/open-webui/releases/latest "
2024-04-10 06:03:05 +00:00
) as response :
response . raise_for_status ( )
data = await response . json ( )
latest_version = data [ " tag_name " ]
return { " current " : VERSION , " latest " : latest_version [ 1 : ] }
2024-09-30 14:32:38 +00:00
except Exception as e :
log . debug ( e )
2024-09-27 12:38:56 +00:00
return { " current " : VERSION , " latest " : VERSION }
2024-02-25 19:26:58 +00:00
2024-04-10 08:27:19 +00:00
2024-12-10 08:54:13 +00:00
@app.get ( " /api/changelog " )
async def get_app_changelog ( ) :
return { key : CHANGELOG [ key ] for idx , key in enumerate ( CHANGELOG ) if idx < 5 }
2024-05-27 17:07:38 +00:00
############################
# OAuth Login & Callback
############################
# SessionMiddleware is used by authlib for oauth
if len ( OAUTH_PROVIDERS ) > 0 :
app . add_middleware (
2024-06-05 18:21:42 +00:00
SessionMiddleware ,
secret_key = WEBUI_SECRET_KEY ,
session_cookie = " oui-session " ,
same_site = WEBUI_SESSION_COOKIE_SAME_SITE ,
2024-06-07 08:13:42 +00:00
https_only = WEBUI_SESSION_COOKIE_SECURE ,
2024-05-27 17:07:38 +00:00
)
@app.get ( " /oauth/ {provider} /login " )
async def oauth_login ( provider : str , request : Request ) :
2024-10-16 14:32:57 +00:00
return await oauth_manager . handle_login ( provider , request )
2024-05-27 17:07:38 +00:00
2024-06-21 17:25:19 +00:00
# OAuth login logic is as follows:
# 1. Attempt to find a user with matching subject ID, tied to the provider
# 2. If OAUTH_MERGE_ACCOUNTS_BY_EMAIL is true, find a user with the email address provided via OAuth
# - This is considered insecure in general, as OAuth providers do not always verify email addresses
# 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user
2024-10-14 07:13:26 +00:00
# - Email addresses are considered unique, so we fail registration if the email address is already taken
2024-05-27 17:07:38 +00:00
@app.get ( " /oauth/ {provider} /callback " )
2024-06-21 13:35:11 +00:00
async def oauth_callback ( provider : str , request : Request , response : Response ) :
2024-10-16 14:32:57 +00:00
return await oauth_manager . handle_callback ( provider , request , response )
2024-05-27 17:07:38 +00:00
2024-04-02 18:55:00 +00:00
@app.get ( " /manifest.json " )
async def get_manifest_json ( ) :
return {
2024-04-04 03:43:55 +00:00
" name " : WEBUI_NAME ,
" short_name " : WEBUI_NAME ,
2024-09-24 11:28:00 +00:00
" description " : " Open WebUI is an open, extensible, user-friendly interface for AI that adapts to your workflow. " ,
2024-04-02 18:55:00 +00:00
" start_url " : " / " ,
" display " : " standalone " ,
" background_color " : " #343541 " ,
2024-11-03 08:59:53 +00:00
" orientation " : " natural " ,
2024-08-15 11:24:47 +00:00
" icons " : [
{
" src " : " /static/logo.png " ,
" type " : " image/png " ,
" sizes " : " 500x500 " ,
" purpose " : " any " ,
} ,
{
" src " : " /static/logo.png " ,
" type " : " image/png " ,
" sizes " : " 500x500 " ,
" purpose " : " maskable " ,
} ,
] ,
2024-04-02 18:55:00 +00:00
}
2024-04-10 08:27:19 +00:00
2024-05-07 00:29:16 +00:00
@app.get ( " /opensearch.xml " )
async def get_opensearch_xml ( ) :
xml_content = rf """
< OpenSearchDescription xmlns = " http://a9.com/-/spec/opensearch/1.1/ " xmlns : moz = " http://www.mozilla.org/2006/browser/search/ " >
< ShortName > { WEBUI_NAME } < / ShortName >
< Description > Search { WEBUI_NAME } < / Description >
< InputEncoding > UTF - 8 < / InputEncoding >
2024-07-09 06:07:23 +00:00
< Image width = " 16 " height = " 16 " type = " image/x-icon " > { WEBUI_URL } / static / favicon . png < / Image >
2024-05-07 00:29:16 +00:00
< Url type = " text/html " method = " get " template = " {WEBUI_URL} /?q= { " { searchTerms } " } " / >
< moz : SearchForm > { WEBUI_URL } < / moz : SearchForm >
< / OpenSearchDescription >
"""
return Response ( content = xml_content , media_type = " application/xml " )
2024-05-15 18:17:18 +00:00
@app.get ( " /health " )
async def healthcheck ( ) :
return { " status " : True }
2024-06-18 13:03:31 +00:00
@app.get ( " /health/db " )
2024-06-21 12:58:57 +00:00
async def healthcheck_with_db ( ) :
2024-06-24 11:06:15 +00:00
Session . execute ( text ( " SELECT 1; " ) ) . all ( )
2024-06-18 13:03:31 +00:00
return { " status " : True }
2024-04-09 10:32:28 +00:00
app . mount ( " /static " , StaticFiles ( directory = STATIC_DIR ) , name = " static " )
app . mount ( " /cache " , StaticFiles ( directory = CACHE_DIR ) , name = " cache " )
2024-02-24 01:12:19 +00:00
2024-04-28 15:03:30 +00:00
if os . path . exists ( FRONTEND_BUILD_DIR ) :
2024-05-22 04:38:58 +00:00
mimetypes . add_type ( " text/javascript " , " .js " )
2024-04-28 15:03:30 +00:00
app . mount (
" / " ,
SPAStaticFiles ( directory = FRONTEND_BUILD_DIR , html = True ) ,
name = " spa-static-files " ,
)
else :
log . warning (
f " Frontend build directory not found at ' { FRONTEND_BUILD_DIR } ' . Serving API only. "
)