mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge branch 'dev' of https://github.com/open-webui/open-webui into Dev-Individual-RAG-Config
This commit is contained in:
@@ -5,6 +5,7 @@ import os
|
||||
import pkgutil
|
||||
import sys
|
||||
import shutil
|
||||
from uuid import uuid4
|
||||
from pathlib import Path
|
||||
|
||||
import markdown
|
||||
@@ -130,6 +131,7 @@ else:
|
||||
PACKAGE_DATA = {"version": "0.0.0"}
|
||||
|
||||
VERSION = PACKAGE_DATA["version"]
|
||||
INSTANCE_ID = os.environ.get("INSTANCE_ID", str(uuid4()))
|
||||
|
||||
|
||||
# Function to parse each section
|
||||
|
||||
@@ -8,6 +8,8 @@ import shutil
|
||||
import sys
|
||||
import time
|
||||
import random
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from urllib.parse import urlencode, parse_qs, urlparse
|
||||
@@ -19,6 +21,7 @@ from aiocache import cached
|
||||
import aiohttp
|
||||
import anyio.to_thread
|
||||
import requests
|
||||
from redis import Redis
|
||||
|
||||
|
||||
from fastapi import (
|
||||
@@ -231,6 +234,9 @@ from open_webui.config import (
|
||||
DOCLING_OCR_ENGINE,
|
||||
DOCLING_OCR_LANG,
|
||||
DOCLING_DO_PICTURE_DESCRIPTION,
|
||||
DOCLING_PICTURE_DESCRIPTION_MODE,
|
||||
DOCLING_PICTURE_DESCRIPTION_LOCAL,
|
||||
DOCLING_PICTURE_DESCRIPTION_API,
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
DOCUMENT_INTELLIGENCE_KEY,
|
||||
MISTRAL_OCR_API_KEY,
|
||||
@@ -393,6 +399,7 @@ from open_webui.env import (
|
||||
SAFE_MODE,
|
||||
SRC_LOG_LEVELS,
|
||||
VERSION,
|
||||
INSTANCE_ID,
|
||||
WEBUI_BUILD_HASH,
|
||||
WEBUI_SECRET_KEY,
|
||||
WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
@@ -434,8 +441,10 @@ from open_webui.utils.auth import (
|
||||
from open_webui.utils.plugin import install_tool_and_function_dependencies
|
||||
from open_webui.utils.oauth import OAuthManager
|
||||
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
||||
from open_webui.utils.redis import get_redis_connection
|
||||
|
||||
from open_webui.tasks import (
|
||||
redis_task_command_listener,
|
||||
list_task_ids_by_chat_id,
|
||||
stop_task,
|
||||
list_tasks,
|
||||
@@ -487,7 +496,9 @@ https://github.com/open-webui/open-webui
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
app.state.instance_id = INSTANCE_ID
|
||||
start_logger()
|
||||
|
||||
if RESET_CONFIG_ON_START:
|
||||
reset_config()
|
||||
|
||||
@@ -499,6 +510,19 @@ async def lifespan(app: FastAPI):
|
||||
log.info("Installing external dependencies of functions and tools...")
|
||||
install_tool_and_function_dependencies()
|
||||
|
||||
app.state.redis = get_redis_connection(
|
||||
redis_url=REDIS_URL,
|
||||
redis_sentinels=get_sentinels_from_env(
|
||||
REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT
|
||||
),
|
||||
async_mode=True,
|
||||
)
|
||||
|
||||
if app.state.redis is not None:
|
||||
app.state.redis_task_command_listener = asyncio.create_task(
|
||||
redis_task_command_listener(app)
|
||||
)
|
||||
|
||||
if THREAD_POOL_SIZE and THREAD_POOL_SIZE > 0:
|
||||
limiter = anyio.to_thread.current_default_thread_limiter()
|
||||
limiter.total_tokens = THREAD_POOL_SIZE
|
||||
@@ -507,6 +531,9 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
yield
|
||||
|
||||
if hasattr(app.state, "redis_task_command_listener"):
|
||||
app.state.redis_task_command_listener.cancel()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Open WebUI",
|
||||
@@ -518,10 +545,12 @@ app = FastAPI(
|
||||
|
||||
oauth_manager = OAuthManager(app)
|
||||
|
||||
app.state.instance_id = None
|
||||
app.state.config = AppConfig(
|
||||
redis_url=REDIS_URL,
|
||||
redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT),
|
||||
)
|
||||
app.state.redis = None
|
||||
|
||||
app.state.WEBUI_NAME = WEBUI_NAME
|
||||
app.state.LICENSE_METADATA = None
|
||||
@@ -706,6 +735,9 @@ app.state.config.DOCLING_SERVER_URL = DOCLING_SERVER_URL
|
||||
app.state.config.DOCLING_OCR_ENGINE = DOCLING_OCR_ENGINE
|
||||
app.state.config.DOCLING_OCR_LANG = DOCLING_OCR_LANG
|
||||
app.state.config.DOCLING_DO_PICTURE_DESCRIPTION = DOCLING_DO_PICTURE_DESCRIPTION
|
||||
app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE = DOCLING_PICTURE_DESCRIPTION_MODE
|
||||
app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL = DOCLING_PICTURE_DESCRIPTION_LOCAL
|
||||
app.state.config.DOCLING_PICTURE_DESCRIPTION_API = DOCLING_PICTURE_DESCRIPTION_API
|
||||
app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT
|
||||
app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY
|
||||
app.state.config.MISTRAL_OCR_API_KEY = MISTRAL_OCR_API_KEY
|
||||
@@ -1389,26 +1421,30 @@ async def chat_action(
|
||||
|
||||
|
||||
@app.post("/api/tasks/stop/{task_id}")
|
||||
async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)):
|
||||
async def stop_task_endpoint(
|
||||
request: Request, task_id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
try:
|
||||
result = await stop_task(task_id)
|
||||
result = await stop_task(request, task_id)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/api/tasks")
|
||||
async def list_tasks_endpoint(user=Depends(get_verified_user)):
|
||||
return {"tasks": list_tasks()}
|
||||
async def list_tasks_endpoint(request: Request, user=Depends(get_verified_user)):
|
||||
return {"tasks": await list_tasks(request)}
|
||||
|
||||
|
||||
@app.get("/api/tasks/chat/{chat_id}")
|
||||
async def list_tasks_by_chat_id_endpoint(chat_id: str, user=Depends(get_verified_user)):
|
||||
async def list_tasks_by_chat_id_endpoint(
|
||||
request: Request, chat_id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
chat = Chats.get_chat_by_id(chat_id)
|
||||
if chat is None or chat.user_id != user.id:
|
||||
return {"task_ids": []}
|
||||
|
||||
task_ids = list_task_ids_by_chat_id(chat_id)
|
||||
task_ids = await list_task_ids_by_chat_id(request, chat_id)
|
||||
|
||||
print(f"Task IDs for chat {chat_id}: {task_ids}")
|
||||
return {"task_ids": task_ids}
|
||||
|
||||
@@ -245,8 +245,9 @@ class KnowledgeTable:
|
||||
from sqlalchemy import func
|
||||
with get_db() as db:
|
||||
if db.bind.dialect.name == "sqlite":
|
||||
func.json_extract(Knowledge.rag_config, f'$.{model_type}') == model,
|
||||
|
||||
query = db.query(Knowledge).filter(
|
||||
func.json_extract(Knowledge.rag_config, f'$.{model_type}') == model
|
||||
)
|
||||
elif db.bind.dialect.name == "postgresql":
|
||||
query = db.query(Knowledge).filter(
|
||||
Knowledge.rag_config.op("->>")(model_type) == model,
|
||||
|
||||
@@ -2,6 +2,7 @@ import requests
|
||||
import logging
|
||||
import ftfy
|
||||
import sys
|
||||
import json
|
||||
|
||||
from langchain_community.document_loaders import (
|
||||
AzureAIDocumentIntelligenceLoader,
|
||||
@@ -154,6 +155,24 @@ class DoclingLoader:
|
||||
"do_picture_description"
|
||||
)
|
||||
|
||||
picture_description_mode = self.params.get(
|
||||
"picture_description_mode", ""
|
||||
).lower()
|
||||
|
||||
if picture_description_mode == "local" and self.params.get(
|
||||
"picture_description_local", {}
|
||||
):
|
||||
params["picture_description_local"] = self.params.get(
|
||||
"picture_description_local", {}
|
||||
)
|
||||
|
||||
elif picture_description_mode == "api" and self.params.get(
|
||||
"picture_description_api", {}
|
||||
):
|
||||
params["picture_description_api"] = self.params.get(
|
||||
"picture_description_api", {}
|
||||
)
|
||||
|
||||
if self.params.get("ocr_engine") and self.params.get("ocr_lang"):
|
||||
params["ocr_engine"] = self.params.get("ocr_engine")
|
||||
params["ocr_lang"] = [
|
||||
@@ -281,17 +300,20 @@ class Loader:
|
||||
if self._is_text_file(file_ext, file_content_type):
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
# Build params for DoclingLoader
|
||||
params = self.kwargs.get("DOCLING_PARAMS", {})
|
||||
if not isinstance(params, dict):
|
||||
try:
|
||||
params = json.loads(params)
|
||||
except json.JSONDecodeError:
|
||||
log.error("Invalid DOCLING_PARAMS format, expected JSON object")
|
||||
params = {}
|
||||
|
||||
loader = DoclingLoader(
|
||||
url=self.kwargs.get("DOCLING_SERVER_URL"),
|
||||
file_path=file_path,
|
||||
mime_type=file_content_type,
|
||||
params={
|
||||
"ocr_engine": self.kwargs.get("DOCLING_OCR_ENGINE"),
|
||||
"ocr_lang": self.kwargs.get("DOCLING_OCR_LANG"),
|
||||
"do_picture_description": self.kwargs.get(
|
||||
"DOCLING_DO_PICTURE_DESCRIPTION"
|
||||
),
|
||||
},
|
||||
params=params,
|
||||
)
|
||||
elif (
|
||||
self.engine == "document_intelligence"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from xml.etree.ElementTree import ParseError
|
||||
|
||||
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
@@ -93,7 +94,6 @@ class YoutubeLoader:
|
||||
"http": self.proxy_url,
|
||||
"https": self.proxy_url,
|
||||
}
|
||||
# Don't log complete URL because it might contain secrets
|
||||
log.debug(f"Using proxy URL: {self.proxy_url[:14]}...")
|
||||
else:
|
||||
youtube_proxies = None
|
||||
@@ -110,11 +110,37 @@ class YoutubeLoader:
|
||||
for lang in self.language:
|
||||
try:
|
||||
transcript = transcript_list.find_transcript([lang])
|
||||
if transcript.is_generated:
|
||||
log.debug(f"Found generated transcript for language '{lang}'")
|
||||
try:
|
||||
transcript = transcript_list.find_manually_created_transcript(
|
||||
[lang]
|
||||
)
|
||||
log.debug(f"Found manual transcript for language '{lang}'")
|
||||
except NoTranscriptFound:
|
||||
log.debug(
|
||||
f"No manual transcript found for language '{lang}', using generated"
|
||||
)
|
||||
pass
|
||||
|
||||
log.debug(f"Found transcript for language '{lang}'")
|
||||
transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
|
||||
try:
|
||||
transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
|
||||
except ParseError:
|
||||
log.debug(f"Empty or invalid transcript for language '{lang}'")
|
||||
continue
|
||||
|
||||
if not transcript_pieces:
|
||||
log.debug(f"Empty transcript for language '{lang}'")
|
||||
continue
|
||||
|
||||
transcript_text = " ".join(
|
||||
map(
|
||||
lambda transcript_piece: transcript_piece.text.strip(" "),
|
||||
lambda transcript_piece: (
|
||||
transcript_piece.text.strip(" ")
|
||||
if hasattr(transcript_piece, "text")
|
||||
else ""
|
||||
),
|
||||
transcript_pieces,
|
||||
)
|
||||
)
|
||||
@@ -131,6 +157,4 @@ class YoutubeLoader:
|
||||
log.warning(
|
||||
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
|
||||
)
|
||||
raise NoTranscriptFound(
|
||||
f"No transcript found for any supported language. Verify if the video has transcripts, add more languages if needed."
|
||||
)
|
||||
raise NoTranscriptFound(self.video_id, self.language, list(transcript_list))
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
from typing import Optional, List, Dict, Any
|
||||
import logging
|
||||
import json
|
||||
from sqlalchemy import (
|
||||
func,
|
||||
literal,
|
||||
cast,
|
||||
column,
|
||||
create_engine,
|
||||
Column,
|
||||
Integer,
|
||||
MetaData,
|
||||
LargeBinary,
|
||||
select,
|
||||
text,
|
||||
Text,
|
||||
@@ -28,7 +32,12 @@ from open_webui.retrieval.vector.main import (
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||
from open_webui.config import (
|
||||
PGVECTOR_DB_URL,
|
||||
PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
|
||||
PGVECTOR_PGCRYPTO,
|
||||
PGVECTOR_PGCRYPTO_KEY,
|
||||
)
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
@@ -39,14 +48,27 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def pgcrypto_encrypt(val, key):
|
||||
return func.pgp_sym_encrypt(val, literal(key))
|
||||
|
||||
|
||||
def pgcrypto_decrypt(col, key, outtype="text"):
|
||||
return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype)
|
||||
|
||||
|
||||
class DocumentChunk(Base):
|
||||
__tablename__ = "document_chunk"
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
|
||||
collection_name = Column(Text, nullable=False)
|
||||
text = Column(Text, nullable=True)
|
||||
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
|
||||
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
text = Column(LargeBinary, nullable=True)
|
||||
vmetadata = Column(LargeBinary, nullable=True)
|
||||
else:
|
||||
text = Column(Text, nullable=True)
|
||||
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
|
||||
|
||||
|
||||
class PgvectorClient(VectorDBBase):
|
||||
@@ -147,44 +169,39 @@ class PgvectorClient(VectorDBBase):
|
||||
|
||||
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
try:
|
||||
new_items = []
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
new_chunk = DocumentChunk(
|
||||
id=item["id"],
|
||||
vector=vector,
|
||||
collection_name=collection_name,
|
||||
text=item["text"],
|
||||
vmetadata=item["metadata"],
|
||||
)
|
||||
new_items.append(new_chunk)
|
||||
self.session.bulk_save_objects(new_items)
|
||||
self.session.commit()
|
||||
log.info(
|
||||
f"Inserted {len(new_items)} items into collection '{collection_name}'."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during insert: {e}")
|
||||
raise
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
try:
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
existing = (
|
||||
self.session.query(DocumentChunk)
|
||||
.filter(DocumentChunk.id == item["id"])
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
existing.vector = vector
|
||||
existing.text = item["text"]
|
||||
existing.vmetadata = item["metadata"]
|
||||
existing.collection_name = (
|
||||
collection_name # Update collection_name if necessary
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
# Use raw SQL for BYTEA/pgcrypto
|
||||
self.session.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO document_chunk
|
||||
(id, vector, collection_name, text, vmetadata)
|
||||
VALUES (
|
||||
:id, :vector, :collection_name,
|
||||
pgp_sym_encrypt(:text, :key),
|
||||
pgp_sym_encrypt(:metadata::text, :key)
|
||||
)
|
||||
ON CONFLICT (id) DO NOTHING
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": vector,
|
||||
"collection_name": collection_name,
|
||||
"text": item["text"],
|
||||
"metadata": json.dumps(item["metadata"]),
|
||||
"key": PGVECTOR_PGCRYPTO_KEY,
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.session.commit()
|
||||
log.info(f"Encrypted & inserted {len(items)} into '{collection_name}'")
|
||||
|
||||
else:
|
||||
new_items = []
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
new_chunk = DocumentChunk(
|
||||
id=item["id"],
|
||||
vector=vector,
|
||||
@@ -192,11 +209,78 @@ class PgvectorClient(VectorDBBase):
|
||||
text=item["text"],
|
||||
vmetadata=item["metadata"],
|
||||
)
|
||||
self.session.add(new_chunk)
|
||||
self.session.commit()
|
||||
log.info(
|
||||
f"Upserted {len(items)} items into collection '{collection_name}'."
|
||||
)
|
||||
new_items.append(new_chunk)
|
||||
self.session.bulk_save_objects(new_items)
|
||||
self.session.commit()
|
||||
log.info(
|
||||
f"Inserted {len(new_items)} items into collection '{collection_name}'."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during insert: {e}")
|
||||
raise
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
try:
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
self.session.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO document_chunk
|
||||
(id, vector, collection_name, text, vmetadata)
|
||||
VALUES (
|
||||
:id, :vector, :collection_name,
|
||||
pgp_sym_encrypt(:text, :key),
|
||||
pgp_sym_encrypt(:metadata::text, :key)
|
||||
)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
vector = EXCLUDED.vector,
|
||||
collection_name = EXCLUDED.collection_name,
|
||||
text = EXCLUDED.text,
|
||||
vmetadata = EXCLUDED.vmetadata
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": vector,
|
||||
"collection_name": collection_name,
|
||||
"text": item["text"],
|
||||
"metadata": json.dumps(item["metadata"]),
|
||||
"key": PGVECTOR_PGCRYPTO_KEY,
|
||||
},
|
||||
)
|
||||
self.session.commit()
|
||||
log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'")
|
||||
else:
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
existing = (
|
||||
self.session.query(DocumentChunk)
|
||||
.filter(DocumentChunk.id == item["id"])
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
existing.vector = vector
|
||||
existing.text = item["text"]
|
||||
existing.vmetadata = item["metadata"]
|
||||
existing.collection_name = (
|
||||
collection_name # Update collection_name if necessary
|
||||
)
|
||||
else:
|
||||
new_chunk = DocumentChunk(
|
||||
id=item["id"],
|
||||
vector=vector,
|
||||
collection_name=collection_name,
|
||||
text=item["text"],
|
||||
vmetadata=item["metadata"],
|
||||
)
|
||||
self.session.add(new_chunk)
|
||||
self.session.commit()
|
||||
log.info(
|
||||
f"Upserted {len(items)} items into collection '{collection_name}'."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during upsert: {e}")
|
||||
@@ -230,16 +314,32 @@ class PgvectorClient(VectorDBBase):
|
||||
.alias("query_vectors")
|
||||
)
|
||||
|
||||
result_fields = [
|
||||
DocumentChunk.id,
|
||||
]
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
result_fields.append(
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
|
||||
).label("text")
|
||||
)
|
||||
result_fields.append(
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
).label("vmetadata")
|
||||
)
|
||||
else:
|
||||
result_fields.append(DocumentChunk.text)
|
||||
result_fields.append(DocumentChunk.vmetadata)
|
||||
result_fields.append(
|
||||
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
|
||||
"distance"
|
||||
)
|
||||
)
|
||||
|
||||
# Build the lateral subquery for each query vector
|
||||
subq = (
|
||||
select(
|
||||
DocumentChunk.id,
|
||||
DocumentChunk.text,
|
||||
DocumentChunk.vmetadata,
|
||||
(
|
||||
DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
|
||||
).label("distance"),
|
||||
)
|
||||
select(*result_fields)
|
||||
.where(DocumentChunk.collection_name == collection_name)
|
||||
.order_by(
|
||||
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
|
||||
@@ -299,17 +399,43 @@ class PgvectorClient(VectorDBBase):
|
||||
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
try:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
# Build where clause for vmetadata filter
|
||||
where_clauses = [DocumentChunk.collection_name == collection_name]
|
||||
for key, value in filter.items():
|
||||
# decrypt then check key: JSON filter after decryption
|
||||
where_clauses.append(
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
)[key].astext
|
||||
== str(value)
|
||||
)
|
||||
stmt = select(
|
||||
DocumentChunk.id,
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
|
||||
).label("text"),
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
).label("vmetadata"),
|
||||
).where(*where_clauses)
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
results = self.session.execute(stmt).all()
|
||||
else:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
|
||||
for key, value in filter.items():
|
||||
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
|
||||
for key, value in filter.items():
|
||||
query = query.filter(
|
||||
DocumentChunk.vmetadata[key].astext == str(value)
|
||||
)
|
||||
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
results = query.all()
|
||||
results = query.all()
|
||||
|
||||
if not results:
|
||||
return None
|
||||
@@ -331,20 +457,38 @@ class PgvectorClient(VectorDBBase):
|
||||
self, collection_name: str, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
try:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
stmt = select(
|
||||
DocumentChunk.id,
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
|
||||
).label("text"),
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
).label("vmetadata"),
|
||||
).where(DocumentChunk.collection_name == collection_name)
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
results = self.session.execute(stmt).all()
|
||||
ids = [[row.id for row in results]]
|
||||
documents = [[row.text for row in results]]
|
||||
metadatas = [[row.vmetadata for row in results]]
|
||||
else:
|
||||
|
||||
results = query.all()
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
if not results:
|
||||
return None
|
||||
results = query.all()
|
||||
|
||||
ids = [[result.id for result in results]]
|
||||
documents = [[result.text for result in results]]
|
||||
metadatas = [[result.vmetadata for result in results]]
|
||||
if not results:
|
||||
return None
|
||||
|
||||
ids = [[result.id for result in results]]
|
||||
documents = [[result.text for result in results]]
|
||||
metadatas = [[result.vmetadata for result in results]]
|
||||
|
||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||
except Exception as e:
|
||||
@@ -358,17 +502,33 @@ class PgvectorClient(VectorDBBase):
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
try:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
if ids:
|
||||
query = query.filter(DocumentChunk.id.in_(ids))
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
query = query.filter(
|
||||
DocumentChunk.vmetadata[key].astext == str(value)
|
||||
)
|
||||
deleted = query.delete(synchronize_session=False)
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
wheres = [DocumentChunk.collection_name == collection_name]
|
||||
if ids:
|
||||
wheres.append(DocumentChunk.id.in_(ids))
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
wheres.append(
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
)[key].astext
|
||||
== str(value)
|
||||
)
|
||||
stmt = DocumentChunk.__table__.delete().where(*wheres)
|
||||
result = self.session.execute(stmt)
|
||||
deleted = result.rowcount
|
||||
else:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
if ids:
|
||||
query = query.filter(DocumentChunk.id.in_(ids))
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
query = query.filter(
|
||||
DocumentChunk.vmetadata[key].astext == str(value)
|
||||
)
|
||||
deleted = query.delete(synchronize_session=False)
|
||||
self.session.commit()
|
||||
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
|
||||
except Exception as e:
|
||||
|
||||
@@ -420,7 +420,7 @@ def load_b64_image_data(b64_str):
|
||||
try:
|
||||
if "," in b64_str:
|
||||
header, encoded = b64_str.split(",", 1)
|
||||
mime_type = header.split(";")[0]
|
||||
mime_type = header.split(";")[0].lstrip("data:")
|
||||
img_data = base64.b64decode(encoded)
|
||||
else:
|
||||
mime_type = "image/png"
|
||||
@@ -428,7 +428,7 @@ def load_b64_image_data(b64_str):
|
||||
return img_data, mime_type
|
||||
except Exception as e:
|
||||
log.exception(f"Error loading image data: {e}")
|
||||
return None
|
||||
return None, None
|
||||
|
||||
|
||||
def load_url_image_data(url, headers=None):
|
||||
|
||||
@@ -600,6 +600,9 @@ async def get_rag_config(request: Request, collectionForm: CollectionForm, user=
|
||||
"DOCLING_OCR_ENGINE": rag_config.get("DOCLING_OCR_ENGINE", request.app.state.config.DOCLING_OCR_ENGINE),
|
||||
"DOCLING_OCR_LANG": rag_config.get("DOCLING_OCR_LANG", request.app.state.config.DOCLING_OCR_LANG),
|
||||
"DOCLING_DO_PICTURE_DESCRIPTION": rag_config.get("DOCLING_DO_PICTURE_DESCRIPTION", request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION),
|
||||
"DOCLING_PICTURE_DESCRIPTION_MODE": rag_config.get("DOCLING_PICTURE_DESCRIPTION_MODE", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE),
|
||||
"DOCLING_PICTURE_DESCRIPTION_LOCAL": rag_config.get("DOCLING_PICTURE_DESCRIPTION_LOCAL", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL),
|
||||
"DOCLING_PICTURE_DESCRIPTION_API": rag_config.get("DOCLING_PICTURE_DESCRIPTION_API", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API),
|
||||
"DOCUMENT_INTELLIGENCE_ENDPOINT": rag_config.get("DOCUMENT_INTELLIGENCE_ENDPOINT", request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT),
|
||||
"DOCUMENT_INTELLIGENCE_KEY": rag_config.get("DOCUMENT_INTELLIGENCE_KEY", request.app.state.config.DOCUMENT_INTELLIGENCE_KEY),
|
||||
"MISTRAL_OCR_API_KEY": rag_config.get("MISTRAL_OCR_API_KEY", request.app.state.config.MISTRAL_OCR_API_KEY),
|
||||
@@ -766,6 +769,9 @@ class ConfigForm(BaseModel):
|
||||
DOCLING_OCR_ENGINE: Optional[str] = None
|
||||
DOCLING_OCR_LANG: Optional[str] = None
|
||||
DOCLING_DO_PICTURE_DESCRIPTION: Optional[bool] = None
|
||||
DOCLING_PICTURE_DESCRIPTION_MODE: Optional[str] = None
|
||||
DOCLING_PICTURE_DESCRIPTION_LOCAL: Optional[dict] = None
|
||||
DOCLING_PICTURE_DESCRIPTION_API: Optional[dict] = None
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT: Optional[str] = None
|
||||
DOCUMENT_INTELLIGENCE_KEY: Optional[str] = None
|
||||
MISTRAL_OCR_API_KEY: Optional[str] = None
|
||||
@@ -1050,6 +1056,22 @@ async def update_rag_config(
|
||||
else request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION
|
||||
)
|
||||
|
||||
request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE = (
|
||||
form_data.DOCLING_PICTURE_DESCRIPTION_MODE
|
||||
if form_data.DOCLING_PICTURE_DESCRIPTION_MODE is not None
|
||||
else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE
|
||||
)
|
||||
request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL = (
|
||||
form_data.DOCLING_PICTURE_DESCRIPTION_LOCAL
|
||||
if form_data.DOCLING_PICTURE_DESCRIPTION_LOCAL is not None
|
||||
else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL
|
||||
)
|
||||
request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API = (
|
||||
form_data.DOCLING_PICTURE_DESCRIPTION_API
|
||||
if form_data.DOCLING_PICTURE_DESCRIPTION_API is not None
|
||||
else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API
|
||||
)
|
||||
|
||||
request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = (
|
||||
form_data.DOCUMENT_INTELLIGENCE_ENDPOINT
|
||||
if form_data.DOCUMENT_INTELLIGENCE_ENDPOINT is not None
|
||||
@@ -1307,6 +1329,9 @@ async def update_rag_config(
|
||||
"DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE,
|
||||
"DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG,
|
||||
"DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
|
||||
"DOCLING_PICTURE_DESCRIPTION_MODE": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE,
|
||||
"DOCLING_PICTURE_DESCRIPTION_LOCAL": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL,
|
||||
"DOCLING_PICTURE_DESCRIPTION_API": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API,
|
||||
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
|
||||
@@ -1667,6 +1692,15 @@ def process_file(
|
||||
docling_do_picture_description=rag_config.get(
|
||||
"DOCLING_DO_PICTURE_DESCRIPTION", request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION
|
||||
)
|
||||
picture_description_mode = rag_config.get(
|
||||
"PICTURE_DESCRIPTION_MODE", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE
|
||||
)
|
||||
picture_description_local = rag_config.get(
|
||||
"PICTURE_DESCRIPTION_MODE", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL
|
||||
)
|
||||
picture_description_api = rag_config.get(
|
||||
"PICTURE_DESCRIPTION_API", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API
|
||||
)
|
||||
pdf_extract_images = rag_config.get(
|
||||
"PDF_EXTRACT_IMAGES", request.app.state.config.PDF_EXTRACT_IMAGES
|
||||
)
|
||||
@@ -1757,9 +1791,14 @@ def process_file(
|
||||
EXTERNAL_DOCUMENT_LOADER_API_KEY=external_document_loader_api_key,
|
||||
TIKA_SERVER_URL=tika_server_url,
|
||||
DOCLING_SERVER_URL=docling_server_url,
|
||||
DOCLING_OCR_ENGINE=docling_ocr_engine,
|
||||
DOCLING_OCR_LANG=docling_ocr_lang,
|
||||
DOCLING_DO_PICTURE_DESCRIPTION=docling_do_picture_description,
|
||||
DOCLING_PARAMS={
|
||||
"ocr_engine": docling_ocr_engine,
|
||||
"ocr_lang": docling_ocr_lang,
|
||||
"do_picture_description": docling_do_picture_description,
|
||||
"picture_description_mode": picture_description_mode,
|
||||
"picture_description_local": picture_description_local,
|
||||
"picture_description_api": picture_description_api,
|
||||
},
|
||||
PDF_EXTRACT_IMAGES=pdf_extract_images,
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT=document_intelligence_endpoint,
|
||||
DOCUMENT_INTELLIGENCE_KEY=document_intelligence_key,
|
||||
|
||||
@@ -33,7 +33,7 @@ class CodeForm(BaseModel):
|
||||
|
||||
|
||||
@router.post("/code/format")
|
||||
async def format_code(form_data: CodeForm, user=Depends(get_verified_user)):
|
||||
async def format_code(form_data: CodeForm, user=Depends(get_admin_user)):
|
||||
try:
|
||||
formatted_code = black.format_str(form_data.code, mode=black.Mode())
|
||||
return {"code": formatted_code}
|
||||
|
||||
@@ -2,16 +2,87 @@
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
from uuid import uuid4
|
||||
import json
|
||||
from redis.asyncio import Redis
|
||||
from fastapi import Request
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
# A dictionary to keep track of active tasks
|
||||
tasks: Dict[str, asyncio.Task] = {}
|
||||
chat_tasks = {}
|
||||
|
||||
|
||||
def cleanup_task(task_id: str, id=None):
|
||||
REDIS_TASKS_KEY = "open-webui:tasks"
|
||||
REDIS_CHAT_TASKS_KEY = "open-webui:tasks:chat"
|
||||
REDIS_PUBSUB_CHANNEL = "open-webui:tasks:commands"
|
||||
|
||||
|
||||
def is_redis(request: Request) -> bool:
|
||||
# Called everywhere a request is available to check Redis
|
||||
return hasattr(request.app.state, "redis") and (request.app.state.redis is not None)
|
||||
|
||||
|
||||
async def redis_task_command_listener(app):
|
||||
redis: Redis = app.state.redis
|
||||
pubsub = redis.pubsub()
|
||||
await pubsub.subscribe(REDIS_PUBSUB_CHANNEL)
|
||||
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] != "message":
|
||||
continue
|
||||
try:
|
||||
command = json.loads(message["data"])
|
||||
if command.get("action") == "stop":
|
||||
task_id = command.get("task_id")
|
||||
local_task = tasks.get(task_id)
|
||||
if local_task:
|
||||
local_task.cancel()
|
||||
except Exception as e:
|
||||
print(f"Error handling distributed task command: {e}")
|
||||
|
||||
|
||||
### ------------------------------
|
||||
### REDIS-ENABLED HANDLERS
|
||||
### ------------------------------
|
||||
|
||||
|
||||
async def redis_save_task(redis: Redis, task_id: str, chat_id: Optional[str]):
|
||||
pipe = redis.pipeline()
|
||||
pipe.hset(REDIS_TASKS_KEY, task_id, chat_id or "")
|
||||
if chat_id:
|
||||
pipe.sadd(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
|
||||
await pipe.execute()
|
||||
|
||||
|
||||
async def redis_cleanup_task(redis: Redis, task_id: str, chat_id: Optional[str]):
|
||||
pipe = redis.pipeline()
|
||||
pipe.hdel(REDIS_TASKS_KEY, task_id)
|
||||
if chat_id:
|
||||
pipe.srem(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
|
||||
if (await pipe.scard(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}").execute())[-1] == 0:
|
||||
pipe.delete(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}") # Remove if empty set
|
||||
await pipe.execute()
|
||||
|
||||
|
||||
async def redis_list_tasks(redis: Redis) -> List[str]:
|
||||
return list(await redis.hkeys(REDIS_TASKS_KEY))
|
||||
|
||||
|
||||
async def redis_list_chat_tasks(redis: Redis, chat_id: str) -> List[str]:
|
||||
return list(await redis.smembers(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}"))
|
||||
|
||||
|
||||
async def redis_send_command(redis: Redis, command: dict):
|
||||
await redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command))
|
||||
|
||||
|
||||
async def cleanup_task(request, task_id: str, id=None):
|
||||
"""
|
||||
Remove a completed or canceled task from the global `tasks` dictionary.
|
||||
"""
|
||||
if is_redis(request):
|
||||
await redis_cleanup_task(request.app.state.redis, task_id, id)
|
||||
|
||||
tasks.pop(task_id, None) # Remove the task if it exists
|
||||
|
||||
# If an ID is provided, remove the task from the chat_tasks dictionary
|
||||
@@ -21,7 +92,7 @@ def cleanup_task(task_id: str, id=None):
|
||||
chat_tasks.pop(id, None)
|
||||
|
||||
|
||||
def create_task(coroutine, id=None):
|
||||
async def create_task(request, coroutine, id=None):
|
||||
"""
|
||||
Create a new asyncio task and add it to the global task dictionary.
|
||||
"""
|
||||
@@ -29,7 +100,9 @@ def create_task(coroutine, id=None):
|
||||
task = asyncio.create_task(coroutine) # Create the task
|
||||
|
||||
# Add a done callback for cleanup
|
||||
task.add_done_callback(lambda t: cleanup_task(task_id, id))
|
||||
task.add_done_callback(
|
||||
lambda t: asyncio.create_task(cleanup_task(request, task_id, id))
|
||||
)
|
||||
tasks[task_id] = task
|
||||
|
||||
# If an ID is provided, associate the task with that ID
|
||||
@@ -38,34 +111,46 @@ def create_task(coroutine, id=None):
|
||||
else:
|
||||
chat_tasks[id] = [task_id]
|
||||
|
||||
if is_redis(request):
|
||||
await redis_save_task(request.app.state.redis, task_id, id)
|
||||
|
||||
return task_id, task
|
||||
|
||||
|
||||
def get_task(task_id: str):
|
||||
"""
|
||||
Retrieve a task by its task ID.
|
||||
"""
|
||||
return tasks.get(task_id)
|
||||
|
||||
|
||||
def list_tasks():
|
||||
async def list_tasks(request):
|
||||
"""
|
||||
List all currently active task IDs.
|
||||
"""
|
||||
if is_redis(request):
|
||||
return await redis_list_tasks(request.app.state.redis)
|
||||
return list(tasks.keys())
|
||||
|
||||
|
||||
def list_task_ids_by_chat_id(id):
|
||||
async def list_task_ids_by_chat_id(request, id):
|
||||
"""
|
||||
List all tasks associated with a specific ID.
|
||||
"""
|
||||
if is_redis(request):
|
||||
return await redis_list_chat_tasks(request.app.state.redis, id)
|
||||
return chat_tasks.get(id, [])
|
||||
|
||||
|
||||
async def stop_task(task_id: str):
|
||||
async def stop_task(request, task_id: str):
|
||||
"""
|
||||
Cancel a running task and remove it from the global task list.
|
||||
"""
|
||||
if is_redis(request):
|
||||
# PUBSUB: All instances check if they have this task, and stop if so.
|
||||
await redis_send_command(
|
||||
request.app.state.redis,
|
||||
{
|
||||
"action": "stop",
|
||||
"task_id": task_id,
|
||||
},
|
||||
)
|
||||
# Optionally check if task_id still in Redis a few moments later for feedback?
|
||||
return {"status": True, "message": f"Stop signal sent for {task_id}"}
|
||||
|
||||
task = tasks.get(task_id)
|
||||
if not task:
|
||||
raise ValueError(f"Task with ID {task_id} not found.")
|
||||
|
||||
@@ -23,6 +23,7 @@ from open_webui.env import (
|
||||
TRUSTED_SIGNATURE_KEY,
|
||||
STATIC_DIR,
|
||||
SRC_LOG_LEVELS,
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||
)
|
||||
|
||||
from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
|
||||
@@ -157,6 +158,7 @@ def get_http_authorization_cred(auth_header: Optional[str]):
|
||||
|
||||
def get_current_user(
|
||||
request: Request,
|
||||
response: Response,
|
||||
background_tasks: BackgroundTasks,
|
||||
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
|
||||
):
|
||||
@@ -225,6 +227,19 @@ def get_current_user(
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
||||
else:
|
||||
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
|
||||
trusted_email = request.headers.get(WEBUI_AUTH_TRUSTED_EMAIL_HEADER)
|
||||
if trusted_email and user.email != trusted_email:
|
||||
# Delete the token cookie
|
||||
response.delete_cookie("token")
|
||||
# Delete OAuth token if present
|
||||
if request.cookies.get("oauth_id_token"):
|
||||
response.delete_cookie("oauth_id_token")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User mismatch. Please sign in again.",
|
||||
)
|
||||
|
||||
# Add user info to current span
|
||||
current_span = trace.get_current_span()
|
||||
if current_span:
|
||||
|
||||
@@ -37,7 +37,12 @@ from open_webui.routers.tasks import (
|
||||
generate_chat_tags,
|
||||
)
|
||||
from open_webui.routers.retrieval import process_web_search, SearchForm
|
||||
from open_webui.routers.images import image_generations, GenerateImageForm
|
||||
from open_webui.routers.images import (
|
||||
load_b64_image_data,
|
||||
image_generations,
|
||||
GenerateImageForm,
|
||||
upload_image,
|
||||
)
|
||||
from open_webui.routers.pipelines import (
|
||||
process_pipeline_inlet_filter,
|
||||
process_pipeline_outlet_filter,
|
||||
@@ -2278,28 +2283,21 @@ async def process_chat_response(
|
||||
stdoutLines = stdout.split("\n")
|
||||
for idx, line in enumerate(stdoutLines):
|
||||
if "data:image/png;base64" in line:
|
||||
id = str(uuid4())
|
||||
|
||||
# ensure the path exists
|
||||
os.makedirs(
|
||||
os.path.join(CACHE_DIR, "images"),
|
||||
exist_ok=True,
|
||||
image_url = ""
|
||||
# Extract base64 image data from the line
|
||||
image_data, content_type = (
|
||||
load_b64_image_data(line)
|
||||
)
|
||||
|
||||
image_path = os.path.join(
|
||||
CACHE_DIR,
|
||||
f"images/{id}.png",
|
||||
)
|
||||
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(
|
||||
base64.b64decode(
|
||||
line.split(",")[1]
|
||||
)
|
||||
if image_data is not None:
|
||||
image_url = upload_image(
|
||||
request,
|
||||
image_data,
|
||||
content_type,
|
||||
metadata,
|
||||
user,
|
||||
)
|
||||
|
||||
stdoutLines[idx] = (
|
||||
f""
|
||||
f""
|
||||
)
|
||||
|
||||
output["stdout"] = "\n".join(stdoutLines)
|
||||
@@ -2310,30 +2308,22 @@ async def process_chat_response(
|
||||
resultLines = result.split("\n")
|
||||
for idx, line in enumerate(resultLines):
|
||||
if "data:image/png;base64" in line:
|
||||
id = str(uuid4())
|
||||
|
||||
# ensure the path exists
|
||||
os.makedirs(
|
||||
os.path.join(CACHE_DIR, "images"),
|
||||
exist_ok=True,
|
||||
image_url = ""
|
||||
# Extract base64 image data from the line
|
||||
image_data, content_type = (
|
||||
load_b64_image_data(line)
|
||||
)
|
||||
|
||||
image_path = os.path.join(
|
||||
CACHE_DIR,
|
||||
f"images/{id}.png",
|
||||
)
|
||||
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(
|
||||
base64.b64decode(
|
||||
line.split(",")[1]
|
||||
)
|
||||
if image_data is not None:
|
||||
image_url = upload_image(
|
||||
request,
|
||||
image_data,
|
||||
content_type,
|
||||
metadata,
|
||||
user,
|
||||
)
|
||||
|
||||
resultLines[idx] = (
|
||||
f""
|
||||
f""
|
||||
)
|
||||
|
||||
output["result"] = "\n".join(resultLines)
|
||||
except Exception as e:
|
||||
output = str(e)
|
||||
@@ -2442,8 +2432,8 @@ async def process_chat_response(
|
||||
await response.background()
|
||||
|
||||
# background_tasks.add_task(post_response_handler, response, events)
|
||||
task_id, _ = create_task(
|
||||
post_response_handler(response, events), id=metadata["chat_id"]
|
||||
task_id, _ = await create_task(
|
||||
request, post_response_handler(response, events), id=metadata["chat_id"]
|
||||
)
|
||||
return {"status": True, "task_id": task_id}
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import socketio
|
||||
import redis
|
||||
from redis import asyncio as aioredis
|
||||
from urllib.parse import urlparse
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def parse_redis_service_url(redis_url):
|
||||
@@ -18,23 +17,46 @@ def parse_redis_service_url(redis_url):
|
||||
}
|
||||
|
||||
|
||||
def get_redis_connection(redis_url, redis_sentinels, decode_responses=True):
|
||||
if redis_sentinels:
|
||||
redis_config = parse_redis_service_url(redis_url)
|
||||
sentinel = redis.sentinel.Sentinel(
|
||||
redis_sentinels,
|
||||
port=redis_config["port"],
|
||||
db=redis_config["db"],
|
||||
username=redis_config["username"],
|
||||
password=redis_config["password"],
|
||||
decode_responses=decode_responses,
|
||||
)
|
||||
def get_redis_connection(
|
||||
redis_url, redis_sentinels, async_mode=False, decode_responses=True
|
||||
):
|
||||
if async_mode:
|
||||
import redis.asyncio as redis
|
||||
|
||||
# Get a master connection from Sentinel
|
||||
return sentinel.master_for(redis_config["service"])
|
||||
# If using sentinel in async mode
|
||||
if redis_sentinels:
|
||||
redis_config = parse_redis_service_url(redis_url)
|
||||
sentinel = redis.sentinel.Sentinel(
|
||||
redis_sentinels,
|
||||
port=redis_config["port"],
|
||||
db=redis_config["db"],
|
||||
username=redis_config["username"],
|
||||
password=redis_config["password"],
|
||||
decode_responses=decode_responses,
|
||||
)
|
||||
return sentinel.master_for(redis_config["service"])
|
||||
elif redis_url:
|
||||
return redis.from_url(redis_url, decode_responses=decode_responses)
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
# Standard Redis connection
|
||||
return redis.Redis.from_url(redis_url, decode_responses=decode_responses)
|
||||
import redis
|
||||
|
||||
if redis_sentinels:
|
||||
redis_config = parse_redis_service_url(redis_url)
|
||||
sentinel = redis.sentinel.Sentinel(
|
||||
redis_sentinels,
|
||||
port=redis_config["port"],
|
||||
db=redis_config["db"],
|
||||
username=redis_config["username"],
|
||||
password=redis_config["password"],
|
||||
decode_responses=decode_responses,
|
||||
)
|
||||
return sentinel.master_for(redis_config["service"])
|
||||
elif redis_url:
|
||||
return redis.Redis.from_url(redis_url, decode_responses=decode_responses)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env):
|
||||
|
||||
@@ -14,7 +14,11 @@ if [[ "${WEB_LOADER_ENGINE,,}" == "playwright" ]]; then
|
||||
python -c "import nltk; nltk.download('punkt_tab')"
|
||||
fi
|
||||
|
||||
KEY_FILE=.webui_secret_key
|
||||
if [ -n "${WEBUI_SECRET_KEY_FILE}" ]; then
|
||||
KEY_FILE="${WEBUI_SECRET_KEY_FILE}"
|
||||
else
|
||||
KEY_FILE=".webui_secret_key"
|
||||
fi
|
||||
|
||||
PORT="${PORT:-8080}"
|
||||
HOST="${HOST:-0.0.0.0}"
|
||||
|
||||
@@ -18,6 +18,10 @@ IF /I "%WEB_LOADER_ENGINE%" == "playwright" (
|
||||
)
|
||||
|
||||
SET "KEY_FILE=.webui_secret_key"
|
||||
IF NOT "%WEBUI_SECRET_KEY_FILE%" == "" (
|
||||
SET "KEY_FILE=%WEBUI_SECRET_KEY_FILE%"
|
||||
)
|
||||
|
||||
IF "%PORT%"=="" SET PORT=8080
|
||||
IF "%HOST%"=="" SET HOST=0.0.0.0
|
||||
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
|
||||
|
||||
Reference in New Issue
Block a user