Merge branch 'open-webui:dev' into dev

This commit is contained in:
smonux
2024-09-28 06:01:26 +02:00
committed by GitHub
50 changed files with 1113 additions and 1091 deletions

View File

@@ -27,7 +27,6 @@ from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel
from starlette.background import BackgroundTask
from open_webui.utils.payload import (
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
@@ -47,7 +46,6 @@ app.add_middleware(
allow_headers=["*"],
)
app.state.config = AppConfig()
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
@@ -407,20 +405,25 @@ async def generate_chat_completion(
url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx]
is_o1 = payload["model"].lower().startswith("o1-")
# Change max_completion_tokens to max_tokens (Backward compatible)
if "api.openai.com" not in url and not payload["model"].lower().startswith("o1-"):
if "api.openai.com" not in url and not is_o1:
if "max_completion_tokens" in payload:
# Remove "max_completion_tokens" from the payload
payload["max_tokens"] = payload["max_completion_tokens"]
del payload["max_completion_tokens"]
else:
if payload["model"].lower().startswith("o1-") and "max_tokens" in payload:
if is_o1 and "max_tokens" in payload:
payload["max_completion_tokens"] = payload["max_tokens"]
del payload["max_tokens"]
if "max_tokens" in payload and "max_completion_tokens" in payload:
del payload["max_tokens"]
# Fix: O1 does not support the "system" parameter, Modify "system" to "user"
if is_o1 and payload["messages"][0]["role"] == "system":
payload["messages"][0]["role"] = "user"
# Convert the modified body back to JSON
payload = json.dumps(payload)

View File

@@ -0,0 +1,190 @@
import requests
import logging
import ftfy
from langchain_community.document_loaders import (
BSHTMLLoader,
CSVLoader,
Docx2txtLoader,
OutlookMessageLoader,
PyPDFLoader,
TextLoader,
UnstructuredEPubLoader,
UnstructuredExcelLoader,
UnstructuredMarkdownLoader,
UnstructuredPowerPointLoader,
UnstructuredRSTLoader,
UnstructuredXMLLoader,
YoutubeLoader,
)
from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
known_source_ext = [
"go",
"py",
"java",
"sh",
"bat",
"ps1",
"cmd",
"js",
"ts",
"css",
"cpp",
"hpp",
"h",
"c",
"cs",
"sql",
"log",
"ini",
"pl",
"pm",
"r",
"dart",
"dockerfile",
"env",
"php",
"hs",
"hsc",
"lua",
"nginxconf",
"conf",
"m",
"mm",
"plsql",
"perl",
"rb",
"rs",
"db2",
"scala",
"bash",
"swift",
"vue",
"svelte",
"msg",
"ex",
"exs",
"erl",
"tsx",
"jsx",
"hs",
"lhs",
]
class TikaLoader:
def __init__(self, url, file_path, mime_type=None):
self.url = url
self.file_path = file_path
self.mime_type = mime_type
def load(self) -> list[Document]:
with open(self.file_path, "rb") as f:
data = f.read()
if self.mime_type is not None:
headers = {"Content-Type": self.mime_type}
else:
headers = {}
endpoint = self.url
if not endpoint.endswith("/"):
endpoint += "/"
endpoint += "tika/text"
r = requests.put(endpoint, data=data, headers=headers)
if r.ok:
raw_metadata = r.json()
text = raw_metadata.get("X-TIKA:content", "<No text content found>")
if "Content-Type" in raw_metadata:
headers["Content-Type"] = raw_metadata["Content-Type"]
log.info("Tika extracted text: %s", text)
return [Document(page_content=text, metadata=headers)]
else:
raise Exception(f"Error calling Tika: {r.reason}")
class Loader:
def __init__(self, engine: str = "", **kwargs):
self.engine = engine
self.kwargs = kwargs
def load(
self, filename: str, file_content_type: str, file_path: str
) -> list[Document]:
loader = self._get_loader(filename, file_content_type, file_path)
docs = loader.load()
return [
Document(
page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata
)
for doc in docs
]
def _get_loader(self, filename: str, file_content_type: str, file_path: str):
file_ext = filename.split(".")[-1].lower()
if self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
if file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0
):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
loader = TikaLoader(
url=self.kwargs.get("TIKA_SERVER_URL"),
file_path=file_path,
mime_type=file_content_type,
)
else:
if file_ext == "pdf":
loader = PyPDFLoader(
file_path, extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES")
)
elif file_ext == "csv":
loader = CSVLoader(file_path)
elif file_ext == "rst":
loader = UnstructuredRSTLoader(file_path, mode="elements")
elif file_ext == "xml":
loader = UnstructuredXMLLoader(file_path)
elif file_ext in ["htm", "html"]:
loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
elif file_ext == "md":
loader = UnstructuredMarkdownLoader(file_path)
elif file_content_type == "application/epub+zip":
loader = UnstructuredEPubLoader(file_path)
elif (
file_content_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
or file_ext == "docx"
):
loader = Docx2txtLoader(file_path)
elif file_content_type in [
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
] or file_ext in ["xls", "xlsx"]:
loader = UnstructuredExcelLoader(file_path)
elif file_content_type in [
"application/vnd.ms-powerpoint",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
] or file_ext in ["ppt", "pptx"]:
loader = UnstructuredPowerPointLoader(file_path)
elif file_ext == "msg":
loader = OutlookMessageLoader(file_path)
elif file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0
):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
loader = TextLoader(file_path, autodetect_encoding=True)
return loader

