mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
2644 lines
104 KiB
Python
2644 lines
104 KiB
Python
import json
|
||
import logging
|
||
import mimetypes
|
||
import os
|
||
import shutil
|
||
import asyncio
|
||
import re
|
||
from typing import List as TypingList
|
||
|
||
|
||
import uuid
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Iterator, List, Optional, Sequence, Union
|
||
|
||
from fastapi import (
|
||
Depends,
|
||
FastAPI,
|
||
File,
|
||
Form,
|
||
HTTPException,
|
||
UploadFile,
|
||
Request,
|
||
status,
|
||
APIRouter,
|
||
)
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.concurrency import run_in_threadpool
|
||
from pydantic import BaseModel
|
||
import tiktoken
|
||
|
||
|
||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
|
||
from langchain_core.documents import Document
|
||
|
||
from open_webui.models.files import FileModel, Files
|
||
from open_webui.models.knowledge import Knowledges
|
||
from open_webui.storage.provider import Storage
|
||
|
||
|
||
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||
|
||
# Document loaders
|
||
from open_webui.retrieval.loaders.main import Loader
|
||
from open_webui.retrieval.loaders.youtube import YoutubeLoader
|
||
|
||
# Web search engines
|
||
from open_webui.retrieval.web.main import SearchResult
|
||
from open_webui.retrieval.web.utils import get_web_loader
|
||
from open_webui.retrieval.web.brave import search_brave
|
||
from open_webui.retrieval.web.kagi import search_kagi
|
||
from open_webui.retrieval.web.mojeek import search_mojeek
|
||
from open_webui.retrieval.web.bocha import search_bocha
|
||
from open_webui.retrieval.web.duckduckgo import search_duckduckgo
|
||
from open_webui.retrieval.web.google_pse import search_google_pse
|
||
from open_webui.retrieval.web.jina_search import search_jina
|
||
from open_webui.retrieval.web.searchapi import search_searchapi
|
||
from open_webui.retrieval.web.serpapi import search_serpapi
|
||
from open_webui.retrieval.web.searxng import search_searxng
|
||
from open_webui.retrieval.web.yacy import search_yacy
|
||
from open_webui.retrieval.web.serper import search_serper
|
||
from open_webui.retrieval.web.serply import search_serply
|
||
from open_webui.retrieval.web.serpstack import search_serpstack
|
||
from open_webui.retrieval.web.tavily import search_tavily
|
||
from open_webui.retrieval.web.bing import search_bing
|
||
from open_webui.retrieval.web.exa import search_exa
|
||
from open_webui.retrieval.web.perplexity import search_perplexity
|
||
from open_webui.retrieval.web.sougou import search_sougou
|
||
from open_webui.retrieval.web.firecrawl import search_firecrawl
|
||
from open_webui.retrieval.web.external import search_external
|
||
|
||
from open_webui.retrieval.utils import (
|
||
get_embedding_function,
|
||
get_model_path,
|
||
query_collection,
|
||
query_collection_with_hybrid_search,
|
||
query_doc,
|
||
query_doc_with_hybrid_search,
|
||
)
|
||
from open_webui.utils.misc import (
|
||
calculate_sha256_string,
|
||
)
|
||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||
|
||
from open_webui.config import (
|
||
ENV,
|
||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
||
UPLOAD_DIR,
|
||
DEFAULT_LOCALE,
|
||
RAG_EMBEDDING_CONTENT_PREFIX,
|
||
RAG_EMBEDDING_QUERY_PREFIX,
|
||
)
|
||
from open_webui.env import (
|
||
SRC_LOG_LEVELS,
|
||
DEVICE_TYPE,
|
||
DOCKER,
|
||
SENTENCE_TRANSFORMERS_BACKEND,
|
||
SENTENCE_TRANSFORMERS_MODEL_KWARGS,
|
||
SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND,
|
||
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS,
|
||
)
|
||
|
||
from open_webui.constants import ERROR_MESSAGES
|
||
|
||
log = logging.getLogger(__name__)
|
||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||
|
||
##########################################
|
||
#
|
||
# Utility functions
|
||
#
|
||
##########################################
|
||
|
||
|
||
def get_ef(
|
||
engine: str,
|
||
embedding_model: str,
|
||
auto_update: bool = False,
|
||
):
|
||
ef = None
|
||
if embedding_model and engine == "":
|
||
from sentence_transformers import SentenceTransformer
|
||
|
||
try:
|
||
ef = SentenceTransformer(
|
||
get_model_path(embedding_model, auto_update),
|
||
device=DEVICE_TYPE,
|
||
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||
backend=SENTENCE_TRANSFORMERS_BACKEND,
|
||
model_kwargs=SENTENCE_TRANSFORMERS_MODEL_KWARGS,
|
||
)
|
||
except Exception as e:
|
||
log.debug(f"Error loading SentenceTransformer: {e}")
|
||
|
||
return ef
|
||
|
||
|
||
def get_rf(
|
||
engine: str = "",
|
||
reranking_model: Optional[str] = None,
|
||
external_reranker_url: str = "",
|
||
external_reranker_api_key: str = "",
|
||
auto_update: bool = False,
|
||
):
|
||
rf = None
|
||
if reranking_model:
|
||
if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]):
|
||
try:
|
||
from open_webui.retrieval.models.colbert import ColBERT
|
||
|
||
rf = ColBERT(
|
||
get_model_path(reranking_model, auto_update),
|
||
env="docker" if DOCKER else None,
|
||
)
|
||
|
||
except Exception as e:
|
||
log.error(f"ColBERT: {e}")
|
||
raise Exception(ERROR_MESSAGES.DEFAULT(e))
|
||
else:
|
||
if engine == "external":
|
||
try:
|
||
from open_webui.retrieval.models.external import ExternalReranker
|
||
|
||
rf = ExternalReranker(
|
||
url=external_reranker_url,
|
||
api_key=external_reranker_api_key,
|
||
model=reranking_model,
|
||
)
|
||
except Exception as e:
|
||
log.error(f"ExternalReranking: {e}")
|
||
raise Exception(ERROR_MESSAGES.DEFAULT(e))
|
||
else:
|
||
import sentence_transformers
|
||
|
||
try:
|
||
rf = sentence_transformers.CrossEncoder(
|
||
get_model_path(reranking_model, auto_update),
|
||
device=DEVICE_TYPE,
|
||
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
||
backend=SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND,
|
||
model_kwargs=SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS,
|
||
)
|
||
except Exception as e:
|
||
log.error(f"CrossEncoder: {e}")
|
||
raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error"))
|
||
|
||
return rf
|
||
|
||
|
||
##########################################
|
||
#
|
||
# Text cleaning and processing functions
|
||
#
|
||
##########################################
|
||
|
||
|
||
class TextCleaner:
|
||
"""Modular text cleaning system for document processing and embedding preparation."""
|
||
|
||
@staticmethod
|
||
def normalize_escape_sequences(text: str) -> str:
|
||
"""Normalize escape sequences from various document formats."""
|
||
if not text:
|
||
return ""
|
||
|
||
# Handle double-escaped sequences (common in PPTX)
|
||
replacements = [
|
||
("\\\\n", "\n"), # Double-escaped newlines
|
||
("\\\\t", " "), # Double-escaped tabs
|
||
('\\\\"', '"'), # Double-escaped quotes
|
||
("\\\\r", ""), # Double-escaped carriage returns
|
||
("\\\\/", "/"), # Double-escaped slashes
|
||
("\\\\", "\\"), # Convert double backslashes to single
|
||
]
|
||
|
||
for old, new in replacements:
|
||
text = text.replace(old, new)
|
||
|
||
# Handle single-escaped sequences
|
||
single_replacements = [
|
||
("\\n", "\n"), # Single-escaped newlines
|
||
("\\t", " "), # Single-escaped tabs
|
||
('\\"', '"'), # Single-escaped quotes
|
||
("\\'", "'"), # Single-escaped single quotes
|
||
("\\r", ""), # Single-escaped carriage returns
|
||
("\\/", "/"), # Single-escaped slashes
|
||
]
|
||
|
||
for old, new in single_replacements:
|
||
text = text.replace(old, new)
|
||
|
||
# Remove any remaining backslash artifacts
|
||
text = re.sub(r"\\[a-zA-Z]", "", text) # Remove \letter patterns
|
||
text = re.sub(r"\\[0-9]", "", text) # Remove \number patterns
|
||
text = re.sub(r"\\[^a-zA-Z0-9\s]", "", text) # Remove \symbol patterns
|
||
text = re.sub(r"\\+", "", text) # Remove remaining backslashes
|
||
|
||
return text
|
||
|
||
@staticmethod
|
||
def normalize_unicode(text: str) -> str:
|
||
"""Convert special Unicode characters to ASCII equivalents."""
|
||
if not text:
|
||
return ""
|
||
|
||
unicode_map = {
|
||
"–": "-", # En dash
|
||
"—": "-", # Em dash
|
||
""": "'", # Smart single quote left
|
||
""": "'", # Smart single quote right
|
||
'"': '"', # Smart double quote left
|
||
'"': '"', # Smart double quote right
|
||
"…": "...", # Ellipsis
|
||
"™": " TM", # Trademark
|
||
"®": " R", # Registered
|
||
"©": " C", # Copyright
|
||
"°": " deg", # Degree symbol
|
||
}
|
||
|
||
for unicode_char, ascii_char in unicode_map.items():
|
||
text = text.replace(unicode_char, ascii_char)
|
||
|
||
return text
|
||
|
||
@staticmethod
|
||
def normalize_quotes(text: str) -> str:
|
||
"""Clean up quote-related artifacts and normalize quote marks."""
|
||
if not text:
|
||
return ""
|
||
|
||
# Remove quote artifacts
|
||
quote_patterns = [
|
||
(r'\\+"', '"'), # Multiple backslashes before quotes
|
||
(r'\\"', '"'), # Escaped double quotes
|
||
(r"\\'", "'"), # Escaped single quotes
|
||
(r"\\&", "&"), # Escaped ampersands
|
||
(r'""', '"'), # Double quotes
|
||
(r"''", "'"), # Double single quotes
|
||
]
|
||
|
||
for pattern, replacement in quote_patterns:
|
||
text = re.sub(pattern, replacement, text)
|
||
|
||
return text
|
||
|
||
@staticmethod
|
||
def normalize_whitespace(text: str, preserve_paragraphs: bool = True) -> str:
|
||
"""Normalize whitespace while optionally preserving paragraph structure."""
|
||
if not text:
|
||
return ""
|
||
|
||
if preserve_paragraphs:
|
||
# Preserve paragraph breaks (double newlines) but clean up excessive spacing
|
||
text = re.sub(r"[ \t]+", " ", text) # Multiple spaces/tabs -> single space
|
||
text = re.sub(
|
||
r"\n\s*\n\s*\n+", "\n\n", text
|
||
) # Multiple empty lines -> double line break
|
||
text = re.sub(
|
||
r"^\s+|\s+$", "", text, flags=re.MULTILINE
|
||
) # Trim line-level whitespace
|
||
else:
|
||
# Flatten all whitespace for embedding
|
||
text = re.sub(r"\n+", " ", text) # All newlines to spaces
|
||
text = re.sub(r"\s+", " ", text) # All whitespace to single spaces
|
||
|
||
return text.strip()
|
||
|
||
@staticmethod
|
||
def remove_artifacts(text: str) -> str:
|
||
"""Remove document format artifacts and orphaned elements."""
|
||
if not text:
|
||
return ""
|
||
|
||
# Remove orphaned punctuation
|
||
text = re.sub(r"^\s*[)\]}]+\s*", "", text) # Orphaned closing brackets at start
|
||
text = re.sub(
|
||
r"\n\s*[)\]}]+\s*\n", "\n\n", text
|
||
) # Orphaned closing brackets on own lines
|
||
|
||
# Remove excessive punctuation
|
||
text = re.sub(r"[.]{3,}", "...", text) # Multiple dots to ellipsis
|
||
text = re.sub(r"[-]{3,}", "---", text) # Multiple dashes
|
||
|
||
# Remove empty parentheses and brackets
|
||
text = re.sub(r"\(\s*\)", "", text) # Empty parentheses
|
||
text = re.sub(r"\[\s*\]", "", text) # Empty square brackets
|
||
text = re.sub(r"\{\s*\}", "", text) # Empty curly brackets
|
||
|
||
return text
|
||
|
||
@classmethod
|
||
def clean_for_chunking(cls, text: str) -> str:
|
||
"""Clean text for semantic chunking - preserves structure but normalizes content."""
|
||
if not text:
|
||
return ""
|
||
|
||
# Apply all cleaning steps while preserving paragraph structure
|
||
text = cls.normalize_escape_sequences(text)
|
||
text = cls.normalize_unicode(text)
|
||
text = cls.normalize_quotes(text)
|
||
text = cls.remove_artifacts(text)
|
||
text = cls.normalize_whitespace(text, preserve_paragraphs=True)
|
||
|
||
return text
|
||
|
||
@classmethod
|
||
def clean_for_embedding(cls, text: str) -> str:
|
||
"""Clean text for embedding - flattens structure and optimizes for vector similarity."""
|
||
if not text:
|
||
return ""
|
||
|
||
# Start with chunking-level cleaning
|
||
text = cls.clean_for_chunking(text)
|
||
|
||
# Flatten for embedding
|
||
text = cls.normalize_whitespace(text, preserve_paragraphs=False)
|
||
|
||
return text
|
||
|
||
@classmethod
|
||
def clean_for_storage(cls, text: str) -> str:
|
||
"""Clean text for storage - most aggressive cleaning for database storage."""
|
||
if not text:
|
||
return ""
|
||
|
||
# Start with embedding-level cleaning
|
||
text = cls.clean_for_embedding(text)
|
||
|
||
# Additional aggressive cleaning for storage
|
||
text = re.sub(
|
||
r"\\([^a-zA-Z0-9\s])", r"\1", text
|
||
) # Remove any remaining escape sequences
|
||
|
||
return text
|
||
|
||
|
||
def clean_text_content(text: str) -> str:
|
||
"""Legacy function wrapper for backward compatibility."""
|
||
return TextCleaner.clean_for_chunking(text)
|
||
|
||
|
||
def create_semantic_chunks(
|
||
text: str, max_chunk_size: int, overlap_size: int
|
||
) -> TypingList[str]:
|
||
"""Create semantically aware chunks that respect document structure"""
|
||
if not text or len(text) <= max_chunk_size:
|
||
return [text] if text else []
|
||
|
||
chunks = []
|
||
|
||
# Split by double line breaks (paragraphs) first
|
||
paragraphs = text.split("\n\n")
|
||
|
||
current_chunk = ""
|
||
|
||
for paragraph in paragraphs:
|
||
paragraph = paragraph.strip()
|
||
if not paragraph:
|
||
continue
|
||
|
||
# If adding this paragraph would exceed chunk size
|
||
if current_chunk and len(current_chunk) + len(paragraph) + 2 > max_chunk_size:
|
||
# Try to split the current chunk at sentence boundaries if it's too long
|
||
if len(current_chunk) > max_chunk_size:
|
||
sentence_chunks = split_by_sentences(
|
||
current_chunk, max_chunk_size, overlap_size
|
||
)
|
||
chunks.extend(sentence_chunks)
|
||
else:
|
||
chunks.append(current_chunk.strip())
|
||
|
||
# Start new chunk with overlap from previous chunk if applicable
|
||
if chunks and overlap_size > 0:
|
||
prev_chunk = chunks[-1]
|
||
overlap_text = get_text_overlap(prev_chunk, overlap_size)
|
||
current_chunk = (
|
||
overlap_text + "\n\n" + paragraph if overlap_text else paragraph
|
||
)
|
||
else:
|
||
current_chunk = paragraph
|
||
else:
|
||
# Add paragraph to current chunk
|
||
if current_chunk:
|
||
current_chunk += "\n\n" + paragraph
|
||
else:
|
||
current_chunk = paragraph
|
||
|
||
# Add the last chunk
|
||
if current_chunk:
|
||
if len(current_chunk) > max_chunk_size:
|
||
sentence_chunks = split_by_sentences(
|
||
current_chunk, max_chunk_size, overlap_size
|
||
)
|
||
chunks.extend(sentence_chunks)
|
||
else:
|
||
chunks.append(current_chunk.strip())
|
||
|
||
return [chunk for chunk in chunks if chunk.strip()]
|
||
|
||
|
||
def split_by_sentences(
|
||
text: str, max_chunk_size: int, overlap_size: int
|
||
) -> TypingList[str]:
|
||
"""Split text by sentences when paragraph-level splitting isn't sufficient"""
|
||
# Split by sentence endings
|
||
sentences = re.split(r"(?<=[.!?])\s+", text)
|
||
|
||
chunks = []
|
||
current_chunk = ""
|
||
|
||
for sentence in sentences:
|
||
sentence = sentence.strip()
|
||
if not sentence:
|
||
continue
|
||
|
||
# If adding this sentence would exceed chunk size
|
||
if current_chunk and len(current_chunk) + len(sentence) + 1 > max_chunk_size:
|
||
chunks.append(current_chunk.strip())
|
||
|
||
# Start new chunk with overlap
|
||
if overlap_size > 0:
|
||
overlap_text = get_text_overlap(current_chunk, overlap_size)
|
||
current_chunk = (
|
||
overlap_text + " " + sentence if overlap_text else sentence
|
||
)
|
||
else:
|
||
current_chunk = sentence
|
||
else:
|
||
# Add sentence to current chunk
|
||
if current_chunk:
|
||
current_chunk += " " + sentence
|
||
else:
|
||
current_chunk = sentence
|
||
|
||
# Add the last chunk
|
||
if current_chunk:
|
||
chunks.append(current_chunk.strip())
|
||
|
||
return [chunk for chunk in chunks if chunk.strip()]
|
||
|
||
|
||
def get_text_overlap(text: str, overlap_size: int) -> str:
|
||
"""Get the last overlap_size characters from text, preferring word boundaries"""
|
||
if not text or overlap_size <= 0:
|
||
return ""
|
||
|
||
if len(text) <= overlap_size:
|
||
return text
|
||
|
||
# Try to find a good word boundary within the overlap region
|
||
overlap_text = text[-overlap_size:]
|
||
|
||
# Find the first space to avoid cutting words
|
||
space_index = overlap_text.find(" ")
|
||
if space_index > 0:
|
||
return overlap_text[space_index:].strip()
|
||
|
||
return overlap_text.strip()
|
||
|
||
|
||
##########################################
|
||
#
|
||
# API routes
|
||
#
|
||
##########################################
|
||
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
class CollectionNameForm(BaseModel):
|
||
collection_name: Optional[str] = None
|
||
|
||
|
||
class ProcessUrlForm(CollectionNameForm):
|
||
url: str
|
||
|
||
|
||
class SearchForm(BaseModel):
|
||
queries: List[str]
|
||
|
||
|
||
@router.get("/")
|
||
async def get_status(request: Request):
|
||
return {
|
||
"status": True,
|
||
"chunk_size": request.app.state.config.CHUNK_SIZE,
|
||
"chunk_overlap": request.app.state.config.CHUNK_OVERLAP,
|
||
"template": request.app.state.config.RAG_TEMPLATE,
|
||
"embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||
"embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
|
||
"reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
|
||
"embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||
}
|
||
|
||
|
||
@router.get("/embedding")
|
||
async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
|
||
return {
|
||
"status": True,
|
||
"embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||
"embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
|
||
"embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||
"openai_config": {
|
||
"url": request.app.state.config.RAG_OPENAI_API_BASE_URL,
|
||
"key": request.app.state.config.RAG_OPENAI_API_KEY,
|
||
},
|
||
"ollama_config": {
|
||
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
|
||
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
|
||
},
|
||
"azure_openai_config": {
|
||
"url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL,
|
||
"key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY,
|
||
"version": request.app.state.config.RAG_AZURE_OPENAI_API_VERSION,
|
||
},
|
||
}
|
||
|
||
|
||
class OpenAIConfigForm(BaseModel):
|
||
url: str
|
||
key: str
|
||
|
||
|
||
class OllamaConfigForm(BaseModel):
|
||
url: str
|
||
key: str
|
||
|
||
|
||
class AzureOpenAIConfigForm(BaseModel):
|
||
url: str
|
||
key: str
|
||
version: str
|
||
|
||
|
||
class EmbeddingModelUpdateForm(BaseModel):
|
||
openai_config: Optional[OpenAIConfigForm] = None
|
||
ollama_config: Optional[OllamaConfigForm] = None
|
||
azure_openai_config: Optional[AzureOpenAIConfigForm] = None
|
||
embedding_engine: str
|
||
embedding_model: str
|
||
embedding_batch_size: Optional[int] = 1
|
||
|
||
|
||
@router.post("/embedding/update")
|
||
async def update_embedding_config(
|
||
request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
|
||
):
|
||
log.info(
|
||
f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} "
|
||
f"to {form_data.embedding_model}"
|
||
)
|
||
try:
|
||
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
||
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||
|
||
if request.app.state.config.RAG_EMBEDDING_ENGINE in [
|
||
"ollama",
|
||
"openai",
|
||
"azure_openai",
|
||
]:
|
||
if form_data.openai_config is not None:
|
||
request.app.state.config.RAG_OPENAI_API_BASE_URL = (
|
||
form_data.openai_config.url
|
||
)
|
||
request.app.state.config.RAG_OPENAI_API_KEY = (
|
||
form_data.openai_config.key
|
||
)
|
||
|
||
if form_data.ollama_config is not None:
|
||
request.app.state.config.RAG_OLLAMA_BASE_URL = (
|
||
form_data.ollama_config.url
|
||
)
|
||
request.app.state.config.RAG_OLLAMA_API_KEY = (
|
||
form_data.ollama_config.key
|
||
)
|
||
|
||
if form_data.azure_openai_config is not None:
|
||
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
|
||
form_data.azure_openai_config.url
|
||
)
|
||
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
|
||
form_data.azure_openai_config.key
|
||
)
|
||
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
|
||
form_data.azure_openai_config.version
|
||
)
|
||
|
||
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = (
|
||
form_data.embedding_batch_size
|
||
)
|
||
|
||
request.app.state.ef = get_ef(
|
||
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||
request.app.state.config.RAG_EMBEDDING_MODEL,
|
||
)
|
||
|
||
request.app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||
request.app.state.config.RAG_EMBEDDING_MODEL,
|
||
request.app.state.ef,
|
||
(
|
||
request.app.state.config.RAG_OPENAI_API_BASE_URL
|
||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||
else (
|
||
request.app.state.config.RAG_OLLAMA_BASE_URL
|
||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
|
||
else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL
|
||
)
|
||
),
|
||
(
|
||
request.app.state.config.RAG_OPENAI_API_KEY
|
||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||
else (
|
||
request.app.state.config.RAG_OLLAMA_API_KEY
|
||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
|
||
else request.app.state.config.RAG_AZURE_OPENAI_API_KEY
|
||
)
|
||
),
|
||
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||
azure_api_version=(
|
||
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION
|
||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
|
||
else None
|
||
),
|
||
)
|
||
|
||
return {
|
||
"status": True,
|
||
"embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||
"embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
|
||
"embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||
"openai_config": {
|
||
"url": request.app.state.config.RAG_OPENAI_API_BASE_URL,
|
||
"key": request.app.state.config.RAG_OPENAI_API_KEY,
|
||
},
|
||
"ollama_config": {
|
||
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
|
||
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
|
||
},
|
||
"azure_openai_config": {
|
||
"url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL,
|
||
"key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY,
|
||
"version": request.app.state.config.RAG_AZURE_OPENAI_API_VERSION,
|
||
},
|
||
}
|
||
except Exception as e:
|
||
log.exception(f"Problem updating embedding model: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||
)
|
||
|
||
|
||
@router.get("/config")
|
||
async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||
return {
|
||
"status": True,
|
||
# RAG settings
|
||
"RAG_TEMPLATE": request.app.state.config.RAG_TEMPLATE,
|
||
"TOP_K": request.app.state.config.TOP_K,
|
||
"BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL,
|
||
"RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
|
||
# Hybrid search settings
|
||
"ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||
"TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER,
|
||
"RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD,
|
||
"HYBRID_BM25_WEIGHT": request.app.state.config.HYBRID_BM25_WEIGHT,
|
||
# Content extraction settings
|
||
"CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||
"PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES,
|
||
"DATALAB_MARKER_API_KEY": request.app.state.config.DATALAB_MARKER_API_KEY,
|
||
"DATALAB_MARKER_LANGS": request.app.state.config.DATALAB_MARKER_LANGS,
|
||
"DATALAB_MARKER_SKIP_CACHE": request.app.state.config.DATALAB_MARKER_SKIP_CACHE,
|
||
"DATALAB_MARKER_FORCE_OCR": request.app.state.config.DATALAB_MARKER_FORCE_OCR,
|
||
"DATALAB_MARKER_PAGINATE": request.app.state.config.DATALAB_MARKER_PAGINATE,
|
||
"DATALAB_MARKER_STRIP_EXISTING_OCR": request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR,
|
||
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION": request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
|
||
"DATALAB_MARKER_USE_LLM": request.app.state.config.DATALAB_MARKER_USE_LLM,
|
||
"DATALAB_MARKER_OUTPUT_FORMAT": request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT,
|
||
"EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
|
||
"EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
|
||
"TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL,
|
||
"DOCLING_SERVER_URL": request.app.state.config.DOCLING_SERVER_URL,
|
||
"DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE,
|
||
"DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG,
|
||
"DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
|
||
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
|
||
# Reranking settings
|
||
"RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL,
|
||
"RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE,
|
||
"RAG_EXTERNAL_RERANKER_URL": request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
|
||
"RAG_EXTERNAL_RERANKER_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
||
# Chunking settings
|
||
"TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER,
|
||
"CHUNK_SIZE": request.app.state.config.CHUNK_SIZE,
|
||
"CHUNK_OVERLAP": request.app.state.config.CHUNK_OVERLAP,
|
||
# File upload settings
|
||
"FILE_MAX_SIZE": request.app.state.config.FILE_MAX_SIZE,
|
||
"FILE_MAX_COUNT": request.app.state.config.FILE_MAX_COUNT,
|
||
"ALLOWED_FILE_EXTENSIONS": request.app.state.config.ALLOWED_FILE_EXTENSIONS,
|
||
# Integration settings
|
||
"ENABLE_GOOGLE_DRIVE_INTEGRATION": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||
"ENABLE_ONEDRIVE_INTEGRATION": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
|
||
# Web search settings
|
||
"web": {
|
||
"ENABLE_WEB_SEARCH": request.app.state.config.ENABLE_WEB_SEARCH,
|
||
"WEB_SEARCH_ENGINE": request.app.state.config.WEB_SEARCH_ENGINE,
|
||
"WEB_SEARCH_TRUST_ENV": request.app.state.config.WEB_SEARCH_TRUST_ENV,
|
||
"WEB_SEARCH_RESULT_COUNT": request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||
"WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
|
||
"WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||
"BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER,
|
||
"SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL,
|
||
"YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL,
|
||
"YACY_USERNAME": request.app.state.config.YACY_USERNAME,
|
||
"YACY_PASSWORD": request.app.state.config.YACY_PASSWORD,
|
||
"GOOGLE_PSE_API_KEY": request.app.state.config.GOOGLE_PSE_API_KEY,
|
||
"GOOGLE_PSE_ENGINE_ID": request.app.state.config.GOOGLE_PSE_ENGINE_ID,
|
||
"BRAVE_SEARCH_API_KEY": request.app.state.config.BRAVE_SEARCH_API_KEY,
|
||
"KAGI_SEARCH_API_KEY": request.app.state.config.KAGI_SEARCH_API_KEY,
|
||
"MOJEEK_SEARCH_API_KEY": request.app.state.config.MOJEEK_SEARCH_API_KEY,
|
||
"BOCHA_SEARCH_API_KEY": request.app.state.config.BOCHA_SEARCH_API_KEY,
|
||
"SERPSTACK_API_KEY": request.app.state.config.SERPSTACK_API_KEY,
|
||
"SERPSTACK_HTTPS": request.app.state.config.SERPSTACK_HTTPS,
|
||
"SERPER_API_KEY": request.app.state.config.SERPER_API_KEY,
|
||
"SERPLY_API_KEY": request.app.state.config.SERPLY_API_KEY,
|
||
"TAVILY_API_KEY": request.app.state.config.TAVILY_API_KEY,
|
||
"SEARCHAPI_API_KEY": request.app.state.config.SEARCHAPI_API_KEY,
|
||
"SEARCHAPI_ENGINE": request.app.state.config.SEARCHAPI_ENGINE,
|
||
"SERPAPI_API_KEY": request.app.state.config.SERPAPI_API_KEY,
|
||
"SERPAPI_ENGINE": request.app.state.config.SERPAPI_ENGINE,
|
||
"JINA_API_KEY": request.app.state.config.JINA_API_KEY,
|
||
"BING_SEARCH_V7_ENDPOINT": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
|
||
"BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||
"EXA_API_KEY": request.app.state.config.EXA_API_KEY,
|
||
"PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY,
|
||
"SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID,
|
||
"SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK,
|
||
"WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE,
|
||
"ENABLE_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
||
"PLAYWRIGHT_WS_URL": request.app.state.config.PLAYWRIGHT_WS_URL,
|
||
"PLAYWRIGHT_TIMEOUT": request.app.state.config.PLAYWRIGHT_TIMEOUT,
|
||
"FIRECRAWL_API_KEY": request.app.state.config.FIRECRAWL_API_KEY,
|
||
"FIRECRAWL_API_BASE_URL": request.app.state.config.FIRECRAWL_API_BASE_URL,
|
||
"TAVILY_EXTRACT_DEPTH": request.app.state.config.TAVILY_EXTRACT_DEPTH,
|
||
"EXTERNAL_WEB_SEARCH_URL": request.app.state.config.EXTERNAL_WEB_SEARCH_URL,
|
||
"EXTERNAL_WEB_SEARCH_API_KEY": request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY,
|
||
"EXTERNAL_WEB_LOADER_URL": request.app.state.config.EXTERNAL_WEB_LOADER_URL,
|
||
"EXTERNAL_WEB_LOADER_API_KEY": request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY,
|
||
"YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
|
||
"YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
||
"YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION,
|
||
},
|
||
}
|
||
|
||
|
||
class WebConfig(BaseModel):
|
||
ENABLE_WEB_SEARCH: Optional[bool] = None
|
||
WEB_SEARCH_ENGINE: Optional[str] = None
|
||
WEB_SEARCH_TRUST_ENV: Optional[bool] = None
|
||
WEB_SEARCH_RESULT_COUNT: Optional[int] = None
|
||
WEB_SEARCH_CONCURRENT_REQUESTS: Optional[int] = None
|
||
WEB_SEARCH_DOMAIN_FILTER_LIST: Optional[List[str]] = []
|
||
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
|
||
BYPASS_WEB_SEARCH_WEB_LOADER: Optional[bool] = None
|
||
SEARXNG_QUERY_URL: Optional[str] = None
|
||
YACY_QUERY_URL: Optional[str] = None
|
||
YACY_USERNAME: Optional[str] = None
|
||
YACY_PASSWORD: Optional[str] = None
|
||
GOOGLE_PSE_API_KEY: Optional[str] = None
|
||
GOOGLE_PSE_ENGINE_ID: Optional[str] = None
|
||
BRAVE_SEARCH_API_KEY: Optional[str] = None
|
||
KAGI_SEARCH_API_KEY: Optional[str] = None
|
||
MOJEEK_SEARCH_API_KEY: Optional[str] = None
|
||
BOCHA_SEARCH_API_KEY: Optional[str] = None
|
||
SERPSTACK_API_KEY: Optional[str] = None
|
||
SERPSTACK_HTTPS: Optional[bool] = None
|
||
SERPER_API_KEY: Optional[str] = None
|
||
SERPLY_API_KEY: Optional[str] = None
|
||
TAVILY_API_KEY: Optional[str] = None
|
||
SEARCHAPI_API_KEY: Optional[str] = None
|
||
SEARCHAPI_ENGINE: Optional[str] = None
|
||
SERPAPI_API_KEY: Optional[str] = None
|
||
SERPAPI_ENGINE: Optional[str] = None
|
||
JINA_API_KEY: Optional[str] = None
|
||
BING_SEARCH_V7_ENDPOINT: Optional[str] = None
|
||
BING_SEARCH_V7_SUBSCRIPTION_KEY: Optional[str] = None
|
||
EXA_API_KEY: Optional[str] = None
|
||
PERPLEXITY_API_KEY: Optional[str] = None
|
||
SOUGOU_API_SID: Optional[str] = None
|
||
SOUGOU_API_SK: Optional[str] = None
|
||
WEB_LOADER_ENGINE: Optional[str] = None
|
||
ENABLE_WEB_LOADER_SSL_VERIFICATION: Optional[bool] = None
|
||
PLAYWRIGHT_WS_URL: Optional[str] = None
|
||
PLAYWRIGHT_TIMEOUT: Optional[int] = None
|
||
FIRECRAWL_API_KEY: Optional[str] = None
|
||
FIRECRAWL_API_BASE_URL: Optional[str] = None
|
||
TAVILY_EXTRACT_DEPTH: Optional[str] = None
|
||
EXTERNAL_WEB_SEARCH_URL: Optional[str] = None
|
||
EXTERNAL_WEB_SEARCH_API_KEY: Optional[str] = None
|
||
EXTERNAL_WEB_LOADER_URL: Optional[str] = None
|
||
EXTERNAL_WEB_LOADER_API_KEY: Optional[str] = None
|
||
YOUTUBE_LOADER_LANGUAGE: Optional[List[str]] = None
|
||
YOUTUBE_LOADER_PROXY_URL: Optional[str] = None
|
||
YOUTUBE_LOADER_TRANSLATION: Optional[str] = None
|
||
|
||
|
||
class ConfigForm(BaseModel):
|
||
# RAG settings
|
||
RAG_TEMPLATE: Optional[str] = None
|
||
TOP_K: Optional[int] = None
|
||
BYPASS_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
|
||
RAG_FULL_CONTEXT: Optional[bool] = None
|
||
|
||
# Hybrid search settings
|
||
ENABLE_RAG_HYBRID_SEARCH: Optional[bool] = None
|
||
TOP_K_RERANKER: Optional[int] = None
|
||
RELEVANCE_THRESHOLD: Optional[float] = None
|
||
HYBRID_BM25_WEIGHT: Optional[float] = None
|
||
|
||
# Content extraction settings
|
||
CONTENT_EXTRACTION_ENGINE: Optional[str] = None
|
||
PDF_EXTRACT_IMAGES: Optional[bool] = None
|
||
DATALAB_MARKER_API_KEY: Optional[str] = None
|
||
DATALAB_MARKER_LANGS: Optional[str] = None
|
||
DATALAB_MARKER_SKIP_CACHE: Optional[bool] = None
|
||
DATALAB_MARKER_FORCE_OCR: Optional[bool] = None
|
||
DATALAB_MARKER_PAGINATE: Optional[bool] = None
|
||
DATALAB_MARKER_STRIP_EXISTING_OCR: Optional[bool] = None
|
||
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION: Optional[bool] = None
|
||
DATALAB_MARKER_USE_LLM: Optional[bool] = None
|
||
DATALAB_MARKER_OUTPUT_FORMAT: Optional[str] = None
|
||
EXTERNAL_DOCUMENT_LOADER_URL: Optional[str] = None
|
||
EXTERNAL_DOCUMENT_LOADER_API_KEY: Optional[str] = None
|
||
|
||
TIKA_SERVER_URL: Optional[str] = None
|
||
DOCLING_SERVER_URL: Optional[str] = None
|
||
DOCLING_OCR_ENGINE: Optional[str] = None
|
||
DOCLING_OCR_LANG: Optional[str] = None
|
||
DOCLING_DO_PICTURE_DESCRIPTION: Optional[bool] = None
|
||
DOCUMENT_INTELLIGENCE_ENDPOINT: Optional[str] = None
|
||
DOCUMENT_INTELLIGENCE_KEY: Optional[str] = None
|
||
MISTRAL_OCR_API_KEY: Optional[str] = None
|
||
|
||
# Reranking settings
|
||
RAG_RERANKING_MODEL: Optional[str] = None
|
||
RAG_RERANKING_ENGINE: Optional[str] = None
|
||
RAG_EXTERNAL_RERANKER_URL: Optional[str] = None
|
||
RAG_EXTERNAL_RERANKER_API_KEY: Optional[str] = None
|
||
|
||
# Chunking settings
|
||
TEXT_SPLITTER: Optional[str] = None
|
||
CHUNK_SIZE: Optional[int] = None
|
||
CHUNK_OVERLAP: Optional[int] = None
|
||
|
||
# File upload settings
|
||
FILE_MAX_SIZE: Optional[int] = None
|
||
FILE_MAX_COUNT: Optional[int] = None
|
||
ALLOWED_FILE_EXTENSIONS: Optional[List[str]] = None
|
||
|
||
# Integration settings
|
||
ENABLE_GOOGLE_DRIVE_INTEGRATION: Optional[bool] = None
|
||
ENABLE_ONEDRIVE_INTEGRATION: Optional[bool] = None
|
||
|
||
# Web search settings
|
||
web: Optional[WebConfig] = None
|
||
|
||
|
||
@router.post("/config/update")
|
||
async def update_rag_config(
|
||
request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
|
||
):
|
||
# RAG settings
|
||
request.app.state.config.RAG_TEMPLATE = (
|
||
form_data.RAG_TEMPLATE
|
||
if form_data.RAG_TEMPLATE is not None
|
||
else request.app.state.config.RAG_TEMPLATE
|
||
)
|
||
request.app.state.config.TOP_K = (
|
||
form_data.TOP_K
|
||
if form_data.TOP_K is not None
|
||
else request.app.state.config.TOP_K
|
||
)
|
||
request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = (
|
||
form_data.BYPASS_EMBEDDING_AND_RETRIEVAL
|
||
if form_data.BYPASS_EMBEDDING_AND_RETRIEVAL is not None
|
||
else request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
|
||
)
|
||
request.app.state.config.RAG_FULL_CONTEXT = (
|
||
form_data.RAG_FULL_CONTEXT
|
||
if form_data.RAG_FULL_CONTEXT is not None
|
||
else request.app.state.config.RAG_FULL_CONTEXT
|
||
)
|
||
|
||
# Hybrid search settings
|
||
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
|
||
form_data.ENABLE_RAG_HYBRID_SEARCH
|
||
if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
|
||
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
|
||
)
|
||
# Free up memory if hybrid search is disabled
|
||
if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
|
||
request.app.state.rf = None
|
||
|
||
request.app.state.config.TOP_K_RERANKER = (
|
||
form_data.TOP_K_RERANKER
|
||
if form_data.TOP_K_RERANKER is not None
|
||
else request.app.state.config.TOP_K_RERANKER
|
||
)
|
||
request.app.state.config.RELEVANCE_THRESHOLD = (
|
||
form_data.RELEVANCE_THRESHOLD
|
||
if form_data.RELEVANCE_THRESHOLD is not None
|
||
else request.app.state.config.RELEVANCE_THRESHOLD
|
||
)
|
||
request.app.state.config.HYBRID_BM25_WEIGHT = (
|
||
form_data.HYBRID_BM25_WEIGHT
|
||
if form_data.HYBRID_BM25_WEIGHT is not None
|
||
else request.app.state.config.HYBRID_BM25_WEIGHT
|
||
)
|
||
|
||
# Content extraction settings
|
||
request.app.state.config.CONTENT_EXTRACTION_ENGINE = (
|
||
form_data.CONTENT_EXTRACTION_ENGINE
|
||
if form_data.CONTENT_EXTRACTION_ENGINE is not None
|
||
else request.app.state.config.CONTENT_EXTRACTION_ENGINE
|
||
)
|
||
request.app.state.config.PDF_EXTRACT_IMAGES = (
|
||
form_data.PDF_EXTRACT_IMAGES
|
||
if form_data.PDF_EXTRACT_IMAGES is not None
|
||
else request.app.state.config.PDF_EXTRACT_IMAGES
|
||
)
|
||
request.app.state.config.DATALAB_MARKER_API_KEY = (
|
||
form_data.DATALAB_MARKER_API_KEY
|
||
if form_data.DATALAB_MARKER_API_KEY is not None
|
||
else request.app.state.config.DATALAB_MARKER_API_KEY
|
||
)
|
||
request.app.state.config.DATALAB_MARKER_LANGS = (
|
||
form_data.DATALAB_MARKER_LANGS
|
||
if form_data.DATALAB_MARKER_LANGS is not None
|
||
else request.app.state.config.DATALAB_MARKER_LANGS
|
||
)
|
||
request.app.state.config.DATALAB_MARKER_SKIP_CACHE = (
|
||
form_data.DATALAB_MARKER_SKIP_CACHE
|
||
if form_data.DATALAB_MARKER_SKIP_CACHE is not None
|
||
else request.app.state.config.DATALAB_MARKER_SKIP_CACHE
|
||
)
|
||
request.app.state.config.DATALAB_MARKER_FORCE_OCR = (
|
||
form_data.DATALAB_MARKER_FORCE_OCR
|
||
if form_data.DATALAB_MARKER_FORCE_OCR is not None
|
||
else request.app.state.config.DATALAB_MARKER_FORCE_OCR
|
||
)
|
||
request.app.state.config.DATALAB_MARKER_PAGINATE = (
|
||
form_data.DATALAB_MARKER_PAGINATE
|
||
if form_data.DATALAB_MARKER_PAGINATE is not None
|
||
else request.app.state.config.DATALAB_MARKER_PAGINATE
|
||
)
|
||
request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR = (
|
||
form_data.DATALAB_MARKER_STRIP_EXISTING_OCR
|
||
if form_data.DATALAB_MARKER_STRIP_EXISTING_OCR is not None
|
||
else request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR
|
||
)
|
||
request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = (
|
||
form_data.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION
|
||
if form_data.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION is not None
|
||
else request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION
|
||
)
|
||
request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT = (
|
||
form_data.DATALAB_MARKER_OUTPUT_FORMAT
|
||
if form_data.DATALAB_MARKER_OUTPUT_FORMAT is not None
|
||
else request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT
|
||
)
|
||
request.app.state.config.DATALAB_MARKER_USE_LLM = (
|
||
form_data.DATALAB_MARKER_USE_LLM
|
||
if form_data.DATALAB_MARKER_USE_LLM is not None
|
||
else request.app.state.config.DATALAB_MARKER_USE_LLM
|
||
)
|
||
request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL = (
|
||
form_data.EXTERNAL_DOCUMENT_LOADER_URL
|
||
if form_data.EXTERNAL_DOCUMENT_LOADER_URL is not None
|
||
else request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL
|
||
)
|
||
request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY = (
|
||
form_data.EXTERNAL_DOCUMENT_LOADER_API_KEY
|
||
if form_data.EXTERNAL_DOCUMENT_LOADER_API_KEY is not None
|
||
else request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY
|
||
)
|
||
request.app.state.config.TIKA_SERVER_URL = (
|
||
form_data.TIKA_SERVER_URL
|
||
if form_data.TIKA_SERVER_URL is not None
|
||
else request.app.state.config.TIKA_SERVER_URL
|
||
)
|
||
request.app.state.config.DOCLING_SERVER_URL = (
|
||
form_data.DOCLING_SERVER_URL
|
||
if form_data.DOCLING_SERVER_URL is not None
|
||
else request.app.state.config.DOCLING_SERVER_URL
|
||
)
|
||
request.app.state.config.DOCLING_OCR_ENGINE = (
|
||
form_data.DOCLING_OCR_ENGINE
|
||
if form_data.DOCLING_OCR_ENGINE is not None
|
||
else request.app.state.config.DOCLING_OCR_ENGINE
|
||
)
|
||
request.app.state.config.DOCLING_OCR_LANG = (
|
||
form_data.DOCLING_OCR_LANG
|
||
if form_data.DOCLING_OCR_LANG is not None
|
||
else request.app.state.config.DOCLING_OCR_LANG
|
||
)
|
||
|
||
request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION = (
|
||
form_data.DOCLING_DO_PICTURE_DESCRIPTION
|
||
if form_data.DOCLING_DO_PICTURE_DESCRIPTION is not None
|
||
else request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION
|
||
)
|
||
|
||
request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = (
|
||
form_data.DOCUMENT_INTELLIGENCE_ENDPOINT
|
||
if form_data.DOCUMENT_INTELLIGENCE_ENDPOINT is not None
|
||
else request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT
|
||
)
|
||
request.app.state.config.DOCUMENT_INTELLIGENCE_KEY = (
|
||
form_data.DOCUMENT_INTELLIGENCE_KEY
|
||
if form_data.DOCUMENT_INTELLIGENCE_KEY is not None
|
||
else request.app.state.config.DOCUMENT_INTELLIGENCE_KEY
|
||
)
|
||
request.app.state.config.MISTRAL_OCR_API_KEY = (
|
||
form_data.MISTRAL_OCR_API_KEY
|
||
if form_data.MISTRAL_OCR_API_KEY is not None
|
||
else request.app.state.config.MISTRAL_OCR_API_KEY
|
||
)
|
||
|
||
# Reranking settings
|
||
request.app.state.config.RAG_RERANKING_ENGINE = (
|
||
form_data.RAG_RERANKING_ENGINE
|
||
if form_data.RAG_RERANKING_ENGINE is not None
|
||
else request.app.state.config.RAG_RERANKING_ENGINE
|
||
)
|
||
|
||
request.app.state.config.RAG_EXTERNAL_RERANKER_URL = (
|
||
form_data.RAG_EXTERNAL_RERANKER_URL
|
||
if form_data.RAG_EXTERNAL_RERANKER_URL is not None
|
||
else request.app.state.config.RAG_EXTERNAL_RERANKER_URL
|
||
)
|
||
|
||
request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY = (
|
||
form_data.RAG_EXTERNAL_RERANKER_API_KEY
|
||
if form_data.RAG_EXTERNAL_RERANKER_API_KEY is not None
|
||
else request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY
|
||
)
|
||
|
||
log.info(
|
||
f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}"
|
||
)
|
||
try:
|
||
request.app.state.config.RAG_RERANKING_MODEL = form_data.RAG_RERANKING_MODEL
|
||
|
||
try:
|
||
request.app.state.rf = get_rf(
|
||
request.app.state.config.RAG_RERANKING_ENGINE,
|
||
request.app.state.config.RAG_RERANKING_MODEL,
|
||
request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
|
||
request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
||
True,
|
||
)
|
||
except Exception as e:
|
||
log.error(f"Error loading reranking model: {e}")
|
||
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
||
except Exception as e:
|
||
log.exception(f"Problem updating reranking model: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||
)
|
||
|
||
# Chunking settings
|
||
request.app.state.config.TEXT_SPLITTER = (
|
||
form_data.TEXT_SPLITTER
|
||
if form_data.TEXT_SPLITTER is not None
|
||
else request.app.state.config.TEXT_SPLITTER
|
||
)
|
||
request.app.state.config.CHUNK_SIZE = (
|
||
form_data.CHUNK_SIZE
|
||
if form_data.CHUNK_SIZE is not None
|
||
else request.app.state.config.CHUNK_SIZE
|
||
)
|
||
request.app.state.config.CHUNK_OVERLAP = (
|
||
form_data.CHUNK_OVERLAP
|
||
if form_data.CHUNK_OVERLAP is not None
|
||
else request.app.state.config.CHUNK_OVERLAP
|
||
)
|
||
|
||
# File upload settings
|
||
request.app.state.config.FILE_MAX_SIZE = (
|
||
form_data.FILE_MAX_SIZE
|
||
if form_data.FILE_MAX_SIZE is not None
|
||
else request.app.state.config.FILE_MAX_SIZE
|
||
)
|
||
request.app.state.config.FILE_MAX_COUNT = (
|
||
form_data.FILE_MAX_COUNT
|
||
if form_data.FILE_MAX_COUNT is not None
|
||
else request.app.state.config.FILE_MAX_COUNT
|
||
)
|
||
request.app.state.config.ALLOWED_FILE_EXTENSIONS = (
|
||
form_data.ALLOWED_FILE_EXTENSIONS
|
||
if form_data.ALLOWED_FILE_EXTENSIONS is not None
|
||
else request.app.state.config.ALLOWED_FILE_EXTENSIONS
|
||
)
|
||
|
||
# Integration settings
|
||
request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = (
|
||
form_data.ENABLE_GOOGLE_DRIVE_INTEGRATION
|
||
if form_data.ENABLE_GOOGLE_DRIVE_INTEGRATION is not None
|
||
else request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION
|
||
)
|
||
request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION = (
|
||
form_data.ENABLE_ONEDRIVE_INTEGRATION
|
||
if form_data.ENABLE_ONEDRIVE_INTEGRATION is not None
|
||
else request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION
|
||
)
|
||
|
||
if form_data.web is not None:
|
||
# Web search settings
|
||
request.app.state.config.ENABLE_WEB_SEARCH = form_data.web.ENABLE_WEB_SEARCH
|
||
request.app.state.config.WEB_SEARCH_ENGINE = form_data.web.WEB_SEARCH_ENGINE
|
||
request.app.state.config.WEB_SEARCH_TRUST_ENV = (
|
||
form_data.web.WEB_SEARCH_TRUST_ENV
|
||
)
|
||
request.app.state.config.WEB_SEARCH_RESULT_COUNT = (
|
||
form_data.web.WEB_SEARCH_RESULT_COUNT
|
||
)
|
||
request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS = (
|
||
form_data.web.WEB_SEARCH_CONCURRENT_REQUESTS
|
||
)
|
||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST = (
|
||
form_data.web.WEB_SEARCH_DOMAIN_FILTER_LIST
|
||
)
|
||
request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
|
||
form_data.web.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
|
||
)
|
||
request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = (
|
||
form_data.web.BYPASS_WEB_SEARCH_WEB_LOADER
|
||
)
|
||
request.app.state.config.SEARXNG_QUERY_URL = form_data.web.SEARXNG_QUERY_URL
|
||
request.app.state.config.YACY_QUERY_URL = form_data.web.YACY_QUERY_URL
|
||
request.app.state.config.YACY_USERNAME = form_data.web.YACY_USERNAME
|
||
request.app.state.config.YACY_PASSWORD = form_data.web.YACY_PASSWORD
|
||
request.app.state.config.GOOGLE_PSE_API_KEY = form_data.web.GOOGLE_PSE_API_KEY
|
||
request.app.state.config.GOOGLE_PSE_ENGINE_ID = (
|
||
form_data.web.GOOGLE_PSE_ENGINE_ID
|
||
)
|
||
request.app.state.config.BRAVE_SEARCH_API_KEY = (
|
||
form_data.web.BRAVE_SEARCH_API_KEY
|
||
)
|
||
request.app.state.config.KAGI_SEARCH_API_KEY = form_data.web.KAGI_SEARCH_API_KEY
|
||
request.app.state.config.MOJEEK_SEARCH_API_KEY = (
|
||
form_data.web.MOJEEK_SEARCH_API_KEY
|
||
)
|
||
request.app.state.config.BOCHA_SEARCH_API_KEY = (
|
||
form_data.web.BOCHA_SEARCH_API_KEY
|
||
)
|
||
request.app.state.config.SERPSTACK_API_KEY = form_data.web.SERPSTACK_API_KEY
|
||
request.app.state.config.SERPSTACK_HTTPS = form_data.web.SERPSTACK_HTTPS
|
||
request.app.state.config.SERPER_API_KEY = form_data.web.SERPER_API_KEY
|
||
request.app.state.config.SERPLY_API_KEY = form_data.web.SERPLY_API_KEY
|
||
request.app.state.config.TAVILY_API_KEY = form_data.web.TAVILY_API_KEY
|
||
request.app.state.config.SEARCHAPI_API_KEY = form_data.web.SEARCHAPI_API_KEY
|
||
request.app.state.config.SEARCHAPI_ENGINE = form_data.web.SEARCHAPI_ENGINE
|
||
request.app.state.config.SERPAPI_API_KEY = form_data.web.SERPAPI_API_KEY
|
||
request.app.state.config.SERPAPI_ENGINE = form_data.web.SERPAPI_ENGINE
|
||
request.app.state.config.JINA_API_KEY = form_data.web.JINA_API_KEY
|
||
request.app.state.config.BING_SEARCH_V7_ENDPOINT = (
|
||
form_data.web.BING_SEARCH_V7_ENDPOINT
|
||
)
|
||
request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = (
|
||
form_data.web.BING_SEARCH_V7_SUBSCRIPTION_KEY
|
||
)
|
||
request.app.state.config.EXA_API_KEY = form_data.web.EXA_API_KEY
|
||
request.app.state.config.PERPLEXITY_API_KEY = form_data.web.PERPLEXITY_API_KEY
|
||
request.app.state.config.SOUGOU_API_SID = form_data.web.SOUGOU_API_SID
|
||
request.app.state.config.SOUGOU_API_SK = form_data.web.SOUGOU_API_SK
|
||
|
||
# Web loader settings
|
||
request.app.state.config.WEB_LOADER_ENGINE = form_data.web.WEB_LOADER_ENGINE
|
||
request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = (
|
||
form_data.web.ENABLE_WEB_LOADER_SSL_VERIFICATION
|
||
)
|
||
request.app.state.config.PLAYWRIGHT_WS_URL = form_data.web.PLAYWRIGHT_WS_URL
|
||
request.app.state.config.PLAYWRIGHT_TIMEOUT = form_data.web.PLAYWRIGHT_TIMEOUT
|
||
request.app.state.config.FIRECRAWL_API_KEY = form_data.web.FIRECRAWL_API_KEY
|
||
request.app.state.config.FIRECRAWL_API_BASE_URL = (
|
||
form_data.web.FIRECRAWL_API_BASE_URL
|
||
)
|
||
request.app.state.config.EXTERNAL_WEB_SEARCH_URL = (
|
||
form_data.web.EXTERNAL_WEB_SEARCH_URL
|
||
)
|
||
request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY = (
|
||
form_data.web.EXTERNAL_WEB_SEARCH_API_KEY
|
||
)
|
||
request.app.state.config.EXTERNAL_WEB_LOADER_URL = (
|
||
form_data.web.EXTERNAL_WEB_LOADER_URL
|
||
)
|
||
request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY = (
|
||
form_data.web.EXTERNAL_WEB_LOADER_API_KEY
|
||
)
|
||
request.app.state.config.TAVILY_EXTRACT_DEPTH = (
|
||
form_data.web.TAVILY_EXTRACT_DEPTH
|
||
)
|
||
request.app.state.config.YOUTUBE_LOADER_LANGUAGE = (
|
||
form_data.web.YOUTUBE_LOADER_LANGUAGE
|
||
)
|
||
request.app.state.config.YOUTUBE_LOADER_PROXY_URL = (
|
||
form_data.web.YOUTUBE_LOADER_PROXY_URL
|
||
)
|
||
request.app.state.YOUTUBE_LOADER_TRANSLATION = (
|
||
form_data.web.YOUTUBE_LOADER_TRANSLATION
|
||
)
|
||
|
||
return {
|
||
"status": True,
|
||
# RAG settings
|
||
"RAG_TEMPLATE": request.app.state.config.RAG_TEMPLATE,
|
||
"TOP_K": request.app.state.config.TOP_K,
|
||
"BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL,
|
||
"RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
|
||
# Hybrid search settings
|
||
"ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||
"TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER,
|
||
"RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD,
|
||
"HYBRID_BM25_WEIGHT": request.app.state.config.HYBRID_BM25_WEIGHT,
|
||
# Content extraction settings
|
||
"CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||
"PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES,
|
||
"DATALAB_MARKER_API_KEY": request.app.state.config.DATALAB_MARKER_API_KEY,
|
||
"DATALAB_MARKER_LANGS": request.app.state.config.DATALAB_MARKER_LANGS,
|
||
"DATALAB_MARKER_SKIP_CACHE": request.app.state.config.DATALAB_MARKER_SKIP_CACHE,
|
||
"DATALAB_MARKER_FORCE_OCR": request.app.state.config.DATALAB_MARKER_FORCE_OCR,
|
||
"DATALAB_MARKER_PAGINATE": request.app.state.config.DATALAB_MARKER_PAGINATE,
|
||
"DATALAB_MARKER_STRIP_EXISTING_OCR": request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR,
|
||
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION": request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
|
||
"DATALAB_MARKER_USE_LLM": request.app.state.config.DATALAB_MARKER_USE_LLM,
|
||
"DATALAB_MARKER_OUTPUT_FORMAT": request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT,
|
||
"EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
|
||
"EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
|
||
"TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL,
|
||
"DOCLING_SERVER_URL": request.app.state.config.DOCLING_SERVER_URL,
|
||
"DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE,
|
||
"DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG,
|
||
"DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
|
||
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
|
||
# Reranking settings
|
||
"RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL,
|
||
"RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE,
|
||
"RAG_EXTERNAL_RERANKER_URL": request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
|
||
"RAG_EXTERNAL_RERANKER_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
||
# Chunking settings
|
||
"TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER,
|
||
"CHUNK_SIZE": request.app.state.config.CHUNK_SIZE,
|
||
"CHUNK_OVERLAP": request.app.state.config.CHUNK_OVERLAP,
|
||
# File upload settings
|
||
"FILE_MAX_SIZE": request.app.state.config.FILE_MAX_SIZE,
|
||
"FILE_MAX_COUNT": request.app.state.config.FILE_MAX_COUNT,
|
||
"ALLOWED_FILE_EXTENSIONS": request.app.state.config.ALLOWED_FILE_EXTENSIONS,
|
||
# Integration settings
|
||
"ENABLE_GOOGLE_DRIVE_INTEGRATION": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||
"ENABLE_ONEDRIVE_INTEGRATION": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
|
||
# Web search settings
|
||
"web": {
|
||
"ENABLE_WEB_SEARCH": request.app.state.config.ENABLE_WEB_SEARCH,
|
||
"WEB_SEARCH_ENGINE": request.app.state.config.WEB_SEARCH_ENGINE,
|
||
"WEB_SEARCH_TRUST_ENV": request.app.state.config.WEB_SEARCH_TRUST_ENV,
|
||
"WEB_SEARCH_RESULT_COUNT": request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||
"WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
|
||
"WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||
"BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER,
|
||
"SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL,
|
||
"YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL,
|
||
"YACY_USERNAME": request.app.state.config.YACY_USERNAME,
|
||
"YACY_PASSWORD": request.app.state.config.YACY_PASSWORD,
|
||
"GOOGLE_PSE_API_KEY": request.app.state.config.GOOGLE_PSE_API_KEY,
|
||
"GOOGLE_PSE_ENGINE_ID": request.app.state.config.GOOGLE_PSE_ENGINE_ID,
|
||
"BRAVE_SEARCH_API_KEY": request.app.state.config.BRAVE_SEARCH_API_KEY,
|
||
"KAGI_SEARCH_API_KEY": request.app.state.config.KAGI_SEARCH_API_KEY,
|
||
"MOJEEK_SEARCH_API_KEY": request.app.state.config.MOJEEK_SEARCH_API_KEY,
|
||
"BOCHA_SEARCH_API_KEY": request.app.state.config.BOCHA_SEARCH_API_KEY,
|
||
"SERPSTACK_API_KEY": request.app.state.config.SERPSTACK_API_KEY,
|
||
"SERPSTACK_HTTPS": request.app.state.config.SERPSTACK_HTTPS,
|
||
"SERPER_API_KEY": request.app.state.config.SERPER_API_KEY,
|
||
"SERPLY_API_KEY": request.app.state.config.SERPLY_API_KEY,
|
||
"TAVILY_API_KEY": request.app.state.config.TAVILY_API_KEY,
|
||
"SEARCHAPI_API_KEY": request.app.state.config.SEARCHAPI_API_KEY,
|
||
"SEARCHAPI_ENGINE": request.app.state.config.SEARCHAPI_ENGINE,
|
||
"SERPAPI_API_KEY": request.app.state.config.SERPAPI_API_KEY,
|
||
"SERPAPI_ENGINE": request.app.state.config.SERPAPI_ENGINE,
|
||
"JINA_API_KEY": request.app.state.config.JINA_API_KEY,
|
||
"BING_SEARCH_V7_ENDPOINT": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
|
||
"BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||
"EXA_API_KEY": request.app.state.config.EXA_API_KEY,
|
||
"PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY,
|
||
"SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID,
|
||
"SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK,
|
||
"WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE,
|
||
"ENABLE_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
||
"PLAYWRIGHT_WS_URL": request.app.state.config.PLAYWRIGHT_WS_URL,
|
||
"PLAYWRIGHT_TIMEOUT": request.app.state.config.PLAYWRIGHT_TIMEOUT,
|
||
"FIRECRAWL_API_KEY": request.app.state.config.FIRECRAWL_API_KEY,
|
||
"FIRECRAWL_API_BASE_URL": request.app.state.config.FIRECRAWL_API_BASE_URL,
|
||
"TAVILY_EXTRACT_DEPTH": request.app.state.config.TAVILY_EXTRACT_DEPTH,
|
||
"EXTERNAL_WEB_SEARCH_URL": request.app.state.config.EXTERNAL_WEB_SEARCH_URL,
|
||
"EXTERNAL_WEB_SEARCH_API_KEY": request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY,
|
||
"EXTERNAL_WEB_LOADER_URL": request.app.state.config.EXTERNAL_WEB_LOADER_URL,
|
||
"EXTERNAL_WEB_LOADER_API_KEY": request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY,
|
||
"YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
|
||
"YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
||
"YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION,
|
||
},
|
||
}
|
||
|
||
|
||
####################################
|
||
#
|
||
# Document process and retrieval
|
||
#
|
||
####################################
|
||
|
||
|
||
def save_docs_to_vector_db(
|
||
request: Request,
|
||
docs,
|
||
collection_name,
|
||
metadata: Optional[dict] = None,
|
||
overwrite: bool = False,
|
||
split: bool = True,
|
||
add: bool = False,
|
||
user=None,
|
||
) -> bool:
|
||
def _get_docs_info(docs: list[Document]) -> str:
|
||
docs_info = set()
|
||
|
||
# Trying to select relevant metadata identifying the document.
|
||
for doc in docs:
|
||
metadata = getattr(doc, "metadata", {})
|
||
doc_name = metadata.get("name", "")
|
||
if not doc_name:
|
||
doc_name = metadata.get("title", "")
|
||
if not doc_name:
|
||
doc_name = metadata.get("source", "")
|
||
if doc_name:
|
||
docs_info.add(doc_name)
|
||
|
||
return ", ".join(docs_info)
|
||
|
||
log.info(
|
||
f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}"
|
||
)
|
||
|
||
# Check if entries with the same hash (metadata.hash) already exist
|
||
if metadata and "hash" in metadata:
|
||
result = VECTOR_DB_CLIENT.query(
|
||
collection_name=collection_name,
|
||
filter={"hash": metadata["hash"]},
|
||
)
|
||
|
||
if result is not None:
|
||
existing_doc_ids = result.ids[0]
|
||
if existing_doc_ids:
|
||
log.info(f"Document with hash {metadata['hash']} already exists")
|
||
raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT)
|
||
|
||
if split:
|
||
# Apply advanced content-aware splitting and text cleaning
|
||
processed_docs = []
|
||
|
||
for doc in docs:
|
||
# Clean the text content before chunking
|
||
if not doc.page_content:
|
||
continue
|
||
|
||
# Apply text cleaning before chunking using new modular system
|
||
cleaned_content = TextCleaner.clean_for_chunking(doc.page_content)
|
||
|
||
# Create semantic chunks from cleaned content
|
||
chunks = create_semantic_chunks(
|
||
cleaned_content,
|
||
request.app.state.config.CHUNK_SIZE,
|
||
request.app.state.config.CHUNK_OVERLAP,
|
||
)
|
||
|
||
# Create new documents for each chunk
|
||
for i, chunk in enumerate(chunks):
|
||
chunk_metadata = {
|
||
**doc.metadata,
|
||
"chunk_index": i,
|
||
"total_chunks": len(chunks),
|
||
}
|
||
processed_docs.append(
|
||
Document(page_content=chunk, metadata=chunk_metadata)
|
||
)
|
||
|
||
docs = processed_docs
|
||
|
||
if len(docs) == 0:
|
||
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
|
||
|
||
texts = [doc.page_content for doc in docs]
|
||
metadatas = [
|
||
{
|
||
**doc.metadata,
|
||
**(metadata if metadata else {}),
|
||
"embedding_config": json.dumps(
|
||
{
|
||
"engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||
"model": request.app.state.config.RAG_EMBEDDING_MODEL,
|
||
}
|
||
),
|
||
}
|
||
for doc in docs
|
||
]
|
||
|
||
# ChromaDB does not like datetime formats
|
||
# for meta-data so convert them to string.
|
||
for metadata in metadatas:
|
||
for key, value in metadata.items():
|
||
if (
|
||
isinstance(value, datetime)
|
||
or isinstance(value, list)
|
||
or isinstance(value, dict)
|
||
):
|
||
metadata[key] = str(value)
|
||
|
||
try:
|
||
if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
|
||
log.info(f"collection {collection_name} already exists")
|
||
|
||
if overwrite:
|
||
VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
|
||
log.info(f"deleting existing collection {collection_name}")
|
||
elif add is False:
|
||
log.info(
|
||
f"collection {collection_name} already exists, overwrite is False and add is False"
|
||
)
|
||
return True
|
||
|
||
log.info(f"adding to collection {collection_name}")
|
||
embedding_function = get_embedding_function(
|
||
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||
request.app.state.config.RAG_EMBEDDING_MODEL,
|
||
request.app.state.ef,
|
||
(
|
||
request.app.state.config.RAG_OPENAI_API_BASE_URL
|
||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||
else (
|
||
request.app.state.config.RAG_OLLAMA_BASE_URL
|
||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
|
||
else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL
|
||
)
|
||
),
|
||
(
|
||
request.app.state.config.RAG_OPENAI_API_KEY
|
||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||
else (
|
||
request.app.state.config.RAG_OLLAMA_API_KEY
|
||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
|
||
else request.app.state.config.RAG_AZURE_OPENAI_API_KEY
|
||
)
|
||
),
|
||
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||
azure_api_version=(
|
||
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION
|
||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
|
||
else None
|
||
),
|
||
)
|
||
|
||
# Prepare texts for embedding using the new modular cleaning system
|
||
cleaned_texts = [TextCleaner.clean_for_embedding(text) for text in texts]
|
||
|
||
embeddings = embedding_function(
|
||
cleaned_texts,
|
||
prefix=RAG_EMBEDDING_CONTENT_PREFIX,
|
||
user=user,
|
||
)
|
||
|
||
# Store the cleaned text using the new modular cleaning system
|
||
items = []
|
||
for idx in range(len(texts)):
|
||
# Apply consistent storage-level cleaning
|
||
text_to_store = TextCleaner.clean_for_storage(texts[idx])
|
||
|
||
items.append(
|
||
{
|
||
"id": str(uuid.uuid4()),
|
||
"text": text_to_store,
|
||
"vector": embeddings[idx],
|
||
"metadata": metadatas[idx],
|
||
}
|
||
)
|
||
|
||
VECTOR_DB_CLIENT.insert(
|
||
collection_name=collection_name,
|
||
items=items,
|
||
)
|
||
|
||
return True
|
||
except Exception as e:
|
||
log.exception(e)
|
||
raise e
|
||
|
||
|
||
class ProcessFileForm(BaseModel):
|
||
file_id: str
|
||
content: Optional[str] = None
|
||
collection_name: Optional[str] = None
|
||
|
||
|
||
@router.post("/process/file")
|
||
def process_file(
|
||
request: Request,
|
||
form_data: ProcessFileForm,
|
||
user=Depends(get_verified_user),
|
||
):
|
||
try:
|
||
file = Files.get_file_by_id(form_data.file_id)
|
||
|
||
collection_name = form_data.collection_name
|
||
|
||
if collection_name is None:
|
||
collection_name = f"file-{file.id}"
|
||
|
||
if form_data.content:
|
||
# Update the content in the file
|
||
# Usage: /files/{file_id}/data/content/update, /files/ (audio file upload pipeline)
|
||
|
||
try:
|
||
# /files/{file_id}/data/content/update
|
||
VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}")
|
||
except:
|
||
# Audio file upload pipeline
|
||
pass
|
||
|
||
docs = [
|
||
Document(
|
||
page_content=TextCleaner.clean_for_chunking(
|
||
form_data.content.replace("<br/>", "\n")
|
||
),
|
||
metadata={
|
||
**file.meta,
|
||
"name": file.filename,
|
||
"created_by": file.user_id,
|
||
"file_id": file.id,
|
||
"source": file.filename,
|
||
},
|
||
)
|
||
]
|
||
|
||
text_content = form_data.content
|
||
elif form_data.collection_name:
|
||
# Check if the file has already been processed and save the content
|
||
# Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update
|
||
|
||
result = VECTOR_DB_CLIENT.query(
|
||
collection_name=f"file-{file.id}", filter={"file_id": file.id}
|
||
)
|
||
|
||
if result is not None and len(result.ids[0]) > 0:
|
||
docs = [
|
||
Document(
|
||
page_content=TextCleaner.clean_for_chunking(
|
||
result.documents[0][idx]
|
||
),
|
||
metadata=result.metadatas[0][idx],
|
||
)
|
||
for idx, id in enumerate(result.ids[0])
|
||
]
|
||
else:
|
||
docs = [
|
||
Document(
|
||
page_content=TextCleaner.clean_for_chunking(
|
||
file.data.get("content", "")
|
||
),
|
||
metadata={
|
||
**file.meta,
|
||
"name": file.filename,
|
||
"created_by": file.user_id,
|
||
"file_id": file.id,
|
||
"source": file.filename,
|
||
},
|
||
)
|
||
]
|
||
|
||
text_content = file.data.get("content", "")
|
||
else:
|
||
# Process the file and save the content
|
||
# Usage: /files/
|
||
file_path = file.path
|
||
if file_path:
|
||
file_path = Storage.get_file(file_path)
|
||
loader = Loader(
|
||
engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||
DATALAB_MARKER_API_KEY=request.app.state.config.DATALAB_MARKER_API_KEY,
|
||
DATALAB_MARKER_LANGS=request.app.state.config.DATALAB_MARKER_LANGS,
|
||
DATALAB_MARKER_SKIP_CACHE=request.app.state.config.DATALAB_MARKER_SKIP_CACHE,
|
||
DATALAB_MARKER_FORCE_OCR=request.app.state.config.DATALAB_MARKER_FORCE_OCR,
|
||
DATALAB_MARKER_PAGINATE=request.app.state.config.DATALAB_MARKER_PAGINATE,
|
||
DATALAB_MARKER_STRIP_EXISTING_OCR=request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR,
|
||
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION=request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
|
||
DATALAB_MARKER_USE_LLM=request.app.state.config.DATALAB_MARKER_USE_LLM,
|
||
DATALAB_MARKER_OUTPUT_FORMAT=request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT,
|
||
EXTERNAL_DOCUMENT_LOADER_URL=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
|
||
EXTERNAL_DOCUMENT_LOADER_API_KEY=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
|
||
TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
|
||
DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
|
||
DOCLING_OCR_ENGINE=request.app.state.config.DOCLING_OCR_ENGINE,
|
||
DOCLING_OCR_LANG=request.app.state.config.DOCLING_OCR_LANG,
|
||
DOCLING_DO_PICTURE_DESCRIPTION=request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
|
||
PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
|
||
DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||
DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||
MISTRAL_OCR_API_KEY=request.app.state.config.MISTRAL_OCR_API_KEY,
|
||
)
|
||
docs = loader.load(
|
||
file.filename, file.meta.get("content_type"), file_path
|
||
)
|
||
|
||
# Clean the loaded documents before processing
|
||
cleaned_docs = []
|
||
for doc in docs:
|
||
cleaned_content = TextCleaner.clean_for_chunking(doc.page_content)
|
||
|
||
cleaned_docs.append(
|
||
Document(
|
||
page_content=cleaned_content,
|
||
metadata={
|
||
**doc.metadata,
|
||
"name": file.filename,
|
||
"created_by": file.user_id,
|
||
"file_id": file.id,
|
||
"source": file.filename,
|
||
},
|
||
)
|
||
)
|
||
docs = cleaned_docs
|
||
else:
|
||
docs = [
|
||
Document(
|
||
page_content=TextCleaner.clean_for_chunking(
|
||
file.data.get("content", "")
|
||
),
|
||
metadata={
|
||
**file.meta,
|
||
"name": file.filename,
|
||
"created_by": file.user_id,
|
||
"file_id": file.id,
|
||
"source": file.filename,
|
||
},
|
||
)
|
||
]
|
||
text_content = " ".join(
|
||
[doc.page_content for doc in docs if doc.page_content]
|
||
)
|
||
|
||
# Ensure text_content is never None or empty for hash calculation
|
||
if not text_content:
|
||
text_content = ""
|
||
|
||
log.debug(f"text_content: {text_content}")
|
||
Files.update_file_data_by_id(
|
||
file.id,
|
||
{"content": text_content},
|
||
)
|
||
|
||
# Ensure we always pass a valid string to calculate_sha256_string
|
||
hash_input = text_content if text_content else ""
|
||
hash = calculate_sha256_string(hash_input)
|
||
Files.update_file_hash_by_id(file.id, hash)
|
||
|
||
if not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
|
||
try:
|
||
result = save_docs_to_vector_db(
|
||
request,
|
||
docs=docs,
|
||
collection_name=collection_name,
|
||
metadata={
|
||
"file_id": file.id,
|
||
"name": file.filename,
|
||
"hash": hash,
|
||
},
|
||
add=(True if form_data.collection_name else False),
|
||
user=user,
|
||
)
|
||
|
||
if result:
|
||
Files.update_file_metadata_by_id(
|
||
file.id,
|
||
{
|
||
"collection_name": collection_name,
|
||
},
|
||
)
|
||
|
||
return {
|
||
"status": True,
|
||
"collection_name": collection_name,
|
||
"filename": file.filename,
|
||
"content": text_content,
|
||
}
|
||
except Exception as e:
|
||
raise e
|
||
else:
|
||
return {
|
||
"status": True,
|
||
"collection_name": None,
|
||
"filename": file.filename,
|
||
"content": text_content,
|
||
}
|
||
|
||
except Exception as e:
|
||
log.exception(e)
|
||
if "No pandoc was found" in str(e):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
|
||
)
|
||
else:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=str(e),
|
||
)
|
||
|
||
|
||
class ProcessTextForm(BaseModel):
|
||
name: str
|
||
content: str
|
||
collection_name: Optional[str] = None
|
||
|
||
|
||
@router.post("/process/text")
|
||
def process_text(
|
||
request: Request,
|
||
form_data: ProcessTextForm,
|
||
user=Depends(get_verified_user),
|
||
):
|
||
collection_name = form_data.collection_name
|
||
if collection_name is None:
|
||
collection_name = calculate_sha256_string(form_data.content)
|
||
|
||
docs = [
|
||
Document(
|
||
page_content=TextCleaner.clean_for_chunking(form_data.content),
|
||
metadata={"name": form_data.name, "created_by": user.id},
|
||
)
|
||
]
|
||
text_content = form_data.content
|
||
log.debug(f"text_content: {text_content}")
|
||
|
||
result = save_docs_to_vector_db(request, docs, collection_name, user=user)
|
||
if result:
|
||
return {
|
||
"status": True,
|
||
"collection_name": collection_name,
|
||
"content": text_content,
|
||
}
|
||
else:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=ERROR_MESSAGES.DEFAULT(),
|
||
)
|
||
|
||
|
||
@router.post("/process/youtube")
|
||
def process_youtube_video(
|
||
request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user)
|
||
):
|
||
try:
|
||
collection_name = form_data.collection_name
|
||
if not collection_name:
|
||
collection_name = calculate_sha256_string(form_data.url)[:63]
|
||
|
||
loader = YoutubeLoader(
|
||
form_data.url,
|
||
language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
|
||
proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
||
)
|
||
|
||
docs = loader.load()
|
||
content = " ".join([doc.page_content for doc in docs])
|
||
log.debug(f"text_content: {content}")
|
||
|
||
save_docs_to_vector_db(
|
||
request, docs, collection_name, overwrite=True, user=user
|
||
)
|
||
|
||
return {
|
||
"status": True,
|
||
"collection_name": collection_name,
|
||
"filename": form_data.url,
|
||
"file": {
|
||
"data": {
|
||
"content": content,
|
||
},
|
||
"meta": {
|
||
"name": form_data.url,
|
||
},
|
||
},
|
||
}
|
||
except Exception as e:
|
||
log.exception(e)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||
)
|
||
|
||
|
||
@router.post("/process/web")
|
||
def process_web(
|
||
request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user)
|
||
):
|
||
try:
|
||
collection_name = form_data.collection_name
|
||
if not collection_name:
|
||
collection_name = calculate_sha256_string(form_data.url)[:63]
|
||
|
||
loader = get_web_loader(
|
||
form_data.url,
|
||
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
||
requests_per_second=request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
|
||
)
|
||
docs = loader.load()
|
||
content = " ".join([doc.page_content for doc in docs])
|
||
|
||
log.debug(f"text_content: {content}")
|
||
|
||
if not request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
|
||
save_docs_to_vector_db(
|
||
request, docs, collection_name, overwrite=True, user=user
|
||
)
|
||
else:
|
||
collection_name = None
|
||
|
||
return {
|
||
"status": True,
|
||
"collection_name": collection_name,
|
||
"filename": form_data.url,
|
||
"file": {
|
||
"data": {
|
||
"content": content,
|
||
},
|
||
"meta": {
|
||
"name": form_data.url,
|
||
"source": form_data.url,
|
||
},
|
||
},
|
||
}
|
||
except Exception as e:
|
||
log.exception(e)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||
)
|
||
|
||
|
||
def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||
"""Search the web using a search engine and return the results as a list of SearchResult objects.
|
||
Will look for a search engine API key in environment variables in the following order:
|
||
- SEARXNG_QUERY_URL
|
||
- YACY_QUERY_URL + YACY_USERNAME + YACY_PASSWORD
|
||
- GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
|
||
- BRAVE_SEARCH_API_KEY
|
||
- KAGI_SEARCH_API_KEY
|
||
- MOJEEK_SEARCH_API_KEY
|
||
- BOCHA_SEARCH_API_KEY
|
||
- SERPSTACK_API_KEY
|
||
- SERPER_API_KEY
|
||
- SERPLY_API_KEY
|
||
- TAVILY_API_KEY
|
||
- EXA_API_KEY
|
||
- PERPLEXITY_API_KEY
|
||
- SOUGOU_API_SID + SOUGOU_API_SK
|
||
- SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`)
|
||
- SERPAPI_API_KEY + SERPAPI_ENGINE (by default `google`)
|
||
Args:
|
||
query (str): The query to search for
|
||
"""
|
||
|
||
# TODO: add playwright to search the web
|
||
if engine == "searxng":
|
||
if request.app.state.config.SEARXNG_QUERY_URL:
|
||
return search_searxng(
|
||
request.app.state.config.SEARXNG_QUERY_URL,
|
||
query,
|
||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||
)
|
||
else:
|
||
raise Exception("No SEARXNG_QUERY_URL found in environment variables")
|
||
elif engine == "yacy":
|
||
if request.app.state.config.YACY_QUERY_URL:
|
||
return search_yacy(
|
||
request.app.state.config.YACY_QUERY_URL,
|
||
request.app.state.config.YACY_USERNAME,
|
||
request.app.state.config.YACY_PASSWORD,
|
||
query,
|
||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||
)
|
||
else:
|
||
raise Exception("No YACY_QUERY_URL found in environment variables")
|
||
elif engine == "google_pse":
|
||
if (
|
||
request.app.state.config.GOOGLE_PSE_API_KEY
|
||
and request.app.state.config.GOOGLE_PSE_ENGINE_ID
|
||
):
|
||
return search_google_pse(
|
||
request.app.state.config.GOOGLE_PSE_API_KEY,
|
||
request.app.state.config.GOOGLE_PSE_ENGINE_ID,
|
||
query,
|
||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||
)
|
||
else:
|
||
raise Exception(
|
||
"No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
|
||
)
|
||
elif engine == "brave":
|
||
if request.app.state.config.BRAVE_SEARCH_API_KEY:
|
||
return search_brave(
|
||
request.app.state.config.BRAVE_SEARCH_API_KEY,
|
||
query,
|
||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||
)
|
||
else:
|
||
raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
|
||
elif engine == "kagi":
|
||
if request.app.state.config.KAGI_SEARCH_API_KEY:
|
||
return search_kagi(
|
||
request.app.state.config.KAGI_SEARCH_API_KEY,
|
||
query,
|
||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||
)
|
||
else:
|
||
raise Exception("No KAGI_SEARCH_API_KEY found in environment variables")
|
||
elif engine == "mojeek":
|
||
if request.app.state.config.MOJEEK_SEARCH_API_KEY:
|
||
return search_mojeek(
|
||
request.app.state.config.MOJEEK_SEARCH_API_KEY,
|
||
query,
|
||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||
)
|
||
else:
|
||
raise Exception("No MOJEEK_SEARCH_API_KEY found in environment variables")
|
||
elif engine == "bocha":
|
||
if request.app.state.config.BOCHA_SEARCH_API_KEY:
|
||
return search_bocha(
|
||
request.app.state.config.BOCHA_SEARCH_API_KEY,
|
||
query,
|
||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||
)
|
||
else:
|
||
raise Exception("No BOCHA_SEARCH_API_KEY found in environment variables")
|
||
elif engine == "serpstack":
|
||
if request.app.state.config.SERPSTACK_API_KEY:
|
||
return search_serpstack(
|
||
request.app.state.config.SERPSTACK_API_KEY,
|
||
query,
|
||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||
https_enabled=request.app.state.config.SERPSTACK_HTTPS,
|
||
)
|
||
else:
|
||
raise Exception("No SERPSTACK_API_KEY found in environment variables")
|
||
elif engine == "serper":
|
||
if request.app.state.config.SERPER_API_KEY:
|
||
return search_serper(
|
||
request.app.state.config.SERPER_API_KEY,
|
||
query,
|
||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||
)
|
||
else:
|
||
raise Exception("No SERPER_API_KEY found in environment variables")
|
||
elif engine == "serply":
|
||
if request.app.state.config.SERPLY_API_KEY:
|
||
return search_serply(
|
||
request.app.state.config.SERPLY_API_KEY,
|
||
query,
|
||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||
)
|
||
else:
|
||
raise Exception("No SERPLY_API_KEY found in environment variables")
|
||
elif engine == "duckduckgo":
|
||
return search_duckduckgo(
|
||
query,
|
||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||
)
|
||
elif engine == "tavily":
|
||
if request.app.state.config.TAVILY_API_KEY:
|
||
return search_tavily(
|
||
request.app.state.config.TAVILY_API_KEY,
|
||
query,
|
||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||
)
|
||
else:
|
||
raise Exception("No TAVILY_API_KEY found in environment variables")
|
||
elif engine == "searchapi":
|
||
if request.app.state.config.SEARCHAPI_API_KEY:
|
||
return search_searchapi(
|
||
request.app.state.config.SEARCHAPI_API_KEY,
|
||
request.app.state.config.SEARCHAPI_ENGINE,
|
||
query,
|
||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||
)
|
||
else:
|
||
raise Exception("No SEARCHAPI_API_KEY found in environment variables")
|
||
elif engine == "serpapi":
|
||
if request.app.state.config.SERPAPI_API_KEY:
|
||
return search_serpapi(
|
||
request.app.state.config.SERPAPI_API_KEY,
|
||
request.app.state.config.SERPAPI_ENGINE,
|
||
query,
|
||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||
)
|
||
else:
|
||
raise Exception("No SERPAPI_API_KEY found in environment variables")
|
||
elif engine == "jina":
|
||
return search_jina(
|
||
request.app.state.config.JINA_API_KEY,
|
||
query,
|
||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||
)
|
||
elif engine == "bing":
|
||
return search_bing(
|
||
request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||
request.app.state.config.BING_SEARCH_V7_ENDPOINT,
|
||
str(DEFAULT_LOCALE),
|
||
query,
|
||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||
)
|
||
elif engine == "exa":
|
||
return search_exa(
|
||
request.app.state.config.EXA_API_KEY,
|
||
query,
|
||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||
)
|
||
elif engine == "perplexity":
|
||
return search_perplexity(
|
||
request.app.state.config.PERPLEXITY_API_KEY,
|
||
query,
|
||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||
)
|
||
elif engine == "sougou":
|
||
if (
|
||
request.app.state.config.SOUGOU_API_SID
|
||
and request.app.state.config.SOUGOU_API_SK
|
||
):
|
||
return search_sougou(
|
||
request.app.state.config.SOUGOU_API_SID,
|
||
request.app.state.config.SOUGOU_API_SK,
|
||
query,
|
||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||
)
|
||
else:
|
||
raise Exception(
|
||
"No SOUGOU_API_SID or SOUGOU_API_SK found in environment variables"
|
||
)
|
||
elif engine == "firecrawl":
|
||
return search_firecrawl(
|
||
request.app.state.config.FIRECRAWL_API_BASE_URL,
|
||
request.app.state.config.FIRECRAWL_API_KEY,
|
||
query,
|
||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||
)
|
||
elif engine == "external":
|
||
return search_external(
|
||
request.app.state.config.EXTERNAL_WEB_SEARCH_URL,
|
||
request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY,
|
||
query,
|
||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||
)
|
||
else:
|
||
raise Exception("No search engine API key found in environment variables")
|
||
|
||
|
||
@router.post("/process/web/search")
|
||
async def process_web_search(
|
||
request: Request, form_data: SearchForm, user=Depends(get_verified_user)
|
||
):
|
||
|
||
urls = []
|
||
try:
|
||
logging.info(
|
||
f"trying to web search with {request.app.state.config.WEB_SEARCH_ENGINE, form_data.queries}"
|
||
)
|
||
|
||
search_tasks = [
|
||
run_in_threadpool(
|
||
search_web,
|
||
request,
|
||
request.app.state.config.WEB_SEARCH_ENGINE,
|
||
query,
|
||
)
|
||
for query in form_data.queries
|
||
]
|
||
|
||
search_results = await asyncio.gather(*search_tasks)
|
||
|
||
for result in search_results:
|
||
if result:
|
||
for item in result:
|
||
if item and item.link:
|
||
urls.append(item.link)
|
||
|
||
urls = list(dict.fromkeys(urls))
|
||
log.debug(f"urls: {urls}")
|
||
|
||
except Exception as e:
|
||
log.exception(e)
|
||
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
|
||
)
|
||
|
||
try:
|
||
if request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER:
|
||
docs = [
|
||
Document(
|
||
page_content=result.snippet,
|
||
metadata={
|
||
"source": result.link,
|
||
"title": result.title,
|
||
"snippet": result.snippet,
|
||
"link": result.link,
|
||
},
|
||
)
|
||
for result in search_results
|
||
if hasattr(result, "snippet")
|
||
]
|
||
else:
|
||
loader = get_web_loader(
|
||
urls,
|
||
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
||
requests_per_second=request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
|
||
trust_env=request.app.state.config.WEB_SEARCH_TRUST_ENV,
|
||
)
|
||
docs = await loader.aload()
|
||
|
||
urls = [
|
||
doc.metadata.get("source") for doc in docs if doc.metadata.get("source")
|
||
] # only keep the urls returned by the loader
|
||
|
||
if request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
|
||
return {
|
||
"status": True,
|
||
"collection_name": None,
|
||
"filenames": urls,
|
||
"docs": [
|
||
{
|
||
"content": doc.page_content,
|
||
"metadata": doc.metadata,
|
||
}
|
||
for doc in docs
|
||
],
|
||
"loaded_count": len(docs),
|
||
}
|
||
else:
|
||
# Create a single collection for all documents
|
||
collection_name = (
|
||
f"web-search-{calculate_sha256_string('-'.join(form_data.queries))}"[
|
||
:63
|
||
]
|
||
)
|
||
|
||
try:
|
||
await run_in_threadpool(
|
||
save_docs_to_vector_db,
|
||
request,
|
||
docs,
|
||
collection_name,
|
||
overwrite=True,
|
||
user=user,
|
||
)
|
||
except Exception as e:
|
||
log.debug(f"error saving docs: {e}")
|
||
|
||
return {
|
||
"status": True,
|
||
"collection_names": [collection_name],
|
||
"filenames": urls,
|
||
"loaded_count": len(docs),
|
||
}
|
||
except Exception as e:
|
||
log.exception(e)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||
)
|
||
|
||
|
||
class QueryDocForm(BaseModel):
|
||
collection_name: str
|
||
query: str
|
||
k: Optional[int] = None
|
||
k_reranker: Optional[int] = None
|
||
r: Optional[float] = None
|
||
hybrid: Optional[bool] = None
|
||
|
||
|
||
@router.post("/query/doc")
|
||
def query_doc_handler(
|
||
request: Request,
|
||
form_data: QueryDocForm,
|
||
user=Depends(get_verified_user),
|
||
):
|
||
try:
|
||
if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
|
||
collection_results = {}
|
||
collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get(
|
||
collection_name=form_data.collection_name
|
||
)
|
||
return query_doc_with_hybrid_search(
|
||
collection_name=form_data.collection_name,
|
||
collection_result=collection_results[form_data.collection_name],
|
||
query=form_data.query,
|
||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
|
||
query, prefix=prefix, user=user
|
||
),
|
||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||
reranking_function=request.app.state.rf,
|
||
k_reranker=form_data.k_reranker
|
||
or request.app.state.config.TOP_K_RERANKER,
|
||
r=(
|
||
form_data.r
|
||
if form_data.r
|
||
else request.app.state.config.RELEVANCE_THRESHOLD
|
||
),
|
||
hybrid_bm25_weight=(
|
||
form_data.hybrid_bm25_weight
|
||
if form_data.hybrid_bm25_weight
|
||
else request.app.state.config.HYBRID_BM25_WEIGHT
|
||
),
|
||
user=user,
|
||
)
|
||
else:
|
||
return query_doc(
|
||
collection_name=form_data.collection_name,
|
||
query_embedding=request.app.state.EMBEDDING_FUNCTION(
|
||
form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user
|
||
),
|
||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||
user=user,
|
||
)
|
||
except Exception as e:
|
||
log.exception(e)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||
)
|
||
|
||
|
||
class QueryCollectionsForm(BaseModel):
|
||
collection_names: list[str]
|
||
query: str
|
||
k: Optional[int] = None
|
||
k_reranker: Optional[int] = None
|
||
r: Optional[float] = None
|
||
hybrid: Optional[bool] = None
|
||
hybrid_bm25_weight: Optional[float] = None
|
||
|
||
|
||
@router.post("/query/collection")
|
||
def query_collection_handler(
|
||
request: Request,
|
||
form_data: QueryCollectionsForm,
|
||
user=Depends(get_verified_user),
|
||
):
|
||
try:
|
||
if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
|
||
return query_collection_with_hybrid_search(
|
||
collection_names=form_data.collection_names,
|
||
queries=[form_data.query],
|
||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
|
||
query, prefix=prefix, user=user
|
||
),
|
||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||
reranking_function=request.app.state.rf,
|
||
k_reranker=form_data.k_reranker
|
||
or request.app.state.config.TOP_K_RERANKER,
|
||
r=(
|
||
form_data.r
|
||
if form_data.r
|
||
else request.app.state.config.RELEVANCE_THRESHOLD
|
||
),
|
||
hybrid_bm25_weight=(
|
||
form_data.hybrid_bm25_weight
|
||
if form_data.hybrid_bm25_weight
|
||
else request.app.state.config.HYBRID_BM25_WEIGHT
|
||
),
|
||
)
|
||
else:
|
||
return query_collection(
|
||
collection_names=form_data.collection_names,
|
||
queries=[form_data.query],
|
||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
|
||
query, prefix=prefix, user=user
|
||
),
|
||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||
)
|
||
|
||
except Exception as e:
|
||
log.exception(e)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||
)
|
||
|
||
|
||
####################################
|
||
#
|
||
# Vector DB operations
|
||
#
|
||
####################################
|
||
|
||
|
||
class DeleteForm(BaseModel):
|
||
collection_name: str
|
||
file_id: str
|
||
|
||
|
||
@router.post("/delete")
|
||
def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)):
|
||
try:
|
||
if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name):
|
||
file = Files.get_file_by_id(form_data.file_id)
|
||
hash = file.hash
|
||
|
||
VECTOR_DB_CLIENT.delete(
|
||
collection_name=form_data.collection_name,
|
||
metadata={"hash": hash},
|
||
)
|
||
return {"status": True}
|
||
else:
|
||
return {"status": False}
|
||
except Exception as e:
|
||
log.exception(e)
|
||
return {"status": False}
|
||
|
||
|
||
@router.post("/reset/db")
|
||
def reset_vector_db(user=Depends(get_admin_user)):
|
||
VECTOR_DB_CLIENT.reset()
|
||
Knowledges.delete_all_knowledge()
|
||
|
||
|
||
@router.post("/reset/uploads")
|
||
def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
|
||
folder = f"{UPLOAD_DIR}"
|
||
try:
|
||
# Check if the directory exists
|
||
if os.path.exists(folder):
|
||
# Iterate over all the files and directories in the specified directory
|
||
for filename in os.listdir(folder):
|
||
file_path = os.path.join(folder, filename)
|
||
try:
|
||
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||
os.unlink(file_path) # Remove the file or link
|
||
elif os.path.isdir(file_path):
|
||
shutil.rmtree(file_path) # Remove the directory
|
||
except Exception as e:
|
||
log.exception(f"Failed to delete {file_path}. Reason: {e}")
|
||
else:
|
||
log.warning(f"The directory {folder} does not exist")
|
||
except Exception as e:
|
||
log.exception(f"Failed to process the directory {folder}. Reason: {e}")
|
||
return True
|
||
|
||
|
||
if ENV == "dev":
|
||
|
||
@router.get("/ef/{text}")
|
||
async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"):
|
||
return {
|
||
"result": request.app.state.EMBEDDING_FUNCTION(
|
||
text, prefix=RAG_EMBEDDING_QUERY_PREFIX
|
||
)
|
||
}
|
||
|
||
|
||
class BatchProcessFilesForm(BaseModel):
|
||
files: List[FileModel]
|
||
collection_name: str
|
||
|
||
|
||
class BatchProcessFilesResult(BaseModel):
|
||
file_id: str
|
||
status: str
|
||
error: Optional[str] = None
|
||
|
||
|
||
class BatchProcessFilesResponse(BaseModel):
|
||
results: List[BatchProcessFilesResult]
|
||
errors: List[BatchProcessFilesResult]
|
||
|
||
|
||
@router.post("/process/files/batch")
|
||
def process_files_batch(
|
||
request: Request,
|
||
form_data: BatchProcessFilesForm,
|
||
user=Depends(get_verified_user),
|
||
) -> BatchProcessFilesResponse:
|
||
"""
|
||
Process a batch of files and save them to the vector database.
|
||
"""
|
||
results: List[BatchProcessFilesResult] = []
|
||
errors: List[BatchProcessFilesResult] = []
|
||
collection_name = form_data.collection_name
|
||
|
||
# Prepare all documents first
|
||
all_docs: List[Document] = []
|
||
for file in form_data.files:
|
||
try:
|
||
text_content = file.data.get("content", "")
|
||
|
||
docs: List[Document] = [
|
||
Document(
|
||
page_content=TextCleaner.clean_for_chunking(
|
||
text_content.replace("<br/>", "\n")
|
||
),
|
||
metadata={
|
||
**file.meta,
|
||
"name": file.filename,
|
||
"created_by": file.user_id,
|
||
"file_id": file.id,
|
||
"source": file.filename,
|
||
},
|
||
)
|
||
]
|
||
|
||
hash = calculate_sha256_string(text_content or "")
|
||
Files.update_file_hash_by_id(file.id, hash)
|
||
Files.update_file_data_by_id(file.id, {"content": text_content})
|
||
|
||
all_docs.extend(docs)
|
||
results.append(BatchProcessFilesResult(file_id=file.id, status="prepared"))
|
||
|
||
except Exception as e:
|
||
log.error(f"process_files_batch: Error processing file {file.id}: {str(e)}")
|
||
errors.append(
|
||
BatchProcessFilesResult(file_id=file.id, status="failed", error=str(e))
|
||
)
|
||
|
||
# Save all documents in one batch
|
||
if all_docs:
|
||
try:
|
||
save_docs_to_vector_db(
|
||
request=request,
|
||
docs=all_docs,
|
||
collection_name=collection_name,
|
||
add=True,
|
||
user=user,
|
||
)
|
||
|
||
# Update all files with collection name
|
||
for result in results:
|
||
Files.update_file_metadata_by_id(
|
||
result.file_id, {"collection_name": collection_name}
|
||
)
|
||
result.status = "completed"
|
||
|
||
except Exception as e:
|
||
log.error(
|
||
f"process_files_batch: Error saving documents to vector DB: {str(e)}"
|
||
)
|
||
for result in results:
|
||
result.status = "failed"
|
||
errors.append(
|
||
BatchProcessFilesResult(file_id=result.file_id, error=str(e))
|
||
)
|
||
|
||
return BatchProcessFilesResponse(results=results, errors=errors)
|
||
|
||
|
||
def delete_file_from_vector_db(file_id: str) -> bool:
|
||
"""
|
||
Delete all vector embeddings for a specific file from the vector database.
|
||
This function works with any vector database (Pinecone, ChromaDB, etc.) and
|
||
handles the cleanup when a file is deleted from the chat.
|
||
|
||
Args:
|
||
file_id (str): The ID of the file to delete from vector database
|
||
|
||
Returns:
|
||
bool: True if deletion was successful, False otherwise
|
||
"""
|
||
try:
|
||
# Get the file record to access its hash and collection info
|
||
file = Files.get_file_by_id(file_id)
|
||
if not file:
|
||
return False
|
||
|
||
# Get the file hash for vector deletion
|
||
file_hash = file.hash
|
||
if not file_hash:
|
||
return False
|
||
|
||
# Try to get collection name from file metadata
|
||
collection_name = None
|
||
if hasattr(file, "meta") and file.meta:
|
||
collection_name = file.meta.get("collection_name")
|
||
|
||
# If no collection name in metadata, try common patterns used by Open WebUI
|
||
if not collection_name:
|
||
# Open WebUI typically uses these patterns:
|
||
possible_collections = [
|
||
f"open-webui_file-{file_id}", # Most common pattern
|
||
f"file-{file_id}", # Alternative pattern
|
||
f"open-webui_{file_id}", # Another possible pattern
|
||
]
|
||
|
||
# Try each possible collection name
|
||
for possible_collection in possible_collections:
|
||
try:
|
||
if VECTOR_DB_CLIENT.has_collection(
|
||
collection_name=possible_collection
|
||
):
|
||
result = VECTOR_DB_CLIENT.delete(
|
||
collection_name=possible_collection,
|
||
filter={"hash": file_hash},
|
||
)
|
||
# Pinecone returns None on successful deletion
|
||
return True
|
||
except Exception as e:
|
||
continue
|
||
|
||
# If none of the standard patterns work, try searching through all collections
|
||
try:
|
||
deleted_count = 0
|
||
|
||
# Get all collections (this method varies by vector DB implementation)
|
||
if hasattr(VECTOR_DB_CLIENT, "list_collections"):
|
||
try:
|
||
collections = VECTOR_DB_CLIENT.list_collections()
|
||
|
||
for collection in collections:
|
||
try:
|
||
if VECTOR_DB_CLIENT.has_collection(
|
||
collection_name=collection
|
||
):
|
||
result = VECTOR_DB_CLIENT.delete(
|
||
collection_name=collection,
|
||
filter={"hash": file_hash},
|
||
)
|
||
# Pinecone returns None on successful deletion, so any non-exception means success
|
||
deleted_count += 1
|
||
except Exception as e:
|
||
continue
|
||
except Exception as e:
|
||
pass
|
||
|
||
return deleted_count > 0
|
||
|
||
except Exception as e:
|
||
return False
|
||
|
||
# Delete from the specific collection found in metadata
|
||
if collection_name and VECTOR_DB_CLIENT.has_collection(
|
||
collection_name=collection_name
|
||
):
|
||
try:
|
||
result = VECTOR_DB_CLIENT.delete(
|
||
collection_name=collection_name,
|
||
filter={"hash": file_hash},
|
||
)
|
||
# Pinecone returns None on successful deletion, so we check for no exception
|
||
# rather than checking the return value
|
||
return True
|
||
except Exception as e:
|
||
return False
|
||
else:
|
||
return False
|
||
|
||
except Exception as e:
|
||
return False
|