mirror of
https://github.com/open-webui/open-webui
synced 2025-06-22 18:07:17 +00:00
2686 lines
122 KiB
Python
2686 lines
122 KiB
Python
import json
|
|
import logging
|
|
import mimetypes
|
|
import os
|
|
import shutil
|
|
import asyncio
|
|
|
|
|
|
import uuid
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Iterator, List, Optional, Sequence, Union
|
|
|
|
from fastapi import (
|
|
Depends,
|
|
FastAPI,
|
|
File,
|
|
Form,
|
|
HTTPException,
|
|
UploadFile,
|
|
Request,
|
|
status,
|
|
APIRouter,
|
|
)
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.concurrency import run_in_threadpool
|
|
from pydantic import BaseModel
|
|
import tiktoken
|
|
|
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
|
|
from langchain_core.documents import Document
|
|
|
|
from open_webui.models.files import FileModel, Files
|
|
from open_webui.models.knowledge import Knowledges
|
|
from open_webui.storage.provider import Storage
|
|
|
|
|
|
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
|
|
|
# Document loaders
|
|
from open_webui.retrieval.loaders.main import Loader
|
|
from open_webui.retrieval.loaders.youtube import YoutubeLoader
|
|
|
|
# Web search engines
|
|
from open_webui.retrieval.web.main import SearchResult
|
|
from open_webui.retrieval.web.utils import get_web_loader
|
|
from open_webui.retrieval.web.brave import search_brave
|
|
from open_webui.retrieval.web.kagi import search_kagi
|
|
from open_webui.retrieval.web.mojeek import search_mojeek
|
|
from open_webui.retrieval.web.bocha import search_bocha
|
|
from open_webui.retrieval.web.duckduckgo import search_duckduckgo
|
|
from open_webui.retrieval.web.google_pse import search_google_pse
|
|
from open_webui.retrieval.web.jina_search import search_jina
|
|
from open_webui.retrieval.web.searchapi import search_searchapi
|
|
from open_webui.retrieval.web.serpapi import search_serpapi
|
|
from open_webui.retrieval.web.searxng import search_searxng
|
|
from open_webui.retrieval.web.yacy import search_yacy
|
|
from open_webui.retrieval.web.serper import search_serper
|
|
from open_webui.retrieval.web.serply import search_serply
|
|
from open_webui.retrieval.web.serpstack import search_serpstack
|
|
from open_webui.retrieval.web.tavily import search_tavily
|
|
from open_webui.retrieval.web.bing import search_bing
|
|
from open_webui.retrieval.web.exa import search_exa
|
|
from open_webui.retrieval.web.perplexity import search_perplexity
|
|
from open_webui.retrieval.web.sougou import search_sougou
|
|
from open_webui.retrieval.web.firecrawl import search_firecrawl
|
|
from open_webui.retrieval.web.external import search_external
|
|
|
|
from open_webui.retrieval.utils import (
|
|
get_embedding_function,
|
|
get_model_path,
|
|
query_collection,
|
|
query_collection_with_hybrid_search,
|
|
query_doc,
|
|
query_doc_with_hybrid_search,
|
|
)
|
|
from open_webui.utils.misc import (
|
|
calculate_sha256_string,
|
|
)
|
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
|
|
|
from open_webui.config import (
|
|
ENV,
|
|
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
|
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
|
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
|
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
|
UPLOAD_DIR,
|
|
DEFAULT_LOCALE,
|
|
RAG_EMBEDDING_CONTENT_PREFIX,
|
|
RAG_EMBEDDING_QUERY_PREFIX,
|
|
)
|
|
from open_webui.env import (
|
|
SRC_LOG_LEVELS,
|
|
DEVICE_TYPE,
|
|
DOCKER,
|
|
SENTENCE_TRANSFORMERS_BACKEND,
|
|
SENTENCE_TRANSFORMERS_MODEL_KWARGS,
|
|
SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND,
|
|
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS,
|
|
)
|
|
|
|
from open_webui.constants import ERROR_MESSAGES
|
|
|
|
log = logging.getLogger(__name__)
|
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
|
|
|
##########################################
|
|
#
|
|
# Utility functions
|
|
#
|
|
##########################################
|
|
|
|
|
|
def get_ef(
|
|
engine: str,
|
|
embedding_model: str,
|
|
auto_update: bool = False,
|
|
):
|
|
ef = None
|
|
if embedding_model and engine == "":
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
try:
|
|
ef = SentenceTransformer(
|
|
get_model_path(embedding_model, auto_update),
|
|
device=DEVICE_TYPE,
|
|
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
|
backend=SENTENCE_TRANSFORMERS_BACKEND,
|
|
model_kwargs=SENTENCE_TRANSFORMERS_MODEL_KWARGS,
|
|
)
|
|
except Exception as e:
|
|
log.debug(f"Error loading SentenceTransformer: {e}")
|
|
|
|
return ef
|
|
|
|
|
|
def get_rf(
|
|
engine: str = "",
|
|
reranking_model: Optional[str] = None,
|
|
external_reranker_url: str = "",
|
|
external_reranker_api_key: str = "",
|
|
auto_update: bool = False,
|
|
):
|
|
rf = None
|
|
if reranking_model:
|
|
if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]):
|
|
try:
|
|
from open_webui.retrieval.models.colbert import ColBERT
|
|
|
|
rf = ColBERT(
|
|
get_model_path(reranking_model, auto_update),
|
|
env="docker" if DOCKER else None,
|
|
)
|
|
|
|
except Exception as e:
|
|
log.error(f"ColBERT: {e}")
|
|
raise Exception(ERROR_MESSAGES.DEFAULT(e))
|
|
else:
|
|
if engine == "external":
|
|
try:
|
|
from open_webui.retrieval.models.external import ExternalReranker
|
|
|
|
rf = ExternalReranker(
|
|
url=external_reranker_url,
|
|
api_key=external_reranker_api_key,
|
|
model=reranking_model,
|
|
)
|
|
except Exception as e:
|
|
log.error(f"ExternalReranking: {e}")
|
|
raise Exception(ERROR_MESSAGES.DEFAULT(e))
|
|
else:
|
|
import sentence_transformers
|
|
|
|
try:
|
|
rf = sentence_transformers.CrossEncoder(
|
|
get_model_path(reranking_model, auto_update),
|
|
device=DEVICE_TYPE,
|
|
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
|
backend=SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND,
|
|
model_kwargs=SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS,
|
|
)
|
|
except Exception as e:
|
|
log.error(f"CrossEncoder: {e}")
|
|
raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error"))
|
|
|
|
return rf
|
|
|
|
|
|
##########################################
|
|
#
|
|
# API routes
|
|
#
|
|
##########################################
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class CollectionNameForm(BaseModel):
|
|
collection_name: Optional[str] = None
|
|
|
|
|
|
class ProcessUrlForm(CollectionNameForm):
|
|
url: str
|
|
|
|
|
|
class SearchForm(BaseModel):
|
|
queries: List[str]
|
|
|
|
|
|
class CollectionForm(BaseModel):
|
|
knowledge_id: Optional[str] = None
|
|
|
|
|
|
@router.get("/")
|
|
async def get_status(request: Request):
|
|
return {
|
|
"status": True,
|
|
"chunk_size": request.app.state.config.CHUNK_SIZE,
|
|
"chunk_overlap": request.app.state.config.CHUNK_OVERLAP,
|
|
"template": request.app.state.config.RAG_TEMPLATE,
|
|
"embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
|
|
"embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
|
|
"reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
|
|
"embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
|
}
|
|
|
|
|
|
@router.post("/embedding")
|
|
async def get_embedding_config(request: Request, collectionForm: Optional[CollectionForm], user=Depends(get_verified_user)):
|
|
"""
|
|
Retrieve the embedding configuration.
|
|
If DEFAULT_RAG_SETTINGS is True, return the default embedding settings.
|
|
Otherwise, return the embedding configuration stored in the database.
|
|
"""
|
|
|
|
knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id)
|
|
rag_config = {}
|
|
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
|
# Return the embedding configuration from the database
|
|
rag_config = knowledge_base.rag_config
|
|
return {
|
|
"status": True,
|
|
"embedding_engine": rag_config.get("embedding_engine", request.app.state.config.RAG_EMBEDDING_ENGINE),
|
|
"embedding_model": rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL),
|
|
"embedding_batch_size": rag_config.get("embedding_batch_size", request.app.state.config.RAG_EMBEDDING_BATCH_SIZE),
|
|
"openai_config": rag_config.get("openai_config", {
|
|
"url": request.app.state.config.RAG_OPENAI_API_BASE_URL,
|
|
"key": request.app.state.config.RAG_OPENAI_API_KEY,
|
|
}),
|
|
"ollama_config": rag_config.get("ollama_config", {
|
|
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
|
|
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
|
|
}),
|
|
"azure_openai_config": rag_config.get("azure_openai_config", {
|
|
"url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL,
|
|
"key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY,
|
|
"version": request.app.state.config.RAG_AZURE_OPENAI_API_VERSION,
|
|
}),
|
|
}
|
|
|
|
|
|
class OpenAIConfigForm(BaseModel):
|
|
url: str
|
|
key: str
|
|
|
|
|
|
class OllamaConfigForm(BaseModel):
|
|
url: str
|
|
key: str
|
|
|
|
|
|
class AzureOpenAIConfigForm(BaseModel):
|
|
url: str
|
|
key: str
|
|
version: str
|
|
|
|
|
|
class EmbeddingModelUpdateForm(BaseModel):
|
|
openai_config: Optional[OpenAIConfigForm] = None
|
|
ollama_config: Optional[OllamaConfigForm] = None
|
|
azure_openai_config: Optional[AzureOpenAIConfigForm] = None
|
|
embedding_engine: str
|
|
embedding_model: str
|
|
embedding_batch_size: Optional[int] = 1
|
|
knowledge_id: Optional[str] = None
|
|
|
|
|
|
@router.post("/embedding/update")
|
|
async def update_embedding_config(
|
|
request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_verified_user)
|
|
):
|
|
"""
|
|
Update the embedding model configuration.
|
|
If DEFAULT_RAG_SETTINGS is True, update the global configuration.
|
|
Otherwise, update the RAG configuration in the database for the user's knowledge base.
|
|
"""
|
|
try:
|
|
knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
|
|
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
|
# Update the RAG configuration in the database
|
|
rag_config = knowledge_base.rag_config
|
|
log.info(
|
|
f"Updating embedding model: {rag_config.get('embedding_model')} to {form_data.embedding_model}"
|
|
)
|
|
|
|
# Check if model is in use elsewhere, otherwise free up memory
|
|
in_use = Knowledges.is_model_in_use_elsewhere(model=rag_config.get('embedding_model'), model_type="embedding_model", id=form_data.knowledge_id)
|
|
|
|
if not in_use and not request.app.state.ef.get(request.app.state.config.RAG_EMBEDDING_MODEL) == rag_config.get("embedding_model") and rag_config.get("embedding_model"):
|
|
del request.app.state.ef[rag_config["embedding_model"]]
|
|
engine = rag_config["embedding_engine"]
|
|
target_model = rag_config["embedding_model"]
|
|
models_list = request.app.state.config.LOADED_EMBEDDING_MODELS[engine]
|
|
|
|
# Find and remove the dictionary that contains the target model
|
|
for model in models_list[:]: # Create a copy of the list for safe iteration
|
|
if model == target_model:
|
|
models_list.remove(model)
|
|
|
|
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
|
|
|
|
import gc
|
|
import torch
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
# Update embedding-related fields
|
|
rag_config["embedding_engine"] = form_data.embedding_engine
|
|
rag_config["embedding_model"] = form_data.embedding_model
|
|
rag_config["embedding_batch_size"] = form_data.embedding_batch_size
|
|
|
|
# Update OpenAI, Ollama, and Azure OpenAI configurations if provided
|
|
if form_data.openai_config is not None:
|
|
rag_config["openai_config"] = {
|
|
"url": form_data.openai_config.url,
|
|
"key": form_data.openai_config.key,
|
|
}
|
|
|
|
if form_data.ollama_config is not None:
|
|
rag_config["ollama_config"] = {
|
|
"url": form_data.ollama_config.url,
|
|
"key": form_data.ollama_config.key,
|
|
}
|
|
|
|
if form_data.azure_openai_config is not None:
|
|
rag_config["azure_openai_config"] = {
|
|
"url": form_data.azure_openai_config.url,
|
|
"key": form_data.azure_openai_config.key,
|
|
"version": form_data.azure_openai_config.version,
|
|
}
|
|
|
|
# Update the embedding function
|
|
if not rag_config["embedding_model"] in request.app.state.ef:
|
|
request.app.state.ef[rag_config["embedding_model"]] = get_ef(
|
|
rag_config["embedding_engine"],
|
|
rag_config["embedding_model"],
|
|
)
|
|
|
|
request.app.state.EMBEDDING_FUNCTION[rag_config["embedding_model"]] = get_embedding_function(
|
|
rag_config["embedding_engine"],
|
|
rag_config["embedding_model"],
|
|
request.app.state.ef[rag_config["embedding_model"]],
|
|
(
|
|
rag_config["openai_config"]["url"]
|
|
if rag_config["embedding_engine"] == "openai"
|
|
else rag_config["ollama_config"]["url"]
|
|
),
|
|
(
|
|
rag_config["openai_config"]["key"]
|
|
if rag_config["embedding_engine"] == "openai"
|
|
else rag_config["ollama_config"]["key"]
|
|
),
|
|
rag_config["embedding_batch_size"],
|
|
azure_api_version=(
|
|
rag_config["azure_openai_config"]["version"]
|
|
if rag_config["embedding_engine"] == "azure_openai"
|
|
else None
|
|
)
|
|
)
|
|
# add model to state for reloading on startup
|
|
if rag_config["embedding_engine"] == "azure_openai":
|
|
request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(
|
|
{rag_config["embedding_model"]: rag_config.get("azure_openai_config", {}).get("version")}
|
|
)
|
|
else:
|
|
request.app.state.config.LOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"])
|
|
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
|
|
# add model to state for selectable reranking models
|
|
if not rag_config["embedding_model"] in request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]]:
|
|
request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[rag_config["embedding_engine"]].append(rag_config["embedding_model"])
|
|
request.app.state.config._state["DOWNLOADED_EMBEDDING_MODELS"].save()
|
|
rag_config["DOWNLOADED_EMBEDDING_MODELS"] = request.app.state.config.DOWNLOADED_EMBEDDING_MODELS
|
|
rag_config["LOADED_EMBEDDING_MODELS"] = request.app.state.config.LOADED_EMBEDDING_MODELS
|
|
|
|
# Save the updated configuration to the database
|
|
Knowledges.update_rag_config_by_id(
|
|
id=form_data.knowledge_id, rag_config=rag_config
|
|
)
|
|
|
|
return {
|
|
"status": True,
|
|
"embedding_engine": rag_config["embedding_engine"],
|
|
"embedding_model": rag_config["embedding_model"],
|
|
"embedding_batch_size": rag_config["embedding_batch_size"],
|
|
"openai_config": rag_config.get("openai_config", {}),
|
|
"ollama_config": rag_config.get("ollama_config", {}),
|
|
"azure_openai_config": rag_config.get("azure_openai_config", {}),
|
|
"DOWNLOADED_EMBEDDING_MODELS": rag_config["DOWNLOADED_EMBEDDING_MODELS"],
|
|
"LOADED_EMBEDDING_MODELS": rag_config["LOADED_EMBEDDING_MODELS"],
|
|
}
|
|
else:
|
|
# Update the global configuration
|
|
log.info(
|
|
f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
|
|
)
|
|
|
|
# Check if model is in use elsewhere, otherwise free up memory
|
|
in_use = Knowledges.is_model_in_use_elsewhere(model=request.app.state.config.RAG_EMBEDDING_MODEL, model_type="embedding_model")
|
|
if not in_use:
|
|
del request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL]
|
|
engine = request.app.state.config.RAG_EMBEDDING_ENGINE
|
|
target_model = request.app.state.config.RAG_EMBEDDING_MODEL
|
|
models_list = request.app.state.config.LOADED_EMBEDDING_MODELS[engine]
|
|
|
|
# Find and remove the dictionary that contains the target model
|
|
for model in models_list[:]: # Create a copy of the list for safe iteration
|
|
if model == target_model:
|
|
models_list.remove(model)
|
|
|
|
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
|
|
|
|
import gc
|
|
import torch
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
|
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
|
|
|
if request.app.state.config.RAG_EMBEDDING_ENGINE in [
|
|
"ollama",
|
|
"openai",
|
|
"azure_openai",
|
|
]:
|
|
if form_data.openai_config is not None:
|
|
request.app.state.config.RAG_OPENAI_API_BASE_URL = (
|
|
form_data.openai_config.url
|
|
)
|
|
request.app.state.config.RAG_OPENAI_API_KEY = (
|
|
form_data.openai_config.key
|
|
)
|
|
|
|
if form_data.ollama_config is not None:
|
|
request.app.state.config.RAG_OLLAMA_BASE_URL = (
|
|
form_data.ollama_config.url
|
|
)
|
|
request.app.state.config.RAG_OLLAMA_API_KEY = (
|
|
form_data.ollama_config.key
|
|
)
|
|
|
|
if form_data.azure_openai_config is not None:
|
|
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
|
|
form_data.azure_openai_config.url
|
|
)
|
|
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
|
|
form_data.azure_openai_config.key
|
|
)
|
|
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
|
|
form_data.azure_openai_config.version
|
|
)
|
|
|
|
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = (
|
|
form_data.embedding_batch_size
|
|
)
|
|
|
|
# Update the embedding function
|
|
if not form_data.embedding_model in request.app.state.ef:
|
|
request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL] = get_ef(
|
|
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
|
request.app.state.config.RAG_EMBEDDING_MODEL,
|
|
)
|
|
|
|
request.app.state.EMBEDDING_FUNCTION[request.app.state.config.RAG_EMBEDDING_MODEL] = get_embedding_function(
|
|
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
|
request.app.state.config.RAG_EMBEDDING_MODEL,
|
|
request.app.state.ef[request.app.state.config.RAG_EMBEDDING_MODEL],
|
|
(
|
|
request.app.state.config.RAG_OPENAI_API_BASE_URL
|
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
|
else (
|
|
request.app.state.config.RAG_OLLAMA_BASE_URL
|
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
|
|
else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL
|
|
)
|
|
),
|
|
(
|
|
request.app.state.config.RAG_OPENAI_API_KEY
|
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
|
else (
|
|
request.app.state.config.RAG_OLLAMA_API_KEY
|
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
|
|
else request.app.state.config.RAG_AZURE_OPENAI_API_KEY
|
|
)
|
|
),
|
|
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
|
azure_api_version=(
|
|
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION
|
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
|
|
else None
|
|
),
|
|
)
|
|
# add model to state for reloading on startup
|
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai":
|
|
request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(
|
|
{request.app.state.config.RAG_EMBEDDING_MODEL: request.app.state.config.RAG_AZURE_OPENAI_API_VERSION}
|
|
)
|
|
else:
|
|
request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL)
|
|
request.app.state.config._state["LOADED_EMBEDDING_MODELS"].save()
|
|
# add model to state for selectable embedding models
|
|
if not request.app.state.config.RAG_EMBEDDING_MODEL in request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE]:
|
|
request.app.state.config.DOWNLOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL)
|
|
request.app.state.config._state["DOWNLOADED_EMBEDDING_MODELS"].save()
|
|
|
|
return {
|
|
"status": True,
|
|
"embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
|
|
"embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
|
|
"embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
|
"openai_config": {
|
|
"url": request.app.state.config.RAG_OPENAI_API_BASE_URL,
|
|
"key": request.app.state.config.RAG_OPENAI_API_KEY,
|
|
},
|
|
"ollama_config": {
|
|
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
|
|
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
|
|
},
|
|
"azure_openai_config": {
|
|
"url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL,
|
|
"key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY,
|
|
"version": request.app.state.config.RAG_AZURE_OPENAI_API_VERSION,
|
|
},
|
|
"LOADED_EMBEDDING_MODELS": request.app.state.config.LOADED_EMBEDDING_MODELS,
|
|
"DOWNLOADED_EMBEDDING_MODELS": request.app.state.config.DOWNLOADED_EMBEDDING_MODELS,
|
|
}
|
|
except Exception as e:
|
|
log.exception(f"Problem updating embedding model: {e}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=ERROR_MESSAGES.DEFAULT(e),
|
|
)
|
|
|
|
|
|
@router.post("/config")
|
|
async def get_rag_config(request: Request, collectionForm: CollectionForm, user=Depends(get_verified_user)):
|
|
"""
|
|
Retrieve the full RAG configuration.
|
|
If DEFAULT_RAG_SETTINGS is True, return the default settings.
|
|
Otherwise, return the RAG configuration stored in the database.
|
|
"""
|
|
|
|
knowledge_base = Knowledges.get_knowledge_by_id(collectionForm.knowledge_id)
|
|
rag_config = {}
|
|
web_config = {}
|
|
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
|
# Return the RAG configuration from the database
|
|
rag_config = knowledge_base.rag_config
|
|
web_config = rag_config.get("web", {})
|
|
return {
|
|
"status": True,
|
|
# RAG settings
|
|
"RAG_TEMPLATE": rag_config.get("TEMPLATE", request.app.state.config.RAG_TEMPLATE),
|
|
"TOP_K": rag_config.get("TOP_K", request.app.state.config.TOP_K),
|
|
"BYPASS_EMBEDDING_AND_RETRIEVAL": rag_config.get("BYPASS_EMBEDDING_AND_RETRIEVAL", request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL),
|
|
"RAG_FULL_CONTEXT": rag_config.get("RAG_FULL_CONTEXT", request.app.state.config.RAG_FULL_CONTEXT),
|
|
# Hybrid search settings
|
|
"ENABLE_RAG_HYBRID_SEARCH": rag_config.get("ENABLE_RAG_HYBRID_SEARCH", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH),
|
|
"TOP_K_RERANKER": rag_config.get("TOP_K_RERANKER", request.app.state.config.TOP_K_RERANKER),
|
|
"RELEVANCE_THRESHOLD": rag_config.get("RELEVANCE_THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD),
|
|
"HYBRID_BM25_WEIGHT": rag_config.get("HYBRID_BM25_WEIGHT", request.app.state.config.HYBRID_BM25_WEIGHT),
|
|
# Content extraction settings
|
|
"CONTENT_EXTRACTION_ENGINE": rag_config.get("CONTENT_EXTRACTION_ENGINE", request.app.state.config.CONTENT_EXTRACTION_ENGINE),
|
|
"PDF_EXTRACT_IMAGES": rag_config.get("PDF_EXTRACT_IMAGES", request.app.state.config.PDF_EXTRACT_IMAGES),
|
|
"DATALAB_MARKER_API_KEY": rag_config.get("DATALAB_MARKER_API_KEY", request.app.state.config.DATALAB_MARKER_API_KEY),
|
|
"DATALAB_MARKER_LANGS": rag_config.get("DATALAB_MARKER_LANGS", request.app.state.config.DATALAB_MARKER_LANGS),
|
|
"DATALAB_MARKER_SKIP_CACHE": rag_config.get("DATALAB_MARKER_SKIP_CACHE", request.app.state.config.DATALAB_MARKER_SKIP_CACHE),
|
|
"DATALAB_MARKER_FORCE_OCR": rag_config.get("DATALAB_MARKER_FORCE_OCR", request.app.state.config.DATALAB_MARKER_FORCE_OCR),
|
|
"DATALAB_MARKER_PAGINATE": rag_config.get("DATALAB_MARKER_PAGINATE", request.app.state.config.DATALAB_MARKER_PAGINATE),
|
|
"DATALAB_MARKER_STRIP_EXISTING_OCR": rag_config.get("DATALAB_MARKER_STRIP_EXISTING_OCR", request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR),
|
|
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION": rag_config.get("DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION),
|
|
"DATALAB_MARKER_USE_LLM": rag_config.get("DATALAB_MARKER_USE_LLM", request.app.state.config.DATALAB_MARKER_USE_LLM),
|
|
"DATALAB_MARKER_OUTPUT_FORMAT": rag_config.get("DATALAB_MARKER_OUTPUT_FORMAT", request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT),
|
|
"EXTERNAL_DOCUMENT_LOADER_URL": rag_config.get("EXTERNAL_DOCUMENT_LOADER_URL", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL),
|
|
"EXTERNAL_DOCUMENT_LOADER_API_KEY": rag_config.get("EXTERNAL_DOCUMENT_LOADER_API_KEY", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY),
|
|
"TIKA_SERVER_URL": rag_config.get("TIKA_SERVER_URL", request.app.state.config.TIKA_SERVER_URL),
|
|
"DOCLING_SERVER_URL": rag_config.get("DOCLING_SERVER_URL", request.app.state.config.DOCLING_SERVER_URL),
|
|
"DOCLING_OCR_ENGINE": rag_config.get("DOCLING_OCR_ENGINE", request.app.state.config.DOCLING_OCR_ENGINE),
|
|
"DOCLING_OCR_LANG": rag_config.get("DOCLING_OCR_LANG", request.app.state.config.DOCLING_OCR_LANG),
|
|
"DOCLING_DO_PICTURE_DESCRIPTION": rag_config.get("DOCLING_DO_PICTURE_DESCRIPTION", request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION),
|
|
"DOCLING_PICTURE_DESCRIPTION_MODE": rag_config.get("DOCLING_PICTURE_DESCRIPTION_MODE", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE),
|
|
"DOCLING_PICTURE_DESCRIPTION_LOCAL": rag_config.get("DOCLING_PICTURE_DESCRIPTION_LOCAL", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL),
|
|
"DOCLING_PICTURE_DESCRIPTION_API": rag_config.get("DOCLING_PICTURE_DESCRIPTION_API", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API),
|
|
"DOCUMENT_INTELLIGENCE_ENDPOINT": rag_config.get("DOCUMENT_INTELLIGENCE_ENDPOINT", request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT),
|
|
"DOCUMENT_INTELLIGENCE_KEY": rag_config.get("DOCUMENT_INTELLIGENCE_KEY", request.app.state.config.DOCUMENT_INTELLIGENCE_KEY),
|
|
"MISTRAL_OCR_API_KEY": rag_config.get("MISTRAL_OCR_API_KEY", request.app.state.config.MISTRAL_OCR_API_KEY),
|
|
# Reranking settings
|
|
"RAG_RERANKING_MODEL": rag_config.get("RAG_RERANKING_MODEL", request.app.state.config.RAG_RERANKING_MODEL),
|
|
"RAG_RERANKING_ENGINE": rag_config.get("RAG_RERANKING_ENGINE", request.app.state.config.RAG_RERANKING_ENGINE),
|
|
"RAG_EXTERNAL_RERANKER_URL": rag_config.get("RAG_EXTERNAL_RERANKER_URL", request.app.state.config.RAG_EXTERNAL_RERANKER_URL),
|
|
"RAG_EXTERNAL_RERANKER_API_KEY": rag_config.get("RAG_EXTERNAL_RERANKER_API_KEY", request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY),
|
|
# Chunking settings
|
|
"TEXT_SPLITTER": rag_config.get("TEXT_SPLITTER", request.app.state.config.TEXT_SPLITTER),
|
|
"CHUNK_SIZE": rag_config.get("CHUNK_SIZE", request.app.state.config.CHUNK_SIZE),
|
|
"CHUNK_OVERLAP": rag_config.get("CHUNK_OVERLAP", request.app.state.config.CHUNK_OVERLAP),
|
|
# File upload settings
|
|
"FILE_MAX_SIZE": rag_config.get("FILE_MAX_SIZE", request.app.state.config.FILE_MAX_SIZE),
|
|
"FILE_MAX_COUNT": rag_config.get("FILE_MAX_COUNT", request.app.state.config.FILE_MAX_COUNT),
|
|
"FILE_IMAGE_COMPRESSION_WIDTH": rag_config.get("FILE_IMAGE_COMPRESSION_WIDTH", request.app.state.config.FILE_IMAGE_COMPRESSION_WIDTH),
|
|
"FILE_IMAGE_COMPRESSION_HEIGHT": rag_config.get("FILE_IMAGE_COMPRESSION_HEIGHT", request.app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT),
|
|
"ALLOWED_FILE_EXTENSIONS": rag_config.get("ALLOWED_FILE_EXTENSIONS", request.app.state.config.ALLOWED_FILE_EXTENSIONS),
|
|
# Integration settings
|
|
"ENABLE_GOOGLE_DRIVE_INTEGRATION": rag_config.get("ENABLE_GOOGLE_DRIVE_INTEGRATION", request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION),
|
|
"ENABLE_ONEDRIVE_INTEGRATION": rag_config.get("enable_onedrive_integration", request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION),
|
|
# Web search settings
|
|
"web": {
|
|
"ENABLE_WEB_SEARCH": web_config.get("ENABLE_WEB_SEARCH", request.app.state.config.ENABLE_WEB_SEARCH),
|
|
"WEB_SEARCH_ENGINE": web_config.get("WEB_SEARCH_ENGINE", request.app.state.config.WEB_SEARCH_ENGINE),
|
|
"WEB_SEARCH_TRUST_ENV": web_config.get("WEB_SEARCH_TRUST_ENV", request.app.state.config.WEB_SEARCH_TRUST_ENV),
|
|
"WEB_SEARCH_RESULT_COUNT": web_config.get("WEB_SEARCH_RESULT_COUNT", request.app.state.config.WEB_SEARCH_RESULT_COUNT),
|
|
"WEB_SEARCH_CONCURRENT_REQUESTS": web_config.get("WEB_SEARCH_CONCURRENT_REQUESTS", request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS),
|
|
"WEB_SEARCH_DOMAIN_FILTER_LIST": web_config.get("WEB_SEARCH_DOMAIN_FILTER_LIST", request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST),
|
|
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": web_config.get("BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL", request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL),
|
|
"BYPASS_WEB_SEARCH_WEB_LOADER": web_config.get("BYPASS_WEB_SEARCH_WEB_LOADER", request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER),
|
|
"SEARXNG_QUERY_URL": web_config.get("SEARXNG_QUERY_URL", request.app.state.config.SEARXNG_QUERY_URL),
|
|
"YACY_QUERY_URL": web_config.get("YACY_QUERY_URL", request.app.state.config.YACY_QUERY_URL),
|
|
"YACY_USERNAME": web_config.get("YACY_QUERY_USERNAME",request.app.state.config.YACY_USERNAME),
|
|
"YACY_PASSWORD": web_config.get("YACY_QUERY_PASSWORD",request.app.state.config.YACY_PASSWORD),
|
|
"GOOGLE_PSE_API_KEY": web_config.get("GOOGLE_PSE_API_KEY", request.app.state.config.GOOGLE_PSE_API_KEY),
|
|
"GOOGLE_PSE_ENGINE_ID": web_config.get("GOOGLE_PSE_ENGINE_ID", request.app.state.config.GOOGLE_PSE_ENGINE_ID),
|
|
"BRAVE_SEARCH_API_KEY": web_config.get("BRAVE_SEARCH_API_KEY", request.app.state.config.BRAVE_SEARCH_API_KEY),
|
|
"KAGI_SEARCH_API_KEY": web_config.get("KAGI_SEARCH_API_KEY", request.app.state.config.KAGI_SEARCH_API_KEY),
|
|
"MOJEEK_SEARCH_API_KEY": web_config.get("MOJEEK_SEARCH_API_KEY", request.app.state.config.MOJEEK_SEARCH_API_KEY),
|
|
"BOCHA_SEARCH_API_KEY": web_config.get("BOCHA_SEARCH_API_KEY", request.app.state.config.BOCHA_SEARCH_API_KEY),
|
|
"SERPSTACK_API_KEY": web_config.get("SERPSTACK_API_KEY", request.app.state.config.SERPSTACK_API_KEY),
|
|
"SERPSTACK_HTTPS": web_config.get("SERPSTACK_HTTPS", request.app.state.config.SERPSTACK_HTTPS),
|
|
"SERPER_API_KEY": web_config.get("SERPER_API_KEY", request.app.state.config.SERPER_API_KEY),
|
|
"SERPLY_API_KEY": web_config.get("SERPLY_API_KEY", request.app.state.config.SERPLY_API_KEY),
|
|
"TAVILY_API_KEY": web_config.get("TAVILY_API_KEY", request.app.state.config.TAVILY_API_KEY),
|
|
"SEARCHAPI_API_KEY": web_config.get("SEARCHAPI_API_KEY", request.app.state.config.SEARCHAPI_API_KEY),
|
|
"SEARCHAPI_ENGINE": web_config.get("SEARCHAPI_ENGINE", request.app.state.config.SEARCHAPI_ENGINE),
|
|
"SERPAPI_API_KEY": web_config.get("SERPAPI_API_KEY", request.app.state.config.SERPAPI_API_KEY),
|
|
"SERPAPI_ENGINE": web_config.get("SERPAPI_ENGINE", request.app.state.config.SERPAPI_ENGINE),
|
|
"JINA_API_KEY": web_config.get("JINA_API_KEY", request.app.state.config.JINA_API_KEY),
|
|
"BING_SEARCH_V7_ENDPOINT": web_config.get("BING_SEARCH_V7_ENDPOINT", request.app.state.config.BING_SEARCH_V7_ENDPOINT),
|
|
"BING_SEARCH_V7_SUBSCRIPTION_KEY": web_config.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY),
|
|
"EXA_API_KEY": web_config.get("EXA_API_KEY", request.app.state.config.EXA_API_KEY),
|
|
"PERPLEXITY_API_KEY": web_config.get("PERPLEXITY_API_KEY", request.app.state.config.PERPLEXITY_API_KEY),
|
|
"PERPLEXITY_MODEL": web_config.get("PERPLEXITY_MODEL", request.app.state.config.PERPLEXITY_MODEL),
|
|
"PERPLEXITY_SEARCH_CONTEXT_USAGE": web_config.get("PERPLEXITY_SEARCH_CONTEXT_USAGE", request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE),
|
|
"SOUGOU_API_SID": web_config.get("SOUGOU_API_SID", request.app.state.config.SOUGOU_API_SID),
|
|
"SOUGOU_API_SK": web_config.get("SOUGOU_API_SK", request.app.state.config.SOUGOU_API_SK),
|
|
"WEB_LOADER_ENGINE": web_config.get("WEB_LOADER_ENGINE", request.app.state.config.WEB_LOADER_ENGINE),
|
|
"ENABLE_WEB_LOADER_SSL_VERIFICATION": web_config.get("ENABLE_WEB_LOADER_SSL_VERIFICATION", request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION),
|
|
"PLAYWRIGHT_WS_URL": web_config.get("PLAYWRIGHT_WS_URL", request.app.state.config.PLAYWRIGHT_WS_URL),
|
|
"PLAYWRIGHT_TIMEOUT": web_config.get("PLAYWRIGHT_TIMEOUT", request.app.state.config.PLAYWRIGHT_TIMEOUT),
|
|
"FIRECRAWL_API_KEY": web_config.get("FIRECRAWL_API_KEY", request.app.state.config.FIRECRAWL_API_KEY),
|
|
"FIRECRAWL_API_BASE_URL": web_config.get("FIRECRAWL_API_BASE_URL", request.app.state.config.FIRECRAWL_API_BASE_URL),
|
|
"TAVILY_EXTRACT_DEPTH": web_config.get("TAVILY_EXTRACT_DEPTH", request.app.state.config.TAVILY_EXTRACT_DEPTH),
|
|
"EXTERNAL_WEB_SEARCH_URL": web_config.get("WEB_SEARCH_URL", request.app.state.config.EXTERNAL_WEB_SEARCH_URL),
|
|
"EXTERNAL_WEB_SEARCH_API_KEY": web_config.get("WEB_SEARCH_KEY", request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY),
|
|
"EXTERNAL_WEB_LOADER_URL": web_config.get("WEB_LOADER_URL", request.app.state.config.EXTERNAL_WEB_LOADER_URL),
|
|
"EXTERNAL_WEB_LOADER_API_KEY": web_config.get("WEB_LOADER_KEY", request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY),
|
|
"YOUTUBE_LOADER_LANGUAGE": web_config.get("YOUTUBE_LOADER_LANGUAGE", request.app.state.config.YOUTUBE_LOADER_LANGUAGE),
|
|
"YOUTUBE_LOADER_PROXY_URL": web_config.get("YOUTUBE_LOADER_PROXY_URL", request.app.state.config.YOUTUBE_LOADER_PROXY_URL),
|
|
"YOUTUBE_LOADER_TRANSLATION": web_config.get("YOUTUBE_LOADER_TRANSLATION", request.app.state.YOUTUBE_LOADER_TRANSLATION),
|
|
},
|
|
"DEFAULT_RAG_SETTINGS": rag_config.get("DEFAULT_RAG_SETTINGS", request.app.state.config.DEFAULT_RAG_SETTINGS),
|
|
"DOWNLOADED_EMBEDDING_MODELS": request.app.state.config.DOWNLOADED_EMBEDDING_MODELS,
|
|
"DOWNLOADED_RERANKING_MODELS": request.app.state.config.DOWNLOADED_RERANKING_MODELS,
|
|
"LOADED_EMBEDDING_MODELS": request.app.state.config.LOADED_EMBEDDING_MODELS,
|
|
"LOADED_RERANKING_MODELS": request.app.state.config.LOADED_RERANKING_MODELS,
|
|
}
|
|
|
|
|
|
class WebConfig(BaseModel):
|
|
ENABLE_WEB_SEARCH: Optional[bool] = None
|
|
WEB_SEARCH_ENGINE: Optional[str] = None
|
|
WEB_SEARCH_TRUST_ENV: Optional[bool] = None
|
|
WEB_SEARCH_RESULT_COUNT: Optional[int] = None
|
|
WEB_SEARCH_CONCURRENT_REQUESTS: Optional[int] = None
|
|
WEB_SEARCH_DOMAIN_FILTER_LIST: Optional[List[str]] = []
|
|
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
|
|
BYPASS_WEB_SEARCH_WEB_LOADER: Optional[bool] = None
|
|
SEARXNG_QUERY_URL: Optional[str] = None
|
|
YACY_QUERY_URL: Optional[str] = None
|
|
YACY_USERNAME: Optional[str] = None
|
|
YACY_PASSWORD: Optional[str] = None
|
|
GOOGLE_PSE_API_KEY: Optional[str] = None
|
|
GOOGLE_PSE_ENGINE_ID: Optional[str] = None
|
|
BRAVE_SEARCH_API_KEY: Optional[str] = None
|
|
KAGI_SEARCH_API_KEY: Optional[str] = None
|
|
MOJEEK_SEARCH_API_KEY: Optional[str] = None
|
|
BOCHA_SEARCH_API_KEY: Optional[str] = None
|
|
SERPSTACK_API_KEY: Optional[str] = None
|
|
SERPSTACK_HTTPS: Optional[bool] = None
|
|
SERPER_API_KEY: Optional[str] = None
|
|
SERPLY_API_KEY: Optional[str] = None
|
|
TAVILY_API_KEY: Optional[str] = None
|
|
SEARCHAPI_API_KEY: Optional[str] = None
|
|
SEARCHAPI_ENGINE: Optional[str] = None
|
|
SERPAPI_API_KEY: Optional[str] = None
|
|
SERPAPI_ENGINE: Optional[str] = None
|
|
JINA_API_KEY: Optional[str] = None
|
|
BING_SEARCH_V7_ENDPOINT: Optional[str] = None
|
|
BING_SEARCH_V7_SUBSCRIPTION_KEY: Optional[str] = None
|
|
EXA_API_KEY: Optional[str] = None
|
|
PERPLEXITY_API_KEY: Optional[str] = None
|
|
PERPLEXITY_MODEL: Optional[str] = None
|
|
PERPLEXITY_SEARCH_CONTEXT_USAGE: Optional[str] = None
|
|
SOUGOU_API_SID: Optional[str] = None
|
|
SOUGOU_API_SK: Optional[str] = None
|
|
WEB_LOADER_ENGINE: Optional[str] = None
|
|
ENABLE_WEB_LOADER_SSL_VERIFICATION: Optional[bool] = None
|
|
PLAYWRIGHT_WS_URL: Optional[str] = None
|
|
PLAYWRIGHT_TIMEOUT: Optional[int] = None
|
|
FIRECRAWL_API_KEY: Optional[str] = None
|
|
FIRECRAWL_API_BASE_URL: Optional[str] = None
|
|
TAVILY_EXTRACT_DEPTH: Optional[str] = None
|
|
EXTERNAL_WEB_SEARCH_URL: Optional[str] = None
|
|
EXTERNAL_WEB_SEARCH_API_KEY: Optional[str] = None
|
|
EXTERNAL_WEB_LOADER_URL: Optional[str] = None
|
|
EXTERNAL_WEB_LOADER_API_KEY: Optional[str] = None
|
|
YOUTUBE_LOADER_LANGUAGE: Optional[List[str]] = None
|
|
YOUTUBE_LOADER_PROXY_URL: Optional[str] = None
|
|
YOUTUBE_LOADER_TRANSLATION: Optional[str] = None
|
|
|
|
|
|
class ConfigForm(BaseModel):
|
|
# RAG settings
|
|
RAG_TEMPLATE: Optional[str] = None
|
|
TOP_K: Optional[int] = None
|
|
BYPASS_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
|
|
RAG_FULL_CONTEXT: Optional[bool] = None
|
|
|
|
# Hybrid search settings
|
|
ENABLE_RAG_HYBRID_SEARCH: Optional[bool] = None
|
|
TOP_K_RERANKER: Optional[int] = None
|
|
RELEVANCE_THRESHOLD: Optional[float] = None
|
|
HYBRID_BM25_WEIGHT: Optional[float] = None
|
|
|
|
# Content extraction settings
|
|
CONTENT_EXTRACTION_ENGINE: Optional[str] = None
|
|
PDF_EXTRACT_IMAGES: Optional[bool] = None
|
|
DATALAB_MARKER_API_KEY: Optional[str] = None
|
|
DATALAB_MARKER_LANGS: Optional[str] = None
|
|
DATALAB_MARKER_SKIP_CACHE: Optional[bool] = None
|
|
DATALAB_MARKER_FORCE_OCR: Optional[bool] = None
|
|
DATALAB_MARKER_PAGINATE: Optional[bool] = None
|
|
DATALAB_MARKER_STRIP_EXISTING_OCR: Optional[bool] = None
|
|
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION: Optional[bool] = None
|
|
DATALAB_MARKER_USE_LLM: Optional[bool] = None
|
|
DATALAB_MARKER_OUTPUT_FORMAT: Optional[str] = None
|
|
EXTERNAL_DOCUMENT_LOADER_URL: Optional[str] = None
|
|
EXTERNAL_DOCUMENT_LOADER_API_KEY: Optional[str] = None
|
|
|
|
TIKA_SERVER_URL: Optional[str] = None
|
|
DOCLING_SERVER_URL: Optional[str] = None
|
|
DOCLING_OCR_ENGINE: Optional[str] = None
|
|
DOCLING_OCR_LANG: Optional[str] = None
|
|
DOCLING_DO_PICTURE_DESCRIPTION: Optional[bool] = None
|
|
DOCLING_PICTURE_DESCRIPTION_MODE: Optional[str] = None
|
|
DOCLING_PICTURE_DESCRIPTION_LOCAL: Optional[dict] = None
|
|
DOCLING_PICTURE_DESCRIPTION_API: Optional[dict] = None
|
|
DOCUMENT_INTELLIGENCE_ENDPOINT: Optional[str] = None
|
|
DOCUMENT_INTELLIGENCE_KEY: Optional[str] = None
|
|
MISTRAL_OCR_API_KEY: Optional[str] = None
|
|
|
|
# Reranking settings
|
|
RAG_RERANKING_MODEL: Optional[str] = None
|
|
RAG_RERANKING_ENGINE: Optional[str] = None
|
|
RAG_EXTERNAL_RERANKER_URL: Optional[str] = None
|
|
RAG_EXTERNAL_RERANKER_API_KEY: Optional[str] = None
|
|
|
|
# Chunking settings
|
|
TEXT_SPLITTER: Optional[str] = None
|
|
CHUNK_SIZE: Optional[int] = None
|
|
CHUNK_OVERLAP: Optional[int] = None
|
|
|
|
# File upload settings
|
|
FILE_MAX_SIZE: Optional[int] = None
|
|
FILE_MAX_COUNT: Optional[int] = None
|
|
FILE_IMAGE_COMPRESSION_WIDTH: Optional[int] = None
|
|
FILE_IMAGE_COMPRESSION_HEIGHT: Optional[int] = None
|
|
ALLOWED_FILE_EXTENSIONS: Optional[List[str]] = None
|
|
|
|
# Integration settings
|
|
ENABLE_GOOGLE_DRIVE_INTEGRATION: Optional[bool] = None
|
|
ENABLE_ONEDRIVE_INTEGRATION: Optional[bool] = None
|
|
|
|
# Web search settings
|
|
web: Optional[WebConfig] = None
|
|
|
|
# knowledge base ID
|
|
knowledge_id: Optional[str] = None
|
|
|
|
@router.post("/config/update")
|
|
async def update_rag_config(
|
|
request: Request, form_data: ConfigForm, user=Depends(get_verified_user)
|
|
):
|
|
"""
|
|
Update the RAG configuration.
|
|
If DEFAULT_RAG_SETTINGS is True, update the global configuration.
|
|
Otherwise, update the RAG configuration in the database for the user's knowledge base.
|
|
"""
|
|
|
|
knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
|
|
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
|
# Update the RAG configuration in the database
|
|
rag_config = knowledge_base.rag_config
|
|
|
|
# Free up memory if hybrid search is disabled and model is not in use elswhere
|
|
in_use = Knowledges.is_model_in_use_elsewhere(model=rag_config.get("RAG_RERANKING_MODEL"), model_type="RAG_RERANKING_MODEL", id=form_data.knowledge_id)
|
|
|
|
if not form_data.ENABLE_RAG_HYBRID_SEARCH and \
|
|
not in_use and \
|
|
request.app.state.rf.get(rag_config["RAG_RERANKING_MODEL"]):
|
|
if rag_config.get("RAG_RERANKING_MODEL"):
|
|
del request.app.state.rf[rag_config["RAG_RERANKING_MODEL"]]
|
|
engine = request.app.state.config.RAG_RERANKING_ENGINE
|
|
target_model = rag_config["RAG_RERANKING_MODEL"]
|
|
models_list = request.app.state.config.LOADED_RERANKING_MODELS[engine]
|
|
|
|
# Find and remove the dictionary that contains the target model
|
|
for model_config in models_list[:]: # Create a copy of the list for safe iteration
|
|
if model_config["RAG_RERANKING_MODEL"] == target_model:
|
|
models_list.remove(model_config)
|
|
|
|
request.app.state.config._state["LOADED_RERANKING_MODELS"].save()
|
|
|
|
import gc
|
|
import torch
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
# Update only the provided fields in the rag_config
|
|
for field, value in form_data.model_dump(exclude_unset=True).items():
|
|
if field == "web" and value is not None:
|
|
rag_config["web"] = {**rag_config.get("web", {}), **value}
|
|
else:
|
|
rag_config[field] = value
|
|
|
|
|
|
log.info(
|
|
f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}"
|
|
)
|
|
try:
|
|
try:
|
|
if not rag_config["RAG_RERANKING_MODEL"] in request.app.state.rf and not rag_config["RAG_RERANKING_MODEL"] == "":
|
|
request.app.state.rf[rag_config["RAG_RERANKING_MODEL"]] = get_rf(
|
|
rag_config["RAG_RERANKING_ENGINE"],
|
|
rag_config["RAG_RERANKING_MODEL"],
|
|
rag_config["RAG_EXTERNAL_RERANKER_URL"],
|
|
rag_config["RAG_EXTERNAL_RERANKER_API_KEY"],
|
|
True,
|
|
)
|
|
|
|
# add model to state for reloading on startup
|
|
request.app.state.config.LOADED_RERANKING_MODELS[rag_config["RAG_RERANKING_ENGINE"]].append({
|
|
"RAG_RERANKING_MODEL": rag_config["RAG_RERANKING_MODEL"],
|
|
"RAG_EXTERNAL_RERANKER_URL": rag_config["RAG_EXTERNAL_RERANKER_URL"],
|
|
"RAG_EXTERNAL_RERANKER_API_KEY": rag_config["RAG_EXTERNAL_RERANKER_API_KEY"]})
|
|
request.app.state.config._state["LOADED_RERANKING_MODELS"].save()
|
|
|
|
# add model to state for selectable reranking models
|
|
if rag_config["RAG_RERANKING_MODEL"] not in request.app.state.config.DOWNLOADED_RERANKING_MODELS[rag_config["RAG_RERANKING_ENGINE"]]:
|
|
request.app.state.config.DOWNLOADED_RERANKING_MODELS[rag_config["RAG_RERANKING_ENGINE"]].append(rag_config["RAG_RERANKING_MODEL"])
|
|
request.app.state.config._state["DOWNLOADED_RERANKING_MODELS"].save()
|
|
|
|
rag_config["LOADED_RERANKING_MODELS"] = request.app.state.config.LOADED_RERANKING_MODELS
|
|
rag_config["DOWNLOADED_RERANKING_MODELS"] = request.app.state.config.DOWNLOADED_RERANKING_MODELS
|
|
|
|
except Exception as e:
|
|
log.error(f"Error loading reranking model: {e}")
|
|
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
|
except Exception as e:
|
|
log.exception(f"Problem updating reranking model: {e}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=ERROR_MESSAGES.DEFAULT(e),
|
|
)
|
|
|
|
Knowledges.update_rag_config_by_id(
|
|
id=knowledge_base.id, rag_config=rag_config
|
|
)
|
|
|
|
return rag_config
|
|
else:
|
|
# Update the global configuration
|
|
# RAG settings
|
|
request.app.state.config.RAG_TEMPLATE = (
|
|
form_data.RAG_TEMPLATE
|
|
if form_data.RAG_TEMPLATE is not None
|
|
else request.app.state.config.RAG_TEMPLATE
|
|
)
|
|
request.app.state.config.TOP_K = (
|
|
form_data.TOP_K
|
|
if form_data.TOP_K is not None
|
|
else request.app.state.config.TOP_K
|
|
)
|
|
request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = (
|
|
form_data.BYPASS_EMBEDDING_AND_RETRIEVAL
|
|
if form_data.BYPASS_EMBEDDING_AND_RETRIEVAL is not None
|
|
else request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
|
|
)
|
|
request.app.state.config.RAG_FULL_CONTEXT = (
|
|
form_data.RAG_FULL_CONTEXT
|
|
if form_data.RAG_FULL_CONTEXT is not None
|
|
else request.app.state.config.RAG_FULL_CONTEXT
|
|
)
|
|
|
|
# Hybrid search settings
|
|
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
|
|
form_data.ENABLE_RAG_HYBRID_SEARCH
|
|
if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
|
|
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
|
|
)
|
|
|
|
# Free up memory if hybrid search is disabled and model is not in use elswhere
|
|
in_use = Knowledges.is_model_in_use_elsewhere(model=request.app.state.config.RAG_RERANKING_MODEL, model_type="RAG_RERANKING_MODEL")
|
|
|
|
if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and \
|
|
not in_use and \
|
|
request.app.state.rf.get(request.app.state.config.RAG_RERANKING_MODEL):
|
|
del request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL]
|
|
engine = request.app.state.config.RAG_RERANKING_ENGINE
|
|
target_model = request.app.state.config.RAG_RERANKING_MODEL
|
|
models_list = request.app.state.config.LOADED_RERANKING_MODELS[engine]
|
|
|
|
# Find and remove the dictionary that contains the target model
|
|
for model_config in models_list[:]: # Create a copy of the list for safe iteration
|
|
if model_config["RAG_RERANKING_MODEL"] == target_model:
|
|
models_list.remove(model_config)
|
|
|
|
request.app.state.config._state["LOADED_RERANKING_MODELS"].save()
|
|
|
|
import gc
|
|
import torch
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
request.app.state.config.TOP_K_RERANKER = (
|
|
form_data.TOP_K_RERANKER
|
|
if form_data.TOP_K_RERANKER is not None
|
|
else request.app.state.config.TOP_K_RERANKER
|
|
)
|
|
request.app.state.config.RELEVANCE_THRESHOLD = (
|
|
form_data.RELEVANCE_THRESHOLD
|
|
if form_data.RELEVANCE_THRESHOLD is not None
|
|
else request.app.state.config.RELEVANCE_THRESHOLD
|
|
)
|
|
request.app.state.config.HYBRID_BM25_WEIGHT = (
|
|
form_data.HYBRID_BM25_WEIGHT
|
|
if form_data.HYBRID_BM25_WEIGHT is not None
|
|
else request.app.state.config.HYBRID_BM25_WEIGHT
|
|
)
|
|
|
|
# Content extraction settings
|
|
request.app.state.config.CONTENT_EXTRACTION_ENGINE = (
|
|
form_data.CONTENT_EXTRACTION_ENGINE
|
|
if form_data.CONTENT_EXTRACTION_ENGINE is not None
|
|
else request.app.state.config.CONTENT_EXTRACTION_ENGINE
|
|
)
|
|
request.app.state.config.PDF_EXTRACT_IMAGES = (
|
|
form_data.PDF_EXTRACT_IMAGES
|
|
if form_data.PDF_EXTRACT_IMAGES is not None
|
|
else request.app.state.config.PDF_EXTRACT_IMAGES
|
|
)
|
|
request.app.state.config.DATALAB_MARKER_API_KEY = (
|
|
form_data.DATALAB_MARKER_API_KEY
|
|
if form_data.DATALAB_MARKER_API_KEY is not None
|
|
else request.app.state.config.DATALAB_MARKER_API_KEY
|
|
)
|
|
request.app.state.config.DATALAB_MARKER_LANGS = (
|
|
form_data.DATALAB_MARKER_LANGS
|
|
if form_data.DATALAB_MARKER_LANGS is not None
|
|
else request.app.state.config.DATALAB_MARKER_LANGS
|
|
)
|
|
request.app.state.config.DATALAB_MARKER_SKIP_CACHE = (
|
|
form_data.DATALAB_MARKER_SKIP_CACHE
|
|
if form_data.DATALAB_MARKER_SKIP_CACHE is not None
|
|
else request.app.state.config.DATALAB_MARKER_SKIP_CACHE
|
|
)
|
|
request.app.state.config.DATALAB_MARKER_FORCE_OCR = (
|
|
form_data.DATALAB_MARKER_FORCE_OCR
|
|
if form_data.DATALAB_MARKER_FORCE_OCR is not None
|
|
else request.app.state.config.DATALAB_MARKER_FORCE_OCR
|
|
)
|
|
request.app.state.config.DATALAB_MARKER_PAGINATE = (
|
|
form_data.DATALAB_MARKER_PAGINATE
|
|
if form_data.DATALAB_MARKER_PAGINATE is not None
|
|
else request.app.state.config.DATALAB_MARKER_PAGINATE
|
|
)
|
|
request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR = (
|
|
form_data.DATALAB_MARKER_STRIP_EXISTING_OCR
|
|
if form_data.DATALAB_MARKER_STRIP_EXISTING_OCR is not None
|
|
else request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR
|
|
)
|
|
request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = (
|
|
form_data.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION
|
|
if form_data.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION is not None
|
|
else request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION
|
|
)
|
|
request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT = (
|
|
form_data.DATALAB_MARKER_OUTPUT_FORMAT
|
|
if form_data.DATALAB_MARKER_OUTPUT_FORMAT is not None
|
|
else request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT
|
|
)
|
|
request.app.state.config.DATALAB_MARKER_USE_LLM = (
|
|
form_data.DATALAB_MARKER_USE_LLM
|
|
if form_data.DATALAB_MARKER_USE_LLM is not None
|
|
else request.app.state.config.DATALAB_MARKER_USE_LLM
|
|
)
|
|
request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL = (
|
|
form_data.EXTERNAL_DOCUMENT_LOADER_URL
|
|
if form_data.EXTERNAL_DOCUMENT_LOADER_URL is not None
|
|
else request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL
|
|
)
|
|
request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY = (
|
|
form_data.EXTERNAL_DOCUMENT_LOADER_API_KEY
|
|
if form_data.EXTERNAL_DOCUMENT_LOADER_API_KEY is not None
|
|
else request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY
|
|
)
|
|
request.app.state.config.TIKA_SERVER_URL = (
|
|
form_data.TIKA_SERVER_URL
|
|
if form_data.TIKA_SERVER_URL is not None
|
|
else request.app.state.config.TIKA_SERVER_URL
|
|
)
|
|
request.app.state.config.DOCLING_SERVER_URL = (
|
|
form_data.DOCLING_SERVER_URL
|
|
if form_data.DOCLING_SERVER_URL is not None
|
|
else request.app.state.config.DOCLING_SERVER_URL
|
|
)
|
|
request.app.state.config.DOCLING_OCR_ENGINE = (
|
|
form_data.DOCLING_OCR_ENGINE
|
|
if form_data.DOCLING_OCR_ENGINE is not None
|
|
else request.app.state.config.DOCLING_OCR_ENGINE
|
|
)
|
|
request.app.state.config.DOCLING_OCR_LANG = (
|
|
form_data.DOCLING_OCR_LANG
|
|
if form_data.DOCLING_OCR_LANG is not None
|
|
else request.app.state.config.DOCLING_OCR_LANG
|
|
)
|
|
|
|
request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION = (
|
|
form_data.DOCLING_DO_PICTURE_DESCRIPTION
|
|
if form_data.DOCLING_DO_PICTURE_DESCRIPTION is not None
|
|
else request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION
|
|
)
|
|
|
|
request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE = (
|
|
form_data.DOCLING_PICTURE_DESCRIPTION_MODE
|
|
if form_data.DOCLING_PICTURE_DESCRIPTION_MODE is not None
|
|
else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE
|
|
)
|
|
request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL = (
|
|
form_data.DOCLING_PICTURE_DESCRIPTION_LOCAL
|
|
if form_data.DOCLING_PICTURE_DESCRIPTION_LOCAL is not None
|
|
else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL
|
|
)
|
|
request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API = (
|
|
form_data.DOCLING_PICTURE_DESCRIPTION_API
|
|
if form_data.DOCLING_PICTURE_DESCRIPTION_API is not None
|
|
else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API
|
|
)
|
|
|
|
request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = (
|
|
form_data.DOCUMENT_INTELLIGENCE_ENDPOINT
|
|
if form_data.DOCUMENT_INTELLIGENCE_ENDPOINT is not None
|
|
else request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT
|
|
)
|
|
request.app.state.config.DOCUMENT_INTELLIGENCE_KEY = (
|
|
form_data.DOCUMENT_INTELLIGENCE_KEY
|
|
if form_data.DOCUMENT_INTELLIGENCE_KEY is not None
|
|
else request.app.state.config.DOCUMENT_INTELLIGENCE_KEY
|
|
)
|
|
request.app.state.config.MISTRAL_OCR_API_KEY = (
|
|
form_data.MISTRAL_OCR_API_KEY
|
|
if form_data.MISTRAL_OCR_API_KEY is not None
|
|
else request.app.state.config.MISTRAL_OCR_API_KEY
|
|
)
|
|
|
|
# Reranking settings
|
|
request.app.state.config.RAG_RERANKING_ENGINE = (
|
|
form_data.RAG_RERANKING_ENGINE
|
|
if form_data.RAG_RERANKING_ENGINE is not None
|
|
else request.app.state.config.RAG_RERANKING_ENGINE
|
|
)
|
|
|
|
request.app.state.config.RAG_EXTERNAL_RERANKER_URL = (
|
|
form_data.RAG_EXTERNAL_RERANKER_URL
|
|
if form_data.RAG_EXTERNAL_RERANKER_URL is not None
|
|
else request.app.state.config.RAG_EXTERNAL_RERANKER_URL
|
|
)
|
|
|
|
request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY = (
|
|
form_data.RAG_EXTERNAL_RERANKER_API_KEY
|
|
if form_data.RAG_EXTERNAL_RERANKER_API_KEY is not None
|
|
else request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY
|
|
)
|
|
|
|
|
|
log.info(
|
|
f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}"
|
|
)
|
|
try:
|
|
request.app.state.config.RAG_RERANKING_MODEL = form_data.RAG_RERANKING_MODEL
|
|
|
|
try:
|
|
if not request.app.state.config.RAG_RERANKING_MODEL in request.app.state.rf and not request.app.state.config.RAG_RERANKING_MODEL == "":
|
|
request.app.state.rf[request.app.state.config.RAG_RERANKING_MODEL] = get_rf(
|
|
request.app.state.config.RAG_RERANKING_ENGINE,
|
|
request.app.state.config.RAG_RERANKING_MODEL,
|
|
request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
|
|
request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
|
True,
|
|
)
|
|
|
|
# add model to state for reloading on startup
|
|
request.app.state.config.LOADED_RERANKING_MODELS[request.app.state.config.RAG_RERANKING_ENGINE].append({
|
|
"RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL,
|
|
"RAG_EXTERNAL_RERANKER_URL": request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
|
|
"RAG_EXTERNAL_RERANKER_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY
|
|
})
|
|
request.app.state.config._state["LOADED_RERANKING_MODELS"].save()
|
|
|
|
# add model to state for selectable reranking models
|
|
if rag_config["RAG_RERANKING_MODEL"] not in request.app.state.config.DOWNLOADED_RERANKING_MODELS[request.app.state.config.RAG_RERANKING_ENGINE]:
|
|
request.app.state.config.DOWNLOADED_RERANKING_MODELS[request.app.state.config.RAG_RERANKING_ENGINE].append(request.app.state.config.RAG_RERANKING_MODEL)
|
|
request.app.state.config._state["DOWNLOADED_RERANKING_MODELS"].save()
|
|
|
|
rag_config["LOADED_RERANKING_MODELS"] = request.app.state.config.LOADED_RERANKING_MODELS
|
|
rag_config["DOWNLOADED_RERANKING_MODELS"] = request.app.state.config.DOWNLOADED_RERANKING_MODELS
|
|
|
|
|
|
except Exception as e:
|
|
log.error(f"Error loading reranking model: {e}")
|
|
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
|
except Exception as e:
|
|
log.exception(f"Problem updating reranking model: {e}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=ERROR_MESSAGES.DEFAULT(e),
|
|
)
|
|
|
|
# Chunking settings
|
|
request.app.state.config.TEXT_SPLITTER = (
|
|
form_data.TEXT_SPLITTER
|
|
if form_data.TEXT_SPLITTER is not None
|
|
else request.app.state.config.TEXT_SPLITTER
|
|
)
|
|
request.app.state.config.CHUNK_SIZE = (
|
|
form_data.CHUNK_SIZE
|
|
if form_data.CHUNK_SIZE is not None
|
|
else request.app.state.config.CHUNK_SIZE
|
|
)
|
|
request.app.state.config.CHUNK_OVERLAP = (
|
|
form_data.CHUNK_OVERLAP
|
|
if form_data.CHUNK_OVERLAP is not None
|
|
else request.app.state.config.CHUNK_OVERLAP
|
|
)
|
|
|
|
# File upload settings
|
|
request.app.state.config.FILE_MAX_SIZE = form_data.FILE_MAX_SIZE
|
|
request.app.state.config.FILE_MAX_COUNT = form_data.FILE_MAX_COUNT
|
|
request.app.state.config.FILE_IMAGE_COMPRESSION_WIDTH = (
|
|
form_data.FILE_IMAGE_COMPRESSION_WIDTH
|
|
)
|
|
request.app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT = (
|
|
form_data.FILE_IMAGE_COMPRESSION_HEIGHT
|
|
)
|
|
request.app.state.config.ALLOWED_FILE_EXTENSIONS = (
|
|
form_data.ALLOWED_FILE_EXTENSIONS
|
|
if form_data.ALLOWED_FILE_EXTENSIONS is not None
|
|
else request.app.state.config.ALLOWED_FILE_EXTENSIONS
|
|
)
|
|
|
|
# Integration settings
|
|
request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = (
|
|
form_data.ENABLE_GOOGLE_DRIVE_INTEGRATION
|
|
if form_data.ENABLE_GOOGLE_DRIVE_INTEGRATION is not None
|
|
else request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION
|
|
)
|
|
request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION = (
|
|
form_data.ENABLE_ONEDRIVE_INTEGRATION
|
|
if form_data.ENABLE_ONEDRIVE_INTEGRATION is not None
|
|
else request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION
|
|
)
|
|
|
|
if form_data.web is not None:
|
|
# Web search settings
|
|
request.app.state.config.ENABLE_WEB_SEARCH = form_data.web.ENABLE_WEB_SEARCH
|
|
request.app.state.config.WEB_SEARCH_ENGINE = form_data.web.WEB_SEARCH_ENGINE
|
|
request.app.state.config.WEB_SEARCH_TRUST_ENV = (
|
|
form_data.web.WEB_SEARCH_TRUST_ENV
|
|
)
|
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT = (
|
|
form_data.web.WEB_SEARCH_RESULT_COUNT
|
|
)
|
|
request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS = (
|
|
form_data.web.WEB_SEARCH_CONCURRENT_REQUESTS
|
|
)
|
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST = (
|
|
form_data.web.WEB_SEARCH_DOMAIN_FILTER_LIST
|
|
)
|
|
request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
|
|
form_data.web.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
|
|
)
|
|
request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = (
|
|
form_data.web.BYPASS_WEB_SEARCH_WEB_LOADER
|
|
)
|
|
request.app.state.config.SEARXNG_QUERY_URL = form_data.web.SEARXNG_QUERY_URL
|
|
request.app.state.config.YACY_QUERY_URL = form_data.web.YACY_QUERY_URL
|
|
request.app.state.config.YACY_USERNAME = form_data.web.YACY_USERNAME
|
|
request.app.state.config.YACY_PASSWORD = form_data.web.YACY_PASSWORD
|
|
request.app.state.config.GOOGLE_PSE_API_KEY = form_data.web.GOOGLE_PSE_API_KEY
|
|
request.app.state.config.GOOGLE_PSE_ENGINE_ID = (
|
|
form_data.web.GOOGLE_PSE_ENGINE_ID
|
|
)
|
|
request.app.state.config.BRAVE_SEARCH_API_KEY = (
|
|
form_data.web.BRAVE_SEARCH_API_KEY
|
|
)
|
|
request.app.state.config.KAGI_SEARCH_API_KEY = form_data.web.KAGI_SEARCH_API_KEY
|
|
request.app.state.config.MOJEEK_SEARCH_API_KEY = (
|
|
form_data.web.MOJEEK_SEARCH_API_KEY
|
|
)
|
|
request.app.state.config.BOCHA_SEARCH_API_KEY = (
|
|
form_data.web.BOCHA_SEARCH_API_KEY
|
|
)
|
|
request.app.state.config.SERPSTACK_API_KEY = form_data.web.SERPSTACK_API_KEY
|
|
request.app.state.config.SERPSTACK_HTTPS = form_data.web.SERPSTACK_HTTPS
|
|
request.app.state.config.SERPER_API_KEY = form_data.web.SERPER_API_KEY
|
|
request.app.state.config.SERPLY_API_KEY = form_data.web.SERPLY_API_KEY
|
|
request.app.state.config.TAVILY_API_KEY = form_data.web.TAVILY_API_KEY
|
|
request.app.state.config.SEARCHAPI_API_KEY = form_data.web.SEARCHAPI_API_KEY
|
|
request.app.state.config.SEARCHAPI_ENGINE = form_data.web.SEARCHAPI_ENGINE
|
|
request.app.state.config.SERPAPI_API_KEY = form_data.web.SERPAPI_API_KEY
|
|
request.app.state.config.SERPAPI_ENGINE = form_data.web.SERPAPI_ENGINE
|
|
request.app.state.config.JINA_API_KEY = form_data.web.JINA_API_KEY
|
|
request.app.state.config.BING_SEARCH_V7_ENDPOINT = (
|
|
form_data.web.BING_SEARCH_V7_ENDPOINT
|
|
)
|
|
request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = (
|
|
form_data.web.BING_SEARCH_V7_SUBSCRIPTION_KEY
|
|
)
|
|
request.app.state.config.EXA_API_KEY = form_data.web.EXA_API_KEY
|
|
request.app.state.config.PERPLEXITY_API_KEY = form_data.web.PERPLEXITY_API_KEY
|
|
request.app.state.config.PERPLEXITY_MODEL = form_data.web.PERPLEXITY_MODEL
|
|
request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = (
|
|
form_data.web.PERPLEXITY_SEARCH_CONTEXT_USAGE
|
|
)
|
|
request.app.state.config.SOUGOU_API_SID = form_data.web.SOUGOU_API_SID
|
|
request.app.state.config.SOUGOU_API_SK = form_data.web.SOUGOU_API_SK
|
|
|
|
# Web loader settings
|
|
request.app.state.config.WEB_LOADER_ENGINE = form_data.web.WEB_LOADER_ENGINE
|
|
request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = (
|
|
form_data.web.ENABLE_WEB_LOADER_SSL_VERIFICATION
|
|
)
|
|
request.app.state.config.PLAYWRIGHT_WS_URL = form_data.web.PLAYWRIGHT_WS_URL
|
|
request.app.state.config.PLAYWRIGHT_TIMEOUT = form_data.web.PLAYWRIGHT_TIMEOUT
|
|
request.app.state.config.FIRECRAWL_API_KEY = form_data.web.FIRECRAWL_API_KEY
|
|
request.app.state.config.FIRECRAWL_API_BASE_URL = (
|
|
form_data.web.FIRECRAWL_API_BASE_URL
|
|
)
|
|
request.app.state.config.EXTERNAL_WEB_SEARCH_URL = (
|
|
form_data.web.EXTERNAL_WEB_SEARCH_URL
|
|
)
|
|
request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY = (
|
|
form_data.web.EXTERNAL_WEB_SEARCH_API_KEY
|
|
)
|
|
request.app.state.config.EXTERNAL_WEB_LOADER_URL = (
|
|
form_data.web.EXTERNAL_WEB_LOADER_URL
|
|
)
|
|
request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY = (
|
|
form_data.web.EXTERNAL_WEB_LOADER_API_KEY
|
|
)
|
|
request.app.state.config.TAVILY_EXTRACT_DEPTH = (
|
|
form_data.web.TAVILY_EXTRACT_DEPTH
|
|
)
|
|
request.app.state.config.YOUTUBE_LOADER_LANGUAGE = (
|
|
form_data.web.YOUTUBE_LOADER_LANGUAGE
|
|
)
|
|
request.app.state.config.YOUTUBE_LOADER_PROXY_URL = (
|
|
form_data.web.YOUTUBE_LOADER_PROXY_URL
|
|
)
|
|
request.app.state.YOUTUBE_LOADER_TRANSLATION = (
|
|
form_data.web.YOUTUBE_LOADER_TRANSLATION
|
|
)
|
|
|
|
return {
|
|
"status": True,
|
|
# RAG settings
|
|
"RAG_TEMPLATE": request.app.state.config.RAG_TEMPLATE,
|
|
"TOP_K": request.app.state.config.TOP_K,
|
|
"BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL,
|
|
"RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
|
|
# Hybrid search settings
|
|
"ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
|
"TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER,
|
|
"RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD,
|
|
"HYBRID_BM25_WEIGHT": request.app.state.config.HYBRID_BM25_WEIGHT,
|
|
# Content extraction settings
|
|
"CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
|
"PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES,
|
|
"DATALAB_MARKER_API_KEY": request.app.state.config.DATALAB_MARKER_API_KEY,
|
|
"DATALAB_MARKER_LANGS": request.app.state.config.DATALAB_MARKER_LANGS,
|
|
"DATALAB_MARKER_SKIP_CACHE": request.app.state.config.DATALAB_MARKER_SKIP_CACHE,
|
|
"DATALAB_MARKER_FORCE_OCR": request.app.state.config.DATALAB_MARKER_FORCE_OCR,
|
|
"DATALAB_MARKER_PAGINATE": request.app.state.config.DATALAB_MARKER_PAGINATE,
|
|
"DATALAB_MARKER_STRIP_EXISTING_OCR": request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR,
|
|
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION": request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
|
|
"DATALAB_MARKER_USE_LLM": request.app.state.config.DATALAB_MARKER_USE_LLM,
|
|
"DATALAB_MARKER_OUTPUT_FORMAT": request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT,
|
|
"EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
|
|
"EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
|
|
"TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL,
|
|
"DOCLING_SERVER_URL": request.app.state.config.DOCLING_SERVER_URL,
|
|
"DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE,
|
|
"DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG,
|
|
"DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
|
|
"DOCLING_PICTURE_DESCRIPTION_MODE": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE,
|
|
"DOCLING_PICTURE_DESCRIPTION_LOCAL": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL,
|
|
"DOCLING_PICTURE_DESCRIPTION_API": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API,
|
|
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
|
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
|
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
|
|
# Reranking settings
|
|
"RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL,
|
|
"RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE,
|
|
"RAG_EXTERNAL_RERANKER_URL": request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
|
|
"RAG_EXTERNAL_RERANKER_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
|
# Chunking settings
|
|
"TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER,
|
|
"CHUNK_SIZE": request.app.state.config.CHUNK_SIZE,
|
|
"CHUNK_OVERLAP": request.app.state.config.CHUNK_OVERLAP,
|
|
# File upload settings
|
|
"FILE_MAX_SIZE": request.app.state.config.FILE_MAX_SIZE,
|
|
"FILE_MAX_COUNT": request.app.state.config.FILE_MAX_COUNT,
|
|
"FILE_IMAGE_COMPRESSION_WIDTH": request.app.state.config.FILE_IMAGE_COMPRESSION_WIDTH,
|
|
"FILE_IMAGE_COMPRESSION_HEIGHT": request.app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT,
|
|
"ALLOWED_FILE_EXTENSIONS": request.app.state.config.ALLOWED_FILE_EXTENSIONS,
|
|
# Integration settings
|
|
"ENABLE_GOOGLE_DRIVE_INTEGRATION": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
|
"ENABLE_ONEDRIVE_INTEGRATION": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
|
|
# Web search settings
|
|
"web": {
|
|
"ENABLE_WEB_SEARCH": request.app.state.config.ENABLE_WEB_SEARCH,
|
|
"WEB_SEARCH_ENGINE": request.app.state.config.WEB_SEARCH_ENGINE,
|
|
"WEB_SEARCH_TRUST_ENV": request.app.state.config.WEB_SEARCH_TRUST_ENV,
|
|
"WEB_SEARCH_RESULT_COUNT": request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
|
"WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
|
|
"WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
|
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
|
"BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER,
|
|
"SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL,
|
|
"YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL,
|
|
"YACY_USERNAME": request.app.state.config.YACY_USERNAME,
|
|
"YACY_PASSWORD": request.app.state.config.YACY_PASSWORD,
|
|
"GOOGLE_PSE_API_KEY": request.app.state.config.GOOGLE_PSE_API_KEY,
|
|
"GOOGLE_PSE_ENGINE_ID": request.app.state.config.GOOGLE_PSE_ENGINE_ID,
|
|
"BRAVE_SEARCH_API_KEY": request.app.state.config.BRAVE_SEARCH_API_KEY,
|
|
"KAGI_SEARCH_API_KEY": request.app.state.config.KAGI_SEARCH_API_KEY,
|
|
"MOJEEK_SEARCH_API_KEY": request.app.state.config.MOJEEK_SEARCH_API_KEY,
|
|
"BOCHA_SEARCH_API_KEY": request.app.state.config.BOCHA_SEARCH_API_KEY,
|
|
"SERPSTACK_API_KEY": request.app.state.config.SERPSTACK_API_KEY,
|
|
"SERPSTACK_HTTPS": request.app.state.config.SERPSTACK_HTTPS,
|
|
"SERPER_API_KEY": request.app.state.config.SERPER_API_KEY,
|
|
"SERPLY_API_KEY": request.app.state.config.SERPLY_API_KEY,
|
|
"TAVILY_API_KEY": request.app.state.config.TAVILY_API_KEY,
|
|
"SEARCHAPI_API_KEY": request.app.state.config.SEARCHAPI_API_KEY,
|
|
"SEARCHAPI_ENGINE": request.app.state.config.SEARCHAPI_ENGINE,
|
|
"SERPAPI_API_KEY": request.app.state.config.SERPAPI_API_KEY,
|
|
"SERPAPI_ENGINE": request.app.state.config.SERPAPI_ENGINE,
|
|
"JINA_API_KEY": request.app.state.config.JINA_API_KEY,
|
|
"BING_SEARCH_V7_ENDPOINT": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
|
|
"BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
|
"EXA_API_KEY": request.app.state.config.EXA_API_KEY,
|
|
"PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY,
|
|
"PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL,
|
|
"PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE,
|
|
"SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID,
|
|
"SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK,
|
|
"WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE,
|
|
"ENABLE_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
|
"PLAYWRIGHT_WS_URL": request.app.state.config.PLAYWRIGHT_WS_URL,
|
|
"PLAYWRIGHT_TIMEOUT": request.app.state.config.PLAYWRIGHT_TIMEOUT,
|
|
"FIRECRAWL_API_KEY": request.app.state.config.FIRECRAWL_API_KEY,
|
|
"FIRECRAWL_API_BASE_URL": request.app.state.config.FIRECRAWL_API_BASE_URL,
|
|
"TAVILY_EXTRACT_DEPTH": request.app.state.config.TAVILY_EXTRACT_DEPTH,
|
|
"EXTERNAL_WEB_SEARCH_URL": request.app.state.config.EXTERNAL_WEB_SEARCH_URL,
|
|
"EXTERNAL_WEB_SEARCH_API_KEY": request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY,
|
|
"EXTERNAL_WEB_LOADER_URL": request.app.state.config.EXTERNAL_WEB_LOADER_URL,
|
|
"EXTERNAL_WEB_LOADER_API_KEY": request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY,
|
|
"YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
|
|
"YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
|
"YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION,
|
|
},
|
|
"DEFAULT_RAG_SETTINGS": request.app.state.config.DEFAULT_RAG_SETTINGS
|
|
}
|
|
|
|
|
|
####################################
|
|
#
|
|
# Document process and retrieval
|
|
#
|
|
####################################
|
|
|
|
|
|
def save_docs_to_vector_db(
|
|
request: Request,
|
|
docs,
|
|
collection_name,
|
|
metadata: Optional[dict] = None,
|
|
overwrite: bool = False,
|
|
split: bool = True,
|
|
add: bool = False,
|
|
user=None,
|
|
knowledge_id: Optional[str] = None
|
|
) -> bool:
|
|
def _get_docs_info(docs: list[Document]) -> str:
|
|
docs_info = set()
|
|
|
|
# Trying to select relevant metadata identifying the document.
|
|
for doc in docs:
|
|
metadata = getattr(doc, "metadata", {})
|
|
doc_name = metadata.get("name", "")
|
|
if not doc_name:
|
|
doc_name = metadata.get("title", "")
|
|
if not doc_name:
|
|
doc_name = metadata.get("source", "")
|
|
if doc_name:
|
|
docs_info.add(doc_name)
|
|
|
|
return ", ".join(docs_info)
|
|
|
|
log.info(
|
|
f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}"
|
|
)
|
|
|
|
rag_config = {}
|
|
# Retrieve the knowledge base using the collection_name
|
|
if knowledge_id:
|
|
knowledge_base = Knowledges.get_knowledge_by_id(knowledge_id)
|
|
# Retrieve the RAG configuration
|
|
if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
|
rag_config = knowledge_base.rag_config
|
|
|
|
# Use knowledge-base-specific or default configurations
|
|
text_splitter_type = rag_config.get("TEXT_SPLITTER", request.app.state.config.TEXT_SPLITTER)
|
|
chunk_size = rag_config.get("CHUNK_SIZE", request.app.state.config.CHUNK_SIZE)
|
|
chunk_overlap = rag_config.get("CHUNK_OVERLAP", request.app.state.config.CHUNK_OVERLAP)
|
|
embedding_engine = rag_config.get("embedding_engine", request.app.state.config.RAG_EMBEDDING_ENGINE)
|
|
embedding_model = rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL)
|
|
embedding_batch_size = rag_config.get("embedding_batch_size", request.app.state.config.RAG_EMBEDDING_BATCH_SIZE)
|
|
openai_api_base_url = rag_config.get("openai_config", {}).get("url", request.app.state.config.RAG_OPENAI_API_BASE_URL)
|
|
openai_api_key = rag_config.get("openai_config", {}).get("url", request.app.state.config.RAG_OPENAI_API_KEY)
|
|
ollama_base_url = rag_config.get("ollama_config", {}).get("url", request.app.state.config.RAG_OLLAMA_BASE_URL)
|
|
ollama_api_key = rag_config.get("ollama_config", {}).get("key", request.app.state.config.RAG_OLLAMA_API_KEY)
|
|
azure_openai_url = rag_config.get("azure_openai", {}).get("url", request.app.state.config.RAG_AZURE_OPENAI_BASE_URL)
|
|
azure_openai_key = rag_config.get("azure_openai", {}).get("key", request.app.state.config.RAG_AZURE_OPENAI_BASE_URL)
|
|
azure_openai_version = rag_config.get("azure_openai", {}).get("version", request.app.state.config.RAG_AZURE_OPENAI_BASE_URL)
|
|
|
|
# Check if entries with the same hash (metadata.hash) already exist
|
|
if metadata and "hash" in metadata:
|
|
result = VECTOR_DB_CLIENT.query(
|
|
collection_name=collection_name,
|
|
filter={"hash": metadata["hash"]},
|
|
)
|
|
|
|
if result is not None:
|
|
existing_doc_ids = result.ids[0]
|
|
if existing_doc_ids:
|
|
log.info(f"Document with hash {metadata['hash']} already exists")
|
|
raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT)
|
|
|
|
if split:
|
|
if text_splitter_type in ["", "character"]:
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=chunk_size,
|
|
chunk_overlap=chunk_overlap,
|
|
add_start_index=True,
|
|
)
|
|
elif text_splitter_type == "token":
|
|
log.info(
|
|
f"Using token text splitter: {request.app.state.config.TIKTOKEN_ENCODING_NAME}"
|
|
)
|
|
|
|
tiktoken.get_encoding(str(request.app.state.config.TIKTOKEN_ENCODING_NAME))
|
|
text_splitter = TokenTextSplitter(
|
|
encoding_name=str(request.app.state.config.TIKTOKEN_ENCODING_NAME),
|
|
chunk_size=chunk_size,
|
|
chunk_overlap=chunk_overlap,
|
|
add_start_index=True,
|
|
)
|
|
else:
|
|
raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter"))
|
|
|
|
docs = text_splitter.split_documents(docs)
|
|
|
|
if len(docs) == 0:
|
|
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
|
|
|
|
texts = [doc.page_content for doc in docs]
|
|
metadatas = [
|
|
{
|
|
**doc.metadata,
|
|
**(metadata if metadata else {}),
|
|
"embedding_config": json.dumps(
|
|
{
|
|
"engine": embedding_engine,
|
|
"model": embedding_model,
|
|
}
|
|
),
|
|
}
|
|
for doc in docs
|
|
]
|
|
|
|
# ChromaDB does not like datetime formats
|
|
# for meta-data so convert them to string.
|
|
for metadata in metadatas:
|
|
for key, value in metadata.items():
|
|
if (
|
|
isinstance(value, datetime)
|
|
or isinstance(value, list)
|
|
or isinstance(value, dict)
|
|
):
|
|
metadata[key] = str(value)
|
|
|
|
try:
|
|
if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
|
|
log.info(f"collection {collection_name} already exists")
|
|
|
|
if overwrite:
|
|
VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
|
|
log.info(f"deleting existing collection {collection_name}")
|
|
elif add is False:
|
|
log.info(
|
|
f"collection {collection_name} already exists, overwrite is False and add is False"
|
|
)
|
|
return True
|
|
|
|
log.info(f"adding to collection {collection_name}")
|
|
embedding_function = get_embedding_function(
|
|
embedding_engine,
|
|
embedding_model,
|
|
request.app.state.ef[embedding_model],
|
|
(
|
|
openai_api_base_url
|
|
if embedding_engine == "openai"
|
|
else (
|
|
ollama_base_url
|
|
if embedding_engine == "ollama"
|
|
else azure_openai_url
|
|
)
|
|
),
|
|
(
|
|
openai_api_key
|
|
if embedding_engine == "openai"
|
|
else ollama_api_key
|
|
),
|
|
embedding_batch_size,
|
|
azure_api_version=(
|
|
azure_openai_version
|
|
if embedding_engine == "azure_openai"
|
|
else None
|
|
),
|
|
)
|
|
|
|
embeddings = embedding_function(
|
|
list(map(lambda x: x.replace("\n", " "), texts)),
|
|
prefix=RAG_EMBEDDING_CONTENT_PREFIX,
|
|
user=user,
|
|
)
|
|
|
|
items = [
|
|
{
|
|
"id": str(uuid.uuid4()),
|
|
"text": text,
|
|
"vector": embeddings[idx],
|
|
"metadata": metadatas[idx],
|
|
}
|
|
for idx, text in enumerate(texts)
|
|
]
|
|
|
|
VECTOR_DB_CLIENT.insert(
|
|
collection_name=collection_name,
|
|
items=items,
|
|
)
|
|
|
|
return True
|
|
except Exception as e:
|
|
log.exception(e)
|
|
raise e
|
|
|
|
|
|
class ProcessFileForm(BaseModel):
|
|
file_id: str
|
|
content: Optional[str] = None
|
|
collection_name: Optional[str] = None
|
|
knowledge_id: Optional[str] = None
|
|
|
|
|
|
@router.post("/process/file")
|
|
def process_file(
|
|
request: Request,
|
|
form_data: ProcessFileForm,
|
|
user=Depends(get_verified_user),
|
|
):
|
|
try:
|
|
file = Files.get_file_by_id(form_data.file_id)
|
|
|
|
collection_name = form_data.collection_name
|
|
|
|
if collection_name is None:
|
|
collection_name = f"file-{file.id}"
|
|
|
|
rag_config = {}
|
|
# Retrieve the knowledge base using the collection id - knowledge_id == collection_name (minimal working solution without logic changes)
|
|
if form_data.collection_name:
|
|
knowledge_base = Knowledges.get_knowledge_by_id(form_data.collection_name)
|
|
|
|
# Retrieve the RAG configuration
|
|
if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
|
rag_config = knowledge_base.rag_config
|
|
form_data.knowledge_id = collection_name # fallback for save_docs_to_vector_db
|
|
|
|
elif form_data.knowledge_id:
|
|
knowledge_base = Knowledges.get_knowledge_by_id(form_data.knowledge_id)
|
|
|
|
# Retrieve the RAG configuration
|
|
if not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
|
rag_config = knowledge_base.rag_config
|
|
|
|
# Use knowledge-base-specific or default configurations
|
|
content_extraction_engine = rag_config.get(
|
|
"CONTENT_EXTRACTION_ENGINE", request.app.state.config.CONTENT_EXTRACTION_ENGINE
|
|
)
|
|
datalab_marker_api_key=rag_config.get(
|
|
"DATALAB_MARKER_API_KEY", request.app.state.config.DATALAB_MARKER_API_KEY
|
|
)
|
|
datalab_marker_langs=rag_config.get(
|
|
"DATALAB_MARKER_LANGS", request.app.state.config.DATALAB_MARKER_LANGS
|
|
)
|
|
datalab_marker_skip_cache=rag_config.get(
|
|
"DATALAB_MARKER_SKIP_CACHE", request.app.state.config.DATALAB_MARKER_SKIP_CACHE
|
|
)
|
|
datalab_marker_force_ocr=rag_config.get(
|
|
"DATALAB_MARKER_FORCE_OCR", request.app.state.config.DATALAB_MARKER_FORCE_OCR
|
|
)
|
|
datalab_marker_paginate=rag_config.get(
|
|
"DATALAB_MARKER_PAGINATE", request.app.state.config.DATALAB_MARKER_PAGINATE
|
|
)
|
|
datalab_marker_strip_existing_ocr=rag_config.get(
|
|
"DATALAB_MARKER_STRIP_EXISTING_OCR", request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR
|
|
)
|
|
datalab_marker_disable_image_extraction=rag_config.get(
|
|
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION
|
|
)
|
|
datalab_marker_use_llm=rag_config.get(
|
|
"DATALAB_MARKER_USE_LLM", request.app.state.config.DATALAB_MARKER_USE_LLM
|
|
)
|
|
datalab_marker_output_format=rag_config.get(
|
|
"DATALAB_MARKER_OUTPUT_FORMAT", request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT
|
|
)
|
|
external_document_loader_url = rag_config.get(
|
|
"EXTERNAL_DOCUMENT_LOADER_URL", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL
|
|
)
|
|
external_document_loader_api_key = rag_config.get(
|
|
"EXTERNAL_DOCUMENT_LOADER_API_KEY", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY
|
|
)
|
|
tika_server_url = rag_config.get(
|
|
"TIKA_SERVER_URL", request.app.state.config.TIKA_SERVER_URL
|
|
)
|
|
docling_server_url = rag_config.get(
|
|
"DOCLING_SERVER_URL", request.app.state.config.DOCLING_SERVER_URL
|
|
)
|
|
docling_ocr_engine=rag_config.get(
|
|
"DOCLING_OCR_ENGINE", request.app.state.config.DOCLING_OCR_ENGINE
|
|
)
|
|
docling_ocr_lang=rag_config.get(
|
|
"DOCLING_OCR_LANG", request.app.state.config.DOCLING_OCR_LANG
|
|
)
|
|
docling_do_picture_description=rag_config.get(
|
|
"DOCLING_DO_PICTURE_DESCRIPTION", request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION
|
|
)
|
|
picture_description_mode = rag_config.get(
|
|
"PICTURE_DESCRIPTION_MODE", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE
|
|
)
|
|
picture_description_local = rag_config.get(
|
|
"PICTURE_DESCRIPTION_MODE", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL
|
|
)
|
|
picture_description_api = rag_config.get(
|
|
"PICTURE_DESCRIPTION_API", request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API
|
|
)
|
|
pdf_extract_images = rag_config.get(
|
|
"PDF_EXTRACT_IMAGES", request.app.state.config.PDF_EXTRACT_IMAGES
|
|
)
|
|
document_intelligence_endpoint = rag_config.get(
|
|
"DOCUMENT_INTELLIGENCE_ENDPOINT", request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT
|
|
)
|
|
document_intelligence_key = rag_config.get(
|
|
"DOCUMENT_INTELLIGENCE_KEY", request.app.state.config.DOCUMENT_INTELLIGENCE_KEY
|
|
)
|
|
mistral_ocr_api_key = rag_config.get(
|
|
"MISTRAL_OCR_API_KEY", request.app.state.config.MISTRAL_OCR_API_KEY
|
|
)
|
|
|
|
if form_data.content:
|
|
# Update the content in the file
|
|
# Usage: /files/{file_id}/data/content/update, /files/ (audio file upload pipeline)
|
|
|
|
try:
|
|
# /files/{file_id}/data/content/update
|
|
VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}")
|
|
except:
|
|
# Audio file upload pipeline
|
|
pass
|
|
|
|
docs = [
|
|
Document(
|
|
page_content=form_data.content.replace("<br/>", "\n"),
|
|
metadata={
|
|
**file.meta,
|
|
"name": file.filename,
|
|
"created_by": file.user_id,
|
|
"file_id": file.id,
|
|
"source": file.filename,
|
|
},
|
|
)
|
|
]
|
|
|
|
text_content = form_data.content
|
|
elif form_data.collection_name:
|
|
# Check if the file has already been processed and save the content
|
|
# Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update
|
|
|
|
result = VECTOR_DB_CLIENT.query(
|
|
collection_name=f"file-{file.id}", filter={"file_id": file.id}
|
|
)
|
|
|
|
if result is not None and len(result.ids[0]) > 0:
|
|
docs = [
|
|
Document(
|
|
page_content=result.documents[0][idx],
|
|
metadata=result.metadatas[0][idx],
|
|
)
|
|
for idx, id in enumerate(result.ids[0])
|
|
]
|
|
else:
|
|
docs = [
|
|
Document(
|
|
page_content=file.data.get("content", ""),
|
|
metadata={
|
|
**file.meta,
|
|
"name": file.filename,
|
|
"created_by": file.user_id,
|
|
"file_id": file.id,
|
|
"source": file.filename,
|
|
},
|
|
)
|
|
]
|
|
|
|
text_content = file.data.get("content", "")
|
|
else:
|
|
# Process the file and save the content
|
|
# Usage: /files/
|
|
file_path = file.path
|
|
if file_path:
|
|
file_path = Storage.get_file(file_path)
|
|
loader = Loader(
|
|
engine=content_extraction_engine,
|
|
DATALAB_MARKER_API_KEY=datalab_marker_api_key,
|
|
DATALAB_MARKER_LANGS=datalab_marker_langs,
|
|
DATALAB_MARKER_SKIP_CACHE=datalab_marker_skip_cache,
|
|
DATALAB_MARKER_FORCE_OCR=datalab_marker_force_ocr,
|
|
DATALAB_MARKER_PAGINATE=datalab_marker_paginate,
|
|
DATALAB_MARKER_STRIP_EXISTING_OCR=datalab_marker_strip_existing_ocr,
|
|
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION=datalab_marker_disable_image_extraction,
|
|
DATALAB_MARKER_USE_LLM=datalab_marker_use_llm,
|
|
DATALAB_MARKER_OUTPUT_FORMAT=datalab_marker_output_format,
|
|
EXTERNAL_DOCUMENT_LOADER_URL=external_document_loader_url,
|
|
EXTERNAL_DOCUMENT_LOADER_API_KEY=external_document_loader_api_key,
|
|
TIKA_SERVER_URL=tika_server_url,
|
|
DOCLING_SERVER_URL=docling_server_url,
|
|
DOCLING_PARAMS={
|
|
"ocr_engine": docling_ocr_engine,
|
|
"ocr_lang": docling_ocr_lang,
|
|
"do_picture_description": docling_do_picture_description,
|
|
"picture_description_mode": picture_description_mode,
|
|
"picture_description_local": picture_description_local,
|
|
"picture_description_api": picture_description_api,
|
|
},
|
|
PDF_EXTRACT_IMAGES=pdf_extract_images,
|
|
DOCUMENT_INTELLIGENCE_ENDPOINT=document_intelligence_endpoint,
|
|
DOCUMENT_INTELLIGENCE_KEY=document_intelligence_key,
|
|
MISTRAL_OCR_API_KEY=mistral_ocr_api_key,
|
|
)
|
|
docs = loader.load(
|
|
file.filename, file.meta.get("content_type"), file_path
|
|
)
|
|
|
|
docs = [
|
|
Document(
|
|
page_content=doc.page_content,
|
|
metadata={
|
|
**doc.metadata,
|
|
"name": file.filename,
|
|
"created_by": file.user_id,
|
|
"file_id": file.id,
|
|
"source": file.filename,
|
|
},
|
|
)
|
|
for doc in docs
|
|
]
|
|
else:
|
|
docs = [
|
|
Document(
|
|
page_content=file.data.get("content", ""),
|
|
metadata={
|
|
**file.meta,
|
|
"name": file.filename,
|
|
"created_by": file.user_id,
|
|
"file_id": file.id,
|
|
"source": file.filename,
|
|
},
|
|
)
|
|
]
|
|
text_content = " ".join([doc.page_content for doc in docs])
|
|
|
|
log.debug(f"text_content: {text_content}")
|
|
Files.update_file_data_by_id(
|
|
file.id,
|
|
{"content": text_content},
|
|
)
|
|
|
|
hash = calculate_sha256_string(text_content)
|
|
Files.update_file_hash_by_id(file.id, hash)
|
|
|
|
if not rag_config.get("BYPASS_EMBEDDING_AND_RETRIEVAL", request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL):
|
|
try:
|
|
result = save_docs_to_vector_db(
|
|
request,
|
|
docs=docs,
|
|
collection_name=collection_name,
|
|
metadata={
|
|
"file_id": file.id,
|
|
"name": file.filename,
|
|
"hash": hash,
|
|
},
|
|
add=(True if form_data.collection_name else False),
|
|
user=user,
|
|
knowledge_id=form_data.knowledge_id
|
|
)
|
|
|
|
if result:
|
|
Files.update_file_metadata_by_id(
|
|
file.id,
|
|
{
|
|
"collection_name": collection_name,
|
|
},
|
|
)
|
|
|
|
return {
|
|
"status": True,
|
|
"collection_name": collection_name,
|
|
"filename": file.filename,
|
|
"content": text_content,
|
|
}
|
|
except Exception as e:
|
|
raise e
|
|
else:
|
|
return {
|
|
"status": True,
|
|
"collection_name": None,
|
|
"filename": file.filename,
|
|
"content": text_content,
|
|
}
|
|
|
|
except Exception as e:
|
|
log.exception(e)
|
|
if "No pandoc was found" in str(e):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
|
|
)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e),
|
|
)
|
|
|
|
|
|
class ProcessTextForm(BaseModel):
|
|
name: str
|
|
content: str
|
|
collection_name: Optional[str] = None
|
|
|
|
|
|
@router.post("/process/text")
|
|
def process_text(
|
|
request: Request,
|
|
form_data: ProcessTextForm,
|
|
user=Depends(get_verified_user)
|
|
):
|
|
collection_name = form_data.collection_name
|
|
if collection_name is None:
|
|
collection_name = calculate_sha256_string(form_data.content)
|
|
|
|
docs = [
|
|
Document(
|
|
page_content=form_data.content,
|
|
metadata={"name": form_data.name, "created_by": user.id},
|
|
)
|
|
]
|
|
text_content = form_data.content
|
|
log.debug(f"text_content: {text_content}")
|
|
|
|
result = save_docs_to_vector_db(request, docs, collection_name, user=user)
|
|
if result:
|
|
return {
|
|
"status": True,
|
|
"collection_name": collection_name,
|
|
"content": text_content,
|
|
}
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=ERROR_MESSAGES.DEFAULT(),
|
|
)
|
|
|
|
|
|
@router.post("/process/youtube")
|
|
def process_youtube_video(
|
|
request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user)
|
|
):
|
|
try:
|
|
collection_name = form_data.collection_name
|
|
if not collection_name:
|
|
collection_name = calculate_sha256_string(form_data.url)[:63]
|
|
|
|
loader = YoutubeLoader(
|
|
form_data.url,
|
|
language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
|
|
proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
|
)
|
|
|
|
docs = loader.load()
|
|
content = " ".join([doc.page_content for doc in docs])
|
|
log.debug(f"text_content: {content}")
|
|
|
|
save_docs_to_vector_db(
|
|
request, docs, collection_name, overwrite=True, user=user
|
|
)
|
|
|
|
return {
|
|
"status": True,
|
|
"collection_name": collection_name,
|
|
"filename": form_data.url,
|
|
"file": {
|
|
"data": {
|
|
"content": content,
|
|
},
|
|
"meta": {
|
|
"name": form_data.url,
|
|
},
|
|
},
|
|
}
|
|
except Exception as e:
|
|
log.exception(e)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=ERROR_MESSAGES.DEFAULT(e),
|
|
)
|
|
|
|
|
|
@router.post("/process/web")
|
|
def process_web(
|
|
request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user)
|
|
):
|
|
try:
|
|
collection_name = form_data.collection_name
|
|
if not collection_name:
|
|
collection_name = calculate_sha256_string(form_data.url)[:63]
|
|
|
|
loader = get_web_loader(
|
|
form_data.url,
|
|
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
|
requests_per_second=request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
|
|
)
|
|
docs = loader.load()
|
|
content = " ".join([doc.page_content for doc in docs])
|
|
|
|
log.debug(f"text_content: {content}")
|
|
|
|
if not request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
|
|
save_docs_to_vector_db(
|
|
request, docs, collection_name, overwrite=True, user=user
|
|
)
|
|
else:
|
|
collection_name = None
|
|
|
|
return {
|
|
"status": True,
|
|
"collection_name": collection_name,
|
|
"filename": form_data.url,
|
|
"file": {
|
|
"data": {
|
|
"content": content,
|
|
},
|
|
"meta": {
|
|
"name": form_data.url,
|
|
"source": form_data.url,
|
|
},
|
|
},
|
|
}
|
|
except Exception as e:
|
|
log.exception(e)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=ERROR_MESSAGES.DEFAULT(e),
|
|
)
|
|
|
|
|
|
def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
|
"""Search the web using a search engine and return the results as a list of SearchResult objects.
|
|
Will look for a search engine API key in environment variables in the following order:
|
|
- SEARXNG_QUERY_URL
|
|
- YACY_QUERY_URL + YACY_USERNAME + YACY_PASSWORD
|
|
- GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
|
|
- BRAVE_SEARCH_API_KEY
|
|
- KAGI_SEARCH_API_KEY
|
|
- MOJEEK_SEARCH_API_KEY
|
|
- BOCHA_SEARCH_API_KEY
|
|
- SERPSTACK_API_KEY
|
|
- SERPER_API_KEY
|
|
- SERPLY_API_KEY
|
|
- TAVILY_API_KEY
|
|
- EXA_API_KEY
|
|
- PERPLEXITY_API_KEY
|
|
- SOUGOU_API_SID + SOUGOU_API_SK
|
|
- SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`)
|
|
- SERPAPI_API_KEY + SERPAPI_ENGINE (by default `google`)
|
|
Args:
|
|
query (str): The query to search for
|
|
"""
|
|
|
|
# TODO: add playwright to search the web
|
|
if engine == "searxng":
|
|
if request.app.state.config.SEARXNG_QUERY_URL:
|
|
return search_searxng(
|
|
request.app.state.config.SEARXNG_QUERY_URL,
|
|
query,
|
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
|
)
|
|
else:
|
|
raise Exception("No SEARXNG_QUERY_URL found in environment variables")
|
|
elif engine == "yacy":
|
|
if request.app.state.config.YACY_QUERY_URL:
|
|
return search_yacy(
|
|
request.app.state.config.YACY_QUERY_URL,
|
|
request.app.state.config.YACY_USERNAME,
|
|
request.app.state.config.YACY_PASSWORD,
|
|
query,
|
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
|
)
|
|
else:
|
|
raise Exception("No YACY_QUERY_URL found in environment variables")
|
|
elif engine == "google_pse":
|
|
if (
|
|
request.app.state.config.GOOGLE_PSE_API_KEY
|
|
and request.app.state.config.GOOGLE_PSE_ENGINE_ID
|
|
):
|
|
return search_google_pse(
|
|
request.app.state.config.GOOGLE_PSE_API_KEY,
|
|
request.app.state.config.GOOGLE_PSE_ENGINE_ID,
|
|
query,
|
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
|
)
|
|
else:
|
|
raise Exception(
|
|
"No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
|
|
)
|
|
elif engine == "brave":
|
|
if request.app.state.config.BRAVE_SEARCH_API_KEY:
|
|
return search_brave(
|
|
request.app.state.config.BRAVE_SEARCH_API_KEY,
|
|
query,
|
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
|
)
|
|
else:
|
|
raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
|
|
elif engine == "kagi":
|
|
if request.app.state.config.KAGI_SEARCH_API_KEY:
|
|
return search_kagi(
|
|
request.app.state.config.KAGI_SEARCH_API_KEY,
|
|
query,
|
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
|
)
|
|
else:
|
|
raise Exception("No KAGI_SEARCH_API_KEY found in environment variables")
|
|
elif engine == "mojeek":
|
|
if request.app.state.config.MOJEEK_SEARCH_API_KEY:
|
|
return search_mojeek(
|
|
request.app.state.config.MOJEEK_SEARCH_API_KEY,
|
|
query,
|
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
|
)
|
|
else:
|
|
raise Exception("No MOJEEK_SEARCH_API_KEY found in environment variables")
|
|
elif engine == "bocha":
|
|
if request.app.state.config.BOCHA_SEARCH_API_KEY:
|
|
return search_bocha(
|
|
request.app.state.config.BOCHA_SEARCH_API_KEY,
|
|
query,
|
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
|
)
|
|
else:
|
|
raise Exception("No BOCHA_SEARCH_API_KEY found in environment variables")
|
|
elif engine == "serpstack":
|
|
if request.app.state.config.SERPSTACK_API_KEY:
|
|
return search_serpstack(
|
|
request.app.state.config.SERPSTACK_API_KEY,
|
|
query,
|
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
|
https_enabled=request.app.state.config.SERPSTACK_HTTPS,
|
|
)
|
|
else:
|
|
raise Exception("No SERPSTACK_API_KEY found in environment variables")
|
|
elif engine == "serper":
|
|
if request.app.state.config.SERPER_API_KEY:
|
|
return search_serper(
|
|
request.app.state.config.SERPER_API_KEY,
|
|
query,
|
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
|
)
|
|
else:
|
|
raise Exception("No SERPER_API_KEY found in environment variables")
|
|
elif engine == "serply":
|
|
if request.app.state.config.SERPLY_API_KEY:
|
|
return search_serply(
|
|
request.app.state.config.SERPLY_API_KEY,
|
|
query,
|
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
|
)
|
|
else:
|
|
raise Exception("No SERPLY_API_KEY found in environment variables")
|
|
elif engine == "duckduckgo":
|
|
return search_duckduckgo(
|
|
query,
|
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
|
)
|
|
elif engine == "tavily":
|
|
if request.app.state.config.TAVILY_API_KEY:
|
|
return search_tavily(
|
|
request.app.state.config.TAVILY_API_KEY,
|
|
query,
|
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
|
)
|
|
else:
|
|
raise Exception("No TAVILY_API_KEY found in environment variables")
|
|
elif engine == "exa":
|
|
if request.app.state.config.EXA_API_KEY:
|
|
return search_exa(
|
|
request.app.state.config.EXA_API_KEY,
|
|
query,
|
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
|
)
|
|
else:
|
|
raise Exception("No EXA_API_KEY found in environment variables")
|
|
elif engine == "searchapi":
|
|
if request.app.state.config.SEARCHAPI_API_KEY:
|
|
return search_searchapi(
|
|
request.app.state.config.SEARCHAPI_API_KEY,
|
|
request.app.state.config.SEARCHAPI_ENGINE,
|
|
query,
|
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
|
)
|
|
else:
|
|
raise Exception("No SEARCHAPI_API_KEY found in environment variables")
|
|
elif engine == "serpapi":
|
|
if request.app.state.config.SERPAPI_API_KEY:
|
|
return search_serpapi(
|
|
request.app.state.config.SERPAPI_API_KEY,
|
|
request.app.state.config.SERPAPI_ENGINE,
|
|
query,
|
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
|
)
|
|
else:
|
|
raise Exception("No SERPAPI_API_KEY found in environment variables")
|
|
elif engine == "jina":
|
|
return search_jina(
|
|
request.app.state.config.JINA_API_KEY,
|
|
query,
|
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
|
)
|
|
elif engine == "bing":
|
|
return search_bing(
|
|
request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
|
request.app.state.config.BING_SEARCH_V7_ENDPOINT,
|
|
str(DEFAULT_LOCALE),
|
|
query,
|
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
|
)
|
|
elif engine == "perplexity":
|
|
return search_perplexity(
|
|
request.app.state.config.PERPLEXITY_API_KEY,
|
|
query,
|
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
|
model=request.app.state.config.PERPLEXITY_MODEL,
|
|
search_context_usage=request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE,
|
|
)
|
|
elif engine == "sougou":
|
|
if (
|
|
request.app.state.config.SOUGOU_API_SID
|
|
and request.app.state.config.SOUGOU_API_SK
|
|
):
|
|
return search_sougou(
|
|
request.app.state.config.SOUGOU_API_SID,
|
|
request.app.state.config.SOUGOU_API_SK,
|
|
query,
|
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
|
)
|
|
else:
|
|
raise Exception(
|
|
"No SOUGOU_API_SID or SOUGOU_API_SK found in environment variables"
|
|
)
|
|
elif engine == "firecrawl":
|
|
return search_firecrawl(
|
|
request.app.state.config.FIRECRAWL_API_BASE_URL,
|
|
request.app.state.config.FIRECRAWL_API_KEY,
|
|
query,
|
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
|
)
|
|
elif engine == "external":
|
|
return search_external(
|
|
request.app.state.config.EXTERNAL_WEB_SEARCH_URL,
|
|
request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY,
|
|
query,
|
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
|
)
|
|
else:
|
|
raise Exception("No search engine API key found in environment variables")
|
|
|
|
|
|
@router.post("/process/web/search")
|
|
async def process_web_search(
|
|
request: Request, form_data: SearchForm, user=Depends(get_verified_user)
|
|
):
|
|
|
|
urls = []
|
|
try:
|
|
logging.info(
|
|
f"trying to web search with {request.app.state.config.WEB_SEARCH_ENGINE, form_data.queries}"
|
|
)
|
|
|
|
search_tasks = [
|
|
run_in_threadpool(
|
|
search_web,
|
|
request,
|
|
request.app.state.config.WEB_SEARCH_ENGINE,
|
|
query,
|
|
)
|
|
for query in form_data.queries
|
|
]
|
|
|
|
search_results = await asyncio.gather(*search_tasks)
|
|
|
|
for result in search_results:
|
|
if result:
|
|
for item in result:
|
|
if item and item.link:
|
|
urls.append(item.link)
|
|
|
|
urls = list(dict.fromkeys(urls))
|
|
log.debug(f"urls: {urls}")
|
|
|
|
except Exception as e:
|
|
log.exception(e)
|
|
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
|
|
)
|
|
|
|
try:
|
|
if request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER:
|
|
search_results = [
|
|
item for result in search_results for item in result if result
|
|
]
|
|
|
|
docs = [
|
|
Document(
|
|
page_content=result.snippet,
|
|
metadata={
|
|
"source": result.link,
|
|
"title": result.title,
|
|
"snippet": result.snippet,
|
|
"link": result.link,
|
|
},
|
|
)
|
|
for result in search_results
|
|
if hasattr(result, "snippet")
|
|
]
|
|
else:
|
|
loader = get_web_loader(
|
|
urls,
|
|
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
|
requests_per_second=request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
|
|
trust_env=request.app.state.config.WEB_SEARCH_TRUST_ENV,
|
|
)
|
|
docs = await loader.aload()
|
|
|
|
urls = [
|
|
doc.metadata.get("source") for doc in docs if doc.metadata.get("source")
|
|
] # only keep the urls returned by the loader
|
|
|
|
if request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
|
|
return {
|
|
"status": True,
|
|
"collection_name": None,
|
|
"filenames": urls,
|
|
"docs": [
|
|
{
|
|
"content": doc.page_content,
|
|
"metadata": doc.metadata,
|
|
}
|
|
for doc in docs
|
|
],
|
|
"loaded_count": len(docs),
|
|
}
|
|
else:
|
|
# Create a single collection for all documents
|
|
collection_name = (
|
|
f"web-search-{calculate_sha256_string('-'.join(form_data.queries))}"[
|
|
:63
|
|
]
|
|
)
|
|
|
|
try:
|
|
await run_in_threadpool(
|
|
save_docs_to_vector_db,
|
|
request,
|
|
docs,
|
|
collection_name,
|
|
overwrite=True,
|
|
user=user,
|
|
)
|
|
except Exception as e:
|
|
log.debug(f"error saving docs: {e}")
|
|
|
|
return {
|
|
"status": True,
|
|
"collection_names": [collection_name],
|
|
"filenames": urls,
|
|
"loaded_count": len(docs),
|
|
}
|
|
except Exception as e:
|
|
log.exception(e)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=ERROR_MESSAGES.DEFAULT(e),
|
|
)
|
|
|
|
|
|
class QueryDocForm(BaseModel):
|
|
collection_name: str
|
|
query: str
|
|
k: Optional[int] = None
|
|
k_reranker: Optional[int] = None
|
|
r: Optional[float] = None
|
|
hybrid: Optional[bool] = None
|
|
|
|
|
|
@router.post("/query/doc")
|
|
def query_doc_handler(
|
|
request: Request,
|
|
form_data: QueryDocForm,
|
|
user=Depends(get_verified_user),
|
|
):
|
|
try:
|
|
# Try to get individual rag config for this collection
|
|
rag_config = {}
|
|
knowledge_base = Knowledges.get_knowledge_by_id(form_data.collection_name)
|
|
if knowledge_base and not knowledge_base.rag_config.get("DEFAULT_RAG_SETTINGS", True):
|
|
rag_config = knowledge_base.rag_config
|
|
|
|
# Use config from rag_config if present, else fallback to global config
|
|
enable_hybrid = rag_config.get("ENABLE_RAG_HYBRID_SEARCH", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH)
|
|
embedding_model = rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL)
|
|
reranking_model = rag_config.get("RAG_RERANKING_MODEL", request.app.state.config.RAG_RERANKING_MODEL)
|
|
top_k = form_data.k if form_data.k else rag_config.get("TOP_K", request.app.state.config.TOP_K)
|
|
top_k_reranker = form_data.k_reranker if form_data.k_reranker else rag_config.get("TOP_K_RERANKER", request.app.state.config.TOP_K_RERANKER)
|
|
relevance_threshold = form_data.r if form_data.r else rag_config.get("RELEVANCE_THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD)
|
|
hybrid_bm25_weight = getattr(form_data, "hybrid_bm25_weight", None)
|
|
if hybrid_bm25_weight is None:
|
|
hybrid_bm25_weight = rag_config.get("HYBRID_BM25_WEIGHT", request.app.state.config.HYBRID_BM25_WEIGHT)
|
|
|
|
if enable_hybrid:
|
|
collection_results = {}
|
|
collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get(
|
|
collection_name=form_data.collection_name
|
|
)
|
|
return query_doc_with_hybrid_search(
|
|
collection_name=form_data.collection_name,
|
|
collection_result=collection_results[form_data.collection_name],
|
|
query=form_data.query,
|
|
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION[embedding_model](
|
|
query, prefix=prefix, user=user
|
|
),
|
|
k=top_k,
|
|
reranking_function=request.app.state.rf[reranking_model],
|
|
k_reranker=top_k_reranker,
|
|
r=relevance_threshold,
|
|
hybrid_bm25_weight=hybrid_bm25_weight,
|
|
user=user,
|
|
)
|
|
else:
|
|
return query_doc(
|
|
collection_name=form_data.collection_name,
|
|
query_embedding=request.app.state.EMBEDDING_FUNCTION[embedding_model](
|
|
form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user
|
|
),
|
|
k=top_k,
|
|
user=user,
|
|
)
|
|
except Exception as e:
|
|
log.exception(e)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=ERROR_MESSAGES.DEFAULT(e),
|
|
)
|
|
|
|
|
|
class QueryCollectionsForm(BaseModel):
|
|
collection_names: list[str]
|
|
query: str
|
|
k: Optional[int] = None
|
|
k_reranker: Optional[int] = None
|
|
r: Optional[float] = None
|
|
hybrid: Optional[bool] = None
|
|
hybrid_bm25_weight: Optional[float] = None
|
|
|
|
|
|
@router.post("/query/collection")
|
|
def query_collection_handler(
|
|
request: Request,
|
|
form_data: QueryCollectionsForm,
|
|
user=Depends(get_verified_user),
|
|
):
|
|
try:
|
|
if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
|
|
return query_collection_with_hybrid_search(
|
|
collection_names=form_data.collection_names,
|
|
queries=[form_data.query],
|
|
user=user,
|
|
ef=request.app.state.EMBEDDING_FUNCTION,
|
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
|
reranking_function=request.app.state.rf,
|
|
k_reranker=form_data.k_reranker
|
|
or request.app.state.config.TOP_K_RERANKER,
|
|
r=(
|
|
form_data.r
|
|
if form_data.r
|
|
else request.app.state.config.RELEVANCE_THRESHOLD
|
|
),
|
|
hybrid_bm25_weight=(
|
|
form_data.hybrid_bm25_weight
|
|
if form_data.hybrid_bm25_weight
|
|
else request.app.state.config.HYBRID_BM25_WEIGHT
|
|
),
|
|
embedding_model=request.app.state.config.RAG_EMBEDDING_MODEL,
|
|
reranking_model=request.app.state.config.RAG_RERANKING_MODEL,
|
|
)
|
|
else:
|
|
return query_collection(
|
|
collection_names=form_data.collection_names,
|
|
queries=[form_data.query],
|
|
user=user,
|
|
ef=request.app.state.EMBEDDING_FUNCTION,
|
|
embedding_model=request.app.state.config.RAG_EMBEDDING_MODEL,
|
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
|
)
|
|
|
|
except Exception as e:
|
|
log.exception(e)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=ERROR_MESSAGES.DEFAULT(e),
|
|
)
|
|
|
|
|
|
####################################
|
|
#
|
|
# Vector DB operations
|
|
#
|
|
####################################
|
|
|
|
|
|
class DeleteForm(BaseModel):
|
|
collection_name: str
|
|
file_id: str
|
|
|
|
|
|
@router.post("/delete")
|
|
def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)):
|
|
try:
|
|
if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name):
|
|
file = Files.get_file_by_id(form_data.file_id)
|
|
hash = file.hash
|
|
|
|
VECTOR_DB_CLIENT.delete(
|
|
collection_name=form_data.collection_name,
|
|
metadata={"hash": hash},
|
|
)
|
|
return {"status": True}
|
|
else:
|
|
return {"status": False}
|
|
except Exception as e:
|
|
log.exception(e)
|
|
return {"status": False}
|
|
|
|
|
|
@router.post("/reset/db")
|
|
def reset_vector_db(user=Depends(get_admin_user)):
|
|
VECTOR_DB_CLIENT.reset()
|
|
Knowledges.delete_all_knowledge()
|
|
|
|
|
|
@router.post("/reset/uploads")
|
|
def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
|
|
folder = f"{UPLOAD_DIR}"
|
|
try:
|
|
# Check if the directory exists
|
|
if os.path.exists(folder):
|
|
# Iterate over all the files and directories in the specified directory
|
|
for filename in os.listdir(folder):
|
|
file_path = os.path.join(folder, filename)
|
|
try:
|
|
if os.path.isfile(file_path) or os.path.islink(file_path):
|
|
os.unlink(file_path) # Remove the file or link
|
|
elif os.path.isdir(file_path):
|
|
shutil.rmtree(file_path) # Remove the directory
|
|
except Exception as e:
|
|
log.exception(f"Failed to delete {file_path}. Reason: {e}")
|
|
else:
|
|
log.warning(f"The directory {folder} does not exist")
|
|
except Exception as e:
|
|
log.exception(f"Failed to process the directory {folder}. Reason: {e}")
|
|
return True
|
|
|
|
|
|
if ENV == "dev":
|
|
|
|
@router.get("/ef/{text}")
|
|
async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"):
|
|
return {
|
|
"result": request.app.state.EMBEDDING_FUNCTION(
|
|
text, prefix=RAG_EMBEDDING_QUERY_PREFIX
|
|
)
|
|
}
|
|
|
|
|
|
class BatchProcessFilesForm(BaseModel):
|
|
files: List[FileModel]
|
|
collection_name: str
|
|
|
|
|
|
class BatchProcessFilesResult(BaseModel):
|
|
file_id: str
|
|
status: str
|
|
error: Optional[str] = None
|
|
|
|
|
|
class BatchProcessFilesResponse(BaseModel):
|
|
results: List[BatchProcessFilesResult]
|
|
errors: List[BatchProcessFilesResult]
|
|
|
|
|
|
@router.post("/process/files/batch")
|
|
def process_files_batch(
|
|
request: Request,
|
|
form_data: BatchProcessFilesForm,
|
|
user=Depends(get_verified_user),
|
|
) -> BatchProcessFilesResponse:
|
|
"""
|
|
Process a batch of files and save them to the vector database.
|
|
"""
|
|
results: List[BatchProcessFilesResult] = []
|
|
errors: List[BatchProcessFilesResult] = []
|
|
collection_name = form_data.collection_name
|
|
|
|
# Prepare all documents first
|
|
all_docs: List[Document] = []
|
|
for file in form_data.files:
|
|
try:
|
|
text_content = file.data.get("content", "")
|
|
|
|
docs: List[Document] = [
|
|
Document(
|
|
page_content=text_content.replace("<br/>", "\n"),
|
|
metadata={
|
|
**file.meta,
|
|
"name": file.filename,
|
|
"created_by": file.user_id,
|
|
"file_id": file.id,
|
|
"source": file.filename,
|
|
},
|
|
)
|
|
]
|
|
|
|
hash = calculate_sha256_string(text_content)
|
|
Files.update_file_hash_by_id(file.id, hash)
|
|
Files.update_file_data_by_id(file.id, {"content": text_content})
|
|
|
|
all_docs.extend(docs)
|
|
results.append(BatchProcessFilesResult(file_id=file.id, status="prepared"))
|
|
|
|
except Exception as e:
|
|
log.error(f"process_files_batch: Error processing file {file.id}: {str(e)}")
|
|
errors.append(
|
|
BatchProcessFilesResult(file_id=file.id, status="failed", error=str(e))
|
|
)
|
|
|
|
# Save all documents in one batch
|
|
if all_docs:
|
|
try:
|
|
save_docs_to_vector_db(
|
|
request=request,
|
|
docs=all_docs,
|
|
collection_name=collection_name,
|
|
add=True,
|
|
user=user,
|
|
)
|
|
|
|
# Update all files with collection name
|
|
for result in results:
|
|
Files.update_file_metadata_by_id(
|
|
result.file_id, {"collection_name": collection_name}
|
|
)
|
|
result.status = "completed"
|
|
|
|
except Exception as e:
|
|
log.error(
|
|
f"process_files_batch: Error saving documents to vector DB: {str(e)}"
|
|
)
|
|
for result in results:
|
|
result.status = "failed"
|
|
errors.append(
|
|
BatchProcessFilesResult(file_id=result.file_id, error=str(e))
|
|
)
|
|
|
|
return BatchProcessFilesResponse(results=results, errors=errors)
|