View File

@@ -0,0 +1,81 @@
import os
import torch
import numpy as np
from colbert.infra import ColBERTConfig
from colbert.modeling.checkpoint import Checkpoint
class ColBERT:
def __init__(self, name, **kwargs) -> None:
print("ColBERT: Loading model", name)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
DOCKER = kwargs.get("env") == "docker"
if DOCKER:
# This is a workaround for the issue with the docker container
# where the torch extension is not loaded properly
# and the following error is thrown:
# /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory
lock_file = (
"/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock"
)
if os.path.exists(lock_file):
os.remove(lock_file)
self.ckpt = Checkpoint(
name,
colbert_config=ColBERTConfig(model_name=name),
).to(self.device)
pass
def calculate_similarity_scores(self, query_embeddings, document_embeddings):
query_embeddings = query_embeddings.to(self.device)
document_embeddings = document_embeddings.to(self.device)
# Validate dimensions to ensure compatibility
if query_embeddings.dim() != 3:
raise ValueError(
f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
)
if document_embeddings.dim() != 3:
raise ValueError(
f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
)
if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
raise ValueError(
"There should be either one query or queries equal to the number of documents."
)
# Transpose the query embeddings to align for matrix multiplication
transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
# Compute similarity scores using batch matrix multiplication
computed_scores = torch.matmul(document_embeddings, transposed_query_embeddings)
# Apply max pooling to extract the highest semantic similarity across each document's sequence
maximum_scores = torch.max(computed_scores, dim=1).values
# Sum up the maximum scores across features to get the overall document relevance scores
final_scores = maximum_scores.sum(dim=1)
normalized_scores = torch.softmax(final_scores, dim=0)
return normalized_scores.detach().cpu().numpy().astype(np.float32)
def predict(self, sentences):
query = sentences[0][0]
docs = [i[1] for i in sentences]
# Embedding the documents
embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0]
# Embedding the queries
embedded_queries = self.ckpt.queryFromText([query], bsize=32)
embedded_query = embedded_queries[0]
# Calculate retrieval scores for the query against all documents
scores = self.calculate_similarity_scores(
embedded_query.unsqueeze(0), embedded_docs
)
return scores

View File

@@ -15,7 +15,7 @@ from open_webui.apps.ollama.main import (
GenerateEmbeddingsForm,
generate_ollama_embeddings,
)
from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
from open_webui.env import SRC_LOG_LEVELS

View File

@@ -1,5 +1,5 @@
from open_webui.apps.rag.vector.dbs.chroma import ChromaClient
from open_webui.apps.rag.vector.dbs.milvus import MilvusClient
from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient
from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient
from open_webui.config import VECTOR_DB

View File

@@ -4,7 +4,7 @@ from chromadb.utils.batch_utils import create_batches
from typing import Optional
from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import (
CHROMA_DATA_PATH,
CHROMA_HTTP_HOST,

View File

@@ -4,7 +4,7 @@ import json
from typing import Optional
from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import (
MILVUS_URI,
)
@@ -98,7 +98,10 @@ class MilvusClient:
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector", index_type="HNSW", metric_type="COSINE", params={}
field_name="vector",
index_type="HNSW",
metric_type="COSINE",
params={"M": 16, "efConstruction": 100},
)
self.client.create_collection(

View File

@@ -2,7 +2,7 @@ import logging
from typing import Optional
import requests
from open_webui.apps.rag.search.main import SearchResult, get_filtered_results
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@@ -1,7 +1,7 @@
import logging
from typing import Optional
from open_webui.apps.rag.search.main import SearchResult, get_filtered_results
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from duckduckgo_search import DDGS
from open_webui.env import SRC_LOG_LEVELS

View File

@@ -2,7 +2,7 @@ import logging
from typing import Optional
import requests
from open_webui.apps.rag.search.main import SearchResult, get_filtered_results
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@@ -1,7 +1,7 @@
import logging
import requests
from open_webui.apps.rag.search.main import SearchResult
from open_webui.apps.retrieval.web.main import SearchResult
from open_webui.env import SRC_LOG_LEVELS
from yarl import URL

View File

@@ -3,7 +3,7 @@ from typing import Optional
from urllib.parse import urlencode
import requests
from open_webui.apps.rag.search.main import SearchResult, get_filtered_results
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@@ -2,7 +2,7 @@ import logging
from typing import Optional
import requests
from open_webui.apps.rag.search.main import SearchResult, get_filtered_results
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@@ -3,7 +3,7 @@ import logging
from typing import Optional
import requests
from open_webui.apps.rag.search.main import SearchResult, get_filtered_results
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@@ -3,7 +3,7 @@ from typing import Optional
from urllib.parse import urlencode
import requests
from open_webui.apps.rag.search.main import SearchResult, get_filtered_results
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@@ -2,7 +2,7 @@ import logging
from typing import Optional
import requests
from open_webui.apps.rag.search.main import SearchResult, get_filtered_results
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@@ -1,7 +1,7 @@
import logging
import requests
from open_webui.apps.rag.search.main import SearchResult
from open_webui.apps.retrieval.web.main import SearchResult
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@@ -0,0 +1,97 @@
import socket
import urllib.parse
import validators
from typing import Union, Sequence, Iterator
from langchain_community.document_loaders import (
WebBaseLoader,
)
from langchain_core.documents import Document
from open_webui.constants import ERROR_MESSAGES
from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH
from open_webui.env import SRC_LOG_LEVELS
import logging
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def validate_url(url: Union[str, Sequence[str]]):
if isinstance(url, str):
if isinstance(validators.url(url), validators.ValidationError):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_RAG_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
parsed_url = urllib.parse.urlparse(url)
# Get IPv4 and IPv6 addresses
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
# Check if any of the resolved addresses are private
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
for ip in ipv4_addresses:
if validators.ipv4(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
for ip in ipv6_addresses:
if validators.ipv6(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return True
elif isinstance(url, Sequence):
return all(validate_url(u) for u in url)
else:
return False
def resolve_hostname(hostname):
# Get address information
addr_info = socket.getaddrinfo(hostname, None)
# Extract IP addresses from address information
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
return ipv4_addresses, ipv6_addresses
class SafeWebBaseLoader(WebBaseLoader):
"""WebBaseLoader with enhanced error handling for URLs."""
def lazy_load(self) -> Iterator[Document]:
"""Lazy load text from the url(s) in web_path with error handling."""
for path in self.web_paths:
try:
soup = self._scrape(path, bs_kwargs=self.bs_kwargs)
text = soup.get_text(**self.bs_get_text_kwargs)
# Build metadata
metadata = {"source": path}
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get(
"content", "No description found."
)
if html := soup.find("html"):
metadata["language"] = html.get("lang", "No language found.")
yield Document(page_content=text, metadata=metadata)
except Exception as e:
# Log the error and continue with the next URL
log.error(f"Error loading {path}: {e}")
def get_web_loader(
url: Union[str, Sequence[str]],
verify_ssl: bool = True,
requests_per_second: int = 2,
):
# Check if the URL is valid
if not validate_url(url):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return SafeWebBaseLoader(
url,
verify_ssl=verify_ssl,
requests_per_second=requests_per_second,
continue_on_failure=True,
)

View File

@@ -97,6 +97,17 @@ class FilesTable:
for file in db.query(File).filter_by(user_id=user_id).all()
]
def update_files_metadata_by_id(self, id: str, meta: dict) -> Optional[FileModel]:
with get_db() as db:
try:
file = db.query(File).filter_by(id=id).first()
file.meta = {**file.meta, **meta}
db.commit()
return FileModel.model_validate(file)
except Exception:
return None
def delete_file_by_id(self, id: str) -> bool:
with get_db() as db:
try:

View File

@@ -171,6 +171,19 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
)
@router.get("/{id}/content/text")
async def get_file_text_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"):
return {"text": file.meta.get("content", {}).get("text", None)}
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
@router.get("/{id}/content/{file_name}", response_model=Optional[FileModel])
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)

View File

@@ -4,7 +4,7 @@ import logging
from typing import Optional
from open_webui.apps.webui.models.memories import Memories, MemoryModel
from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.utils import get_verified_user
from open_webui.env import SRC_LOG_LEVELS

View File

@@ -921,7 +921,7 @@ CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
####################################
# RAG
# Information Retrieval (RAG)
####################################
# RAG Content Extraction

View File

@@ -16,37 +16,45 @@ from typing import Optional
import aiohttp
import requests
from open_webui.apps.audio.main import app as audio_app
from open_webui.apps.images.main import app as images_app
from open_webui.apps.ollama.main import app as ollama_app
from open_webui.apps.ollama.main import (
GenerateChatCompletionForm,
app as ollama_app,
get_all_models as get_ollama_models,
generate_chat_completion as generate_ollama_chat_completion,
generate_openai_chat_completion as generate_ollama_openai_chat_completion,
GenerateChatCompletionForm,
)
from open_webui.apps.ollama.main import get_all_models as get_ollama_models
from open_webui.apps.openai.main import app as openai_app
from open_webui.apps.openai.main import (
app as openai_app,
generate_chat_completion as generate_openai_chat_completion,
get_all_models as get_openai_models,
)
from open_webui.apps.openai.main import get_all_models as get_openai_models
from open_webui.apps.rag.main import app as rag_app
from open_webui.apps.rag.utils import get_rag_context, rag_template
from open_webui.apps.socket.main import app as socket_app, periodic_usage_pool_cleanup
from open_webui.apps.socket.main import get_event_call, get_event_emitter
from open_webui.apps.webui.internal.db import Session
from open_webui.apps.webui.main import app as webui_app
from open_webui.apps.retrieval.main import app as retrieval_app
from open_webui.apps.retrieval.utils import get_rag_context, rag_template
from open_webui.apps.socket.main import (
app as socket_app,
periodic_usage_pool_cleanup,
get_event_call,
get_event_emitter,
)
from open_webui.apps.webui.main import (
app as webui_app,
generate_function_chat_completion,
get_pipe_models,
)
from open_webui.apps.webui.internal.db import Session
from open_webui.apps.webui.models.auths import Auths
from open_webui.apps.webui.models.functions import Functions
from open_webui.apps.webui.models.models import Models
from open_webui.apps.webui.models.users import UserModel, Users
from open_webui.apps.webui.utils import load_function_module_by_id
from open_webui.apps.audio.main import app as audio_app
from open_webui.apps.images.main import app as images_app
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
@@ -492,11 +500,11 @@ async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
contexts, citations = get_rag_context(
files=files,
messages=body["messages"],
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.config.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
embedding_function=retrieval_app.state.EMBEDDING_FUNCTION,
k=retrieval_app.state.config.TOP_K,
reranking_function=retrieval_app.state.sentence_transformer_rf,
r=retrieval_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=retrieval_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
log.debug(f"rag_contexts: {contexts}, citations: {citations}")
@@ -609,7 +617,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if prompt is None:
raise Exception("No user message found")
if (
rag_app.state.config.RELEVANCE_THRESHOLD == 0
retrieval_app.state.config.RELEVANCE_THRESHOLD == 0
and context_string.strip() == ""
):
log.debug(
@@ -621,14 +629,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if model["owned_by"] == "ollama":
body["messages"] = prepend_to_first_user_message_content(
rag_template(
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt
),
body["messages"],
)
else:
body["messages"] = add_or_update_system_message(
rag_template(
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt
),
body["messages"],
)
@@ -762,10 +770,22 @@ class PipelineMiddleware(BaseHTTPMiddleware):
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
user = get_current_user(
request,
get_http_authorization_cred(request.headers["Authorization"]),
)
try:
user = get_current_user(
request,
get_http_authorization_cred(request.headers["Authorization"]),
)
except KeyError as e:
if len(e.args) > 1:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
else:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Not authenticated"},
)
try:
data = filter_pipeline(data, user)
@@ -838,7 +858,7 @@ async def check_url(request: Request, call_next):
async def update_embedding_function(request: Request, call_next):
response = await call_next(request)
if "/embedding/update" in request.url.path:
webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION
return response
@@ -866,11 +886,12 @@ app.mount("/openai", openai_app)
app.mount("/images/api/v1", images_app)
app.mount("/audio/api/v1", audio_app)
app.mount("/rag/api/v1", rag_app)
app.mount("/retrieval/api/v1", retrieval_app)
app.mount("/api/v1", webui_app)
webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION
async def get_all_models():
@@ -2055,7 +2076,7 @@ async def get_app_config(request: Request):
"enable_login_form": webui_app.state.config.ENABLE_LOGIN_FORM,
**(
{
"enable_web_search": rag_app.state.config.ENABLE_RAG_WEB_SEARCH,
"enable_web_search": retrieval_app.state.config.ENABLE_RAG_WEB_SEARCH,
"enable_image_generation": images_app.state.config.ENABLED,
"enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING,
"enable_message_rating": webui_app.state.config.ENABLE_MESSAGE_RATING,
@@ -2081,8 +2102,8 @@ async def get_app_config(request: Request):
},
},
"file": {
"max_size": rag_app.state.config.FILE_MAX_SIZE,
"max_count": rag_app.state.config.FILE_MAX_COUNT,
"max_size": retrieval_app.state.config.FILE_MAX_SIZE,
"max_count": retrieval_app.state.config.FILE_MAX_COUNT,
},
"permissions": {**webui_app.state.config.USER_PERMISSIONS},
}
@@ -2154,7 +2175,8 @@ async def get_app_changelog():
@app.get("/api/version/updates")
async def get_app_latest_release_version():
try:
async with aiohttp.ClientSession(trust_env=True) as session:
timeout = aiohttp.ClientTimeout(total=1)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(
"https://api.github.com/repos/open-webui/open-webui/releases/latest"
) as response:
@@ -2164,10 +2186,7 @@ async def get_app_latest_release_version():
return {"current": VERSION, "latest": latest_version[1:]}
except aiohttp.ClientError:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED,
)
return {"current": VERSION, "latest": VERSION}
############################