sort and fix backend imports

This commit is contained in:
Pascal Lim 2024-08-28 00:10:27 +02:00
parent 08efabc696
commit c386d0b1a5
63 changed files with 598 additions and 1023 deletions

View File

@ -7,46 +7,33 @@ from functools import lru_cache
from pathlib import Path from pathlib import Path
import requests import requests
from fastapi import ( from config import (
FastAPI, AUDIO_STT_ENGINE,
Request, AUDIO_STT_MODEL,
Depends, AUDIO_STT_OPENAI_API_BASE_URL,
HTTPException, AUDIO_STT_OPENAI_API_KEY,
status, AUDIO_TTS_API_KEY,
UploadFile, AUDIO_TTS_ENGINE,
File, AUDIO_TTS_MODEL,
AUDIO_TTS_OPENAI_API_BASE_URL,
AUDIO_TTS_OPENAI_API_KEY,
AUDIO_TTS_SPLIT_ON,
AUDIO_TTS_VOICE,
CACHE_DIR,
CORS_ALLOW_ORIGIN,
DEVICE_TYPE,
WHISPER_MODEL,
WHISPER_MODEL_AUTO_UPDATE,
WHISPER_MODEL_DIR,
AppConfig,
) )
from constants import ERROR_MESSAGES
from env import SRC_LOG_LEVELS
from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from pydantic import BaseModel from pydantic import BaseModel
from utils.utils import get_admin_user, get_current_user, get_verified_user
from config import (
SRC_LOG_LEVELS,
CACHE_DIR,
WHISPER_MODEL,
WHISPER_MODEL_DIR,
WHISPER_MODEL_AUTO_UPDATE,
DEVICE_TYPE,
AUDIO_STT_OPENAI_API_BASE_URL,
AUDIO_STT_OPENAI_API_KEY,
AUDIO_TTS_OPENAI_API_BASE_URL,
AUDIO_TTS_OPENAI_API_KEY,
AUDIO_TTS_API_KEY,
AUDIO_STT_ENGINE,
AUDIO_STT_MODEL,
AUDIO_TTS_ENGINE,
AUDIO_TTS_MODEL,
AUDIO_TTS_VOICE,
AUDIO_TTS_SPLIT_ON,
AppConfig,
CORS_ALLOW_ORIGIN,
)
from constants import ERROR_MESSAGES
from utils.utils import (
get_current_user,
get_verified_user,
get_admin_user,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["AUDIO"]) log.setLevel(SRC_LOG_LEVELS["AUDIO"])
@ -211,7 +198,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
body = json.loads(body) body = json.loads(body)
body["model"] = app.state.config.TTS_MODEL body["model"] = app.state.config.TTS_MODEL
body = json.dumps(body).encode("utf-8") body = json.dumps(body).encode("utf-8")
except Exception as e: except Exception:
pass pass
r = None r = None
@ -488,7 +475,7 @@ def get_available_voices() -> dict:
elif app.state.config.TTS_ENGINE == "elevenlabs": elif app.state.config.TTS_ENGINE == "elevenlabs":
try: try:
ret = get_elevenlabs_voices() ret = get_elevenlabs_voices()
except Exception as e: except Exception:
# Avoided @lru_cache with exception # Avoided @lru_cache with exception
pass pass

View File

@ -1,52 +1,42 @@
from fastapi import ( import asyncio
FastAPI,
Request,
Depends,
HTTPException,
)
from fastapi.middleware.cors import CORSMiddleware
from typing import Optional
from pydantic import BaseModel
from pathlib import Path
import mimetypes
import uuid
import base64 import base64
import json import json
import logging import logging
import mimetypes
import re import re
import uuid
from pathlib import Path
from typing import Optional
import requests import requests
import asyncio
from utils.utils import (
get_verified_user,
get_admin_user,
)
from apps.images.utils.comfyui import ( from apps.images.utils.comfyui import (
ComfyUIWorkflow,
ComfyUIGenerateImageForm, ComfyUIGenerateImageForm,
ComfyUIWorkflow,
comfyui_generate_image, comfyui_generate_image,
) )
from constants import ERROR_MESSAGES
from config import ( from config import (
SRC_LOG_LEVELS,
CACHE_DIR,
IMAGE_GENERATION_ENGINE,
ENABLE_IMAGE_GENERATION,
AUTOMATIC1111_BASE_URL,
AUTOMATIC1111_API_AUTH, AUTOMATIC1111_API_AUTH,
AUTOMATIC1111_BASE_URL,
CACHE_DIR,
COMFYUI_BASE_URL, COMFYUI_BASE_URL,
COMFYUI_WORKFLOW, COMFYUI_WORKFLOW,
COMFYUI_WORKFLOW_NODES, COMFYUI_WORKFLOW_NODES,
IMAGES_OPENAI_API_BASE_URL, CORS_ALLOW_ORIGIN,
IMAGES_OPENAI_API_KEY, ENABLE_IMAGE_GENERATION,
IMAGE_GENERATION_ENGINE,
IMAGE_GENERATION_MODEL, IMAGE_GENERATION_MODEL,
IMAGE_SIZE, IMAGE_SIZE,
IMAGE_STEPS, IMAGE_STEPS,
CORS_ALLOW_ORIGIN, IMAGES_OPENAI_API_BASE_URL,
IMAGES_OPENAI_API_KEY,
AppConfig, AppConfig,
) )
from constants import ERROR_MESSAGES
from env import SRC_LOG_LEVELS
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from utils.utils import get_admin_user, get_verified_user
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"]) log.setLevel(SRC_LOG_LEVELS["IMAGES"])
@ -186,7 +176,7 @@ async def verify_url(user=Depends(get_admin_user)):
) )
r.raise_for_status() r.raise_for_status()
return True return True
except Exception as e: except Exception:
app.state.config.ENABLED = False app.state.config.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
elif app.state.config.ENGINE == "comfyui": elif app.state.config.ENGINE == "comfyui":
@ -194,7 +184,7 @@ async def verify_url(user=Depends(get_admin_user)):
r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info") r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
r.raise_for_status() r.raise_for_status()
return True return True
except Exception as e: except Exception:
app.state.config.ENABLED = False app.state.config.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
else: else:
@ -397,7 +387,6 @@ def save_url_image(url):
r = requests.get(url) r = requests.get(url)
r.raise_for_status() r.raise_for_status()
if r.headers["content-type"].split("/")[0] == "image": if r.headers["content-type"].split("/")[0] == "image":
mime_type = r.headers["content-type"] mime_type = r.headers["content-type"]
image_format = mimetypes.guess_extension(mime_type) image_format = mimetypes.guess_extension(mime_type)
@ -412,7 +401,7 @@ def save_url_image(url):
image_file.write(chunk) image_file.write(chunk)
return image_filename return image_filename
else: else:
log.error(f"Url does not point to an image.") log.error("Url does not point to an image.")
return None return None
except Exception as e: except Exception as e:
@ -430,7 +419,6 @@ async def image_generations(
r = None r = None
try: try:
if app.state.config.ENGINE == "openai": if app.state.config.ENGINE == "openai":
headers = {} headers = {}
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}" headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"

View File

@ -1,20 +1,18 @@
import asyncio import asyncio
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import json import json
import urllib.request
import urllib.parse
import random
import logging import logging
import random
import urllib.parse
import urllib.request
from typing import Optional
from config import SRC_LOG_LEVELS import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
from env import SRC_LOG_LEVELS
from pydantic import BaseModel
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["COMFYUI"]) log.setLevel(SRC_LOG_LEVELS["COMFYUI"])
from pydantic import BaseModel
from typing import Optional
default_headers = {"User-Agent": "Mozilla/5.0"} default_headers = {"User-Agent": "Mozilla/5.0"}

View File

@ -1,54 +1,40 @@
from fastapi import (
FastAPI,
Request,
HTTPException,
Depends,
UploadFile,
File,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, ConfigDict
import os
import re
import random
import requests
import json
import aiohttp
import asyncio import asyncio
import json
import logging import logging
import os
import random
import re
import time import time
from urllib.parse import urlparse
from typing import Optional, Union from typing import Optional, Union
from urllib.parse import urlparse
from starlette.background import BackgroundTask import aiohttp
import requests
from apps.webui.models.models import Models from apps.webui.models.models import Models
from constants import ERROR_MESSAGES
from utils.utils import (
get_verified_user,
get_admin_user,
)
from config import ( from config import (
SRC_LOG_LEVELS,
OLLAMA_BASE_URLS,
ENABLE_OLLAMA_API,
AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT,
CORS_ALLOW_ORIGIN,
ENABLE_MODEL_FILTER, ENABLE_MODEL_FILTER,
ENABLE_OLLAMA_API,
MODEL_FILTER_LIST, MODEL_FILTER_LIST,
OLLAMA_BASE_URLS,
UPLOAD_DIR, UPLOAD_DIR,
AppConfig, AppConfig,
CORS_ALLOW_ORIGIN,
) )
from constants import ERROR_MESSAGES
from env import SRC_LOG_LEVELS
from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, ConfigDict
from starlette.background import BackgroundTask
from utils.misc import ( from utils.misc import (
calculate_sha256,
apply_model_params_to_body_ollama, apply_model_params_to_body_ollama,
apply_model_params_to_body_openai, apply_model_params_to_body_openai,
apply_model_system_prompt_to_body, apply_model_system_prompt_to_body,
calculate_sha256,
) )
from utils.utils import get_admin_user, get_verified_user
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) log.setLevel(SRC_LOG_LEVELS["OLLAMA"])

View File

@ -1,44 +1,36 @@
from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, FileResponse
import requests
import aiohttp
import asyncio import asyncio
import hashlib
import json import json
import logging import logging
from pathlib import Path
from typing import Literal, Optional, overload
import aiohttp
import requests
from apps.webui.models.models import Models
from config import (
AIOHTTP_CLIENT_TIMEOUT,
CACHE_DIR,
CORS_ALLOW_ORIGIN,
ENABLE_MODEL_FILTER,
ENABLE_OPENAI_API,
MODEL_FILTER_LIST,
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
AppConfig,
)
from constants import ERROR_MESSAGES
from env import SRC_LOG_LEVELS
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from starlette.background import BackgroundTask from starlette.background import BackgroundTask
from apps.webui.models.models import Models
from constants import ERROR_MESSAGES
from utils.utils import (
get_verified_user,
get_admin_user,
)
from utils.misc import ( from utils.misc import (
apply_model_params_to_body_openai, apply_model_params_to_body_openai,
apply_model_system_prompt_to_body, apply_model_system_prompt_to_body,
) )
from utils.utils import get_admin_user, get_verified_user
from config import (
SRC_LOG_LEVELS,
ENABLE_OPENAI_API,
AIOHTTP_CLIENT_TIMEOUT,
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
CACHE_DIR,
ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST,
AppConfig,
CORS_ALLOW_ORIGIN,
)
from typing import Optional, Literal, overload
import hashlib
from pathlib import Path
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"]) log.setLevel(SRC_LOG_LEVELS["OPENAI"])

View File

@ -1,143 +1,118 @@
from fastapi import (
FastAPI,
Depends,
HTTPException,
status,
UploadFile,
File,
Form,
)
from fastapi.middleware.cors import CORSMiddleware
import requests
import os, shutil, logging, re
from datetime import datetime
from pathlib import Path
from typing import Union, Sequence, Iterator, Any
from chromadb.utils.batch_utils import create_batches
from langchain_core.documents import Document
from langchain_community.document_loaders import (
WebBaseLoader,
TextLoader,
PyPDFLoader,
CSVLoader,
BSHTMLLoader,
Docx2txtLoader,
UnstructuredEPubLoader,
UnstructuredWordDocumentLoader,
UnstructuredMarkdownLoader,
UnstructuredXMLLoader,
UnstructuredRSTLoader,
UnstructuredExcelLoader,
UnstructuredPowerPointLoader,
YoutubeLoader,
OutlookMessageLoader,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
import validators
import urllib.parse
import socket
from pydantic import BaseModel
from typing import Optional
import mimetypes
import uuid
import json import json
import logging
import mimetypes
import os
import shutil
import socket
import urllib.parse
import uuid
from datetime import datetime
from pathlib import Path
from typing import Iterator, Optional, Sequence, Union
from apps.webui.models.documents import ( import requests
Documents, import validators
DocumentForm,
DocumentResponse,
)
from apps.webui.models.files import (
Files,
)
from apps.rag.utils import (
get_model_path,
get_embedding_function,
query_doc,
query_doc_with_hybrid_search,
query_collection,
query_collection_with_hybrid_search,
)
from apps.rag.search.brave import search_brave from apps.rag.search.brave import search_brave
from apps.rag.search.duckduckgo import search_duckduckgo
from apps.rag.search.google_pse import search_google_pse from apps.rag.search.google_pse import search_google_pse
from apps.rag.search.jina_search import search_jina
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult
from apps.rag.search.searchapi import search_searchapi
from apps.rag.search.searxng import search_searxng from apps.rag.search.searxng import search_searxng
from apps.rag.search.serper import search_serper from apps.rag.search.serper import search_serper
from apps.rag.search.serpstack import search_serpstack
from apps.rag.search.serply import search_serply from apps.rag.search.serply import search_serply
from apps.rag.search.duckduckgo import search_duckduckgo from apps.rag.search.serpstack import search_serpstack
from apps.rag.search.tavily import search_tavily from apps.rag.search.tavily import search_tavily
from apps.rag.search.jina_search import search_jina from apps.rag.utils import (
from apps.rag.search.searchapi import search_searchapi get_embedding_function,
get_model_path,
from utils.misc import ( query_collection,
calculate_sha256, query_collection_with_hybrid_search,
calculate_sha256_string, query_doc,
sanitize_filename, query_doc_with_hybrid_search,
extract_folders_after_data_docs,
) )
from utils.utils import get_verified_user, get_admin_user from apps.webui.models.documents import DocumentForm, Documents
from apps.webui.models.files import Files
from chromadb.utils.batch_utils import create_batches
from config import ( from config import (
AppConfig, BRAVE_SEARCH_API_KEY,
ENV, CHROMA_CLIENT,
SRC_LOG_LEVELS, CHUNK_OVERLAP,
UPLOAD_DIR, CHUNK_SIZE,
DOCS_DIR,
CONTENT_EXTRACTION_ENGINE, CONTENT_EXTRACTION_ENGINE,
TIKA_SERVER_URL, CORS_ALLOW_ORIGIN,
RAG_TOP_K, DEVICE_TYPE,
RAG_RELEVANCE_THRESHOLD, DOCS_DIR,
RAG_FILE_MAX_SIZE, ENABLE_RAG_HYBRID_SEARCH,
RAG_FILE_MAX_COUNT, ENABLE_RAG_LOCAL_WEB_FETCH,
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
ENABLE_RAG_WEB_SEARCH,
ENV,
GOOGLE_PSE_API_KEY,
GOOGLE_PSE_ENGINE_ID,
PDF_EXTRACT_IMAGES,
RAG_EMBEDDING_ENGINE, RAG_EMBEDDING_ENGINE,
RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_AUTO_UPDATE,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
ENABLE_RAG_HYBRID_SEARCH, RAG_EMBEDDING_OPENAI_BATCH_SIZE,
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, RAG_FILE_MAX_COUNT,
RAG_RERANKING_MODEL, RAG_FILE_MAX_SIZE,
PDF_EXTRACT_IMAGES,
RAG_RERANKING_MODEL_AUTO_UPDATE,
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
RAG_OPENAI_API_BASE_URL, RAG_OPENAI_API_BASE_URL,
RAG_OPENAI_API_KEY, RAG_OPENAI_API_KEY,
DEVICE_TYPE, RAG_RELEVANCE_THRESHOLD,
CHROMA_CLIENT, RAG_RERANKING_MODEL,
CHUNK_SIZE, RAG_RERANKING_MODEL_AUTO_UPDATE,
CHUNK_OVERLAP, RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
RAG_TEMPLATE, RAG_TEMPLATE,
ENABLE_RAG_LOCAL_WEB_FETCH, RAG_TOP_K,
YOUTUBE_LOADER_LANGUAGE, RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
ENABLE_RAG_WEB_SEARCH,
RAG_WEB_SEARCH_ENGINE,
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
SEARXNG_QUERY_URL, RAG_WEB_SEARCH_ENGINE,
GOOGLE_PSE_API_KEY, RAG_WEB_SEARCH_RESULT_COUNT,
GOOGLE_PSE_ENGINE_ID,
BRAVE_SEARCH_API_KEY,
SERPSTACK_API_KEY,
SERPSTACK_HTTPS,
SERPER_API_KEY,
SERPLY_API_KEY,
TAVILY_API_KEY,
SEARCHAPI_API_KEY, SEARCHAPI_API_KEY,
SEARCHAPI_ENGINE, SEARCHAPI_ENGINE,
RAG_WEB_SEARCH_RESULT_COUNT, SEARXNG_QUERY_URL,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS, SERPER_API_KEY,
RAG_EMBEDDING_OPENAI_BATCH_SIZE, SERPLY_API_KEY,
CORS_ALLOW_ORIGIN, SERPSTACK_API_KEY,
SERPSTACK_HTTPS,
TAVILY_API_KEY,
TIKA_SERVER_URL,
UPLOAD_DIR,
YOUTUBE_LOADER_LANGUAGE,
AppConfig,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from env import SRC_LOG_LEVELS
from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
from fastapi.middleware.cors import CORSMiddleware
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import (
BSHTMLLoader,
CSVLoader,
Docx2txtLoader,
OutlookMessageLoader,
PyPDFLoader,
TextLoader,
UnstructuredEPubLoader,
UnstructuredExcelLoader,
UnstructuredMarkdownLoader,
UnstructuredPowerPointLoader,
UnstructuredRSTLoader,
UnstructuredXMLLoader,
WebBaseLoader,
YoutubeLoader,
)
from langchain_core.documents import Document
from pydantic import BaseModel
from utils.misc import (
calculate_sha256,
calculate_sha256_string,
extract_folders_after_data_docs,
sanitize_filename,
)
from utils.utils import get_admin_user, get_verified_user
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
@ -539,9 +514,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
app.state.config.SEARCHAPI_API_KEY = form_data.web.search.searchapi_api_key app.state.config.SEARCHAPI_API_KEY = form_data.web.search.searchapi_api_key
app.state.config.SEARCHAPI_ENGINE = ( app.state.config.SEARCHAPI_ENGINE = form_data.web.search.searchapi_engine
form_data.web.search.searchapi_engine
)
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = ( app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
form_data.web.search.concurrent_requests form_data.web.search.concurrent_requests
@ -981,7 +954,6 @@ def store_web_search(form_data: SearchForm, user=Depends(get_verified_user)):
def store_data_in_vector_db( def store_data_in_vector_db(
data, collection_name, metadata: Optional[dict] = None, overwrite: bool = False data, collection_name, metadata: Optional[dict] = None, overwrite: bool = False
) -> bool: ) -> bool:
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.config.CHUNK_SIZE, chunk_size=app.state.config.CHUNK_SIZE,
chunk_overlap=app.state.config.CHUNK_OVERLAP, chunk_overlap=app.state.config.CHUNK_OVERLAP,
@ -1342,7 +1314,6 @@ def store_text(
form_data: TextRAGForm, form_data: TextRAGForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
collection_name = form_data.collection_name collection_name = form_data.collection_name
if collection_name is None: if collection_name is None:
collection_name = calculate_sha256_string(form_data.content) collection_name = calculate_sha256_string(form_data.content)

View File

@ -1,9 +1,9 @@
import logging import logging
from typing import Optional from typing import Optional
import requests
import requests
from apps.rag.search.main import SearchResult, get_filtered_results from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])

View File

@ -1,8 +1,9 @@
import logging import logging
from typing import Optional from typing import Optional
from apps.rag.search.main import SearchResult, get_filtered_results from apps.rag.search.main import SearchResult, get_filtered_results
from duckduckgo_search import DDGS from duckduckgo_search import DDGS
from config import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])

View File

@ -1,10 +1,9 @@
import json
import logging import logging
from typing import Optional from typing import Optional
import requests
import requests
from apps.rag.search.main import SearchResult, get_filtered_results from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])

View File

@ -1,9 +1,9 @@
import logging import logging
import requests
from yarl import URL
import requests
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
from yarl import URL
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])

View File

@ -1,5 +1,6 @@
from typing import Optional from typing import Optional
from urllib.parse import urlparse from urllib.parse import urlparse
from pydantic import BaseModel from pydantic import BaseModel

View File

@ -1,10 +1,9 @@
import logging import logging
import requests
from typing import Optional from typing import Optional
import requests
from apps.rag.search.main import SearchResult, get_filtered_results from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])

View File

@ -1,10 +1,10 @@
import json import json
import logging import logging
from typing import Optional from typing import Optional
import requests
import requests
from apps.rag.search.main import SearchResult, get_filtered_results from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])

View File

@ -1,11 +1,10 @@
import json
import logging import logging
from typing import Optional from typing import Optional
import requests
from urllib.parse import urlencode from urllib.parse import urlencode
import requests
from apps.rag.search.main import SearchResult, get_filtered_results from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])

View File

@ -1,10 +1,9 @@
import json
import logging import logging
from typing import Optional from typing import Optional
import requests
import requests
from apps.rag.search.main import SearchResult, get_filtered_results from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])

View File

@ -1,9 +1,8 @@
import logging import logging
import requests import requests
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])

View File

@ -1,27 +1,16 @@
import os
import logging import logging
import os
from typing import Optional, Union
import requests import requests
from apps.ollama.main import GenerateEmbeddingsForm, generate_ollama_embeddings
from typing import Union from config import CHROMA_CLIENT
from env import SRC_LOG_LEVELS
from apps.ollama.main import (
generate_ollama_embeddings,
GenerateEmbeddingsForm,
)
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain_core.documents import Document
from langchain_community.retrievers import BM25Retriever from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import ( from langchain_core.documents import Document
ContextualCompressionRetriever, from utils.misc import get_last_user_message
EnsembleRetriever,
)
from typing import Optional
from utils.misc import get_last_user_message, add_or_update_system_message
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
@ -261,7 +250,9 @@ def get_rag_context(
collection_names = ( collection_names = (
file["collection_names"] file["collection_names"]
if file["type"] == "collection" if file["type"] == "collection"
else [file["collection_name"]] if file["collection_name"] else [] else [file["collection_name"]]
if file["collection_name"]
else []
) )
collection_names = set(collection_names).difference(extracted_collections) collection_names = set(collection_names).difference(extracted_collections)
@ -401,8 +392,8 @@ def generate_openai_batch_embeddings(
from typing import Any from typing import Any
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.retrievers import BaseRetriever
class ChromaRetriever(BaseRetriever): class ChromaRetriever(BaseRetriever):
@ -439,11 +430,10 @@ class ChromaRetriever(BaseRetriever):
import operator import operator
from typing import Optional, Sequence from typing import Optional, Sequence
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.callbacks import Callbacks from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.pydantic_v1 import Extra from langchain_core.pydantic_v1 import Extra

View File

@ -1,7 +1,6 @@
import socketio
import asyncio import asyncio
import socketio
from apps.webui.models.users import Users from apps.webui.models.users import Users
from utils.utils import decode_token from utils.utils import decode_token

View File

@ -1,21 +1,16 @@
import os
import logging
import json import json
import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Optional
from typing import Optional, Any
from typing_extensions import Self
from sqlalchemy import create_engine, types, Dialect
from sqlalchemy.sql.type_api import _T
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, scoped_session
from peewee_migrate import Router
from apps.webui.internal.wrappers import register_connection from apps.webui.internal.wrappers import register_connection
from env import SRC_LOG_LEVELS, BACKEND_DIR, DATABASE_URL from env import BACKEND_DIR, DATABASE_URL, SRC_LOG_LEVELS
from peewee_migrate import Router
from sqlalchemy import Dialect, create_engine, types
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.sql.type_api import _T
from typing_extensions import Self
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"]) log.setLevel(SRC_LOG_LEVELS["DB"])

View File

@ -1,12 +1,12 @@
from contextvars import ContextVar
from peewee import *
from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError
import logging import logging
from playhouse.db_url import connect, parse from contextvars import ContextVar
from playhouse.shortcuts import ReconnectMixin
from env import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
from peewee import *
from peewee import InterfaceError as PeeWeeInterfaceError
from peewee import PostgresqlDatabase
from playhouse.db_url import connect, parse
from playhouse.shortcuts import ReconnectMixin
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"]) log.setLevel(SRC_LOG_LEVELS["DB"])

View File

@ -1,65 +1,59 @@
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from apps.webui.routers import (
auths,
users,
chats,
documents,
tools,
models,
prompts,
configs,
memories,
utils,
files,
functions,
)
from apps.webui.models.functions import Functions
from apps.webui.models.models import Models
from apps.webui.utils import load_function_module_by_id
from utils.misc import (
openai_chat_chunk_message_template,
openai_chat_completion_message_template,
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
from utils.tools import get_tools
from config import (
SHOW_ADMIN_DETAILS,
ADMIN_EMAIL,
WEBUI_AUTH,
DEFAULT_MODELS,
DEFAULT_PROMPT_SUGGESTIONS,
DEFAULT_USER_ROLE,
ENABLE_SIGNUP,
ENABLE_LOGIN_FORM,
USER_PERMISSIONS,
WEBHOOK_URL,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
JWT_EXPIRES_IN,
WEBUI_BANNERS,
ENABLE_COMMUNITY_SHARING,
ENABLE_MESSAGE_RATING,
AppConfig,
OAUTH_USERNAME_CLAIM,
OAUTH_PICTURE_CLAIM,
OAUTH_EMAIL_CLAIM,
CORS_ALLOW_ORIGIN,
)
from apps.socket.main import get_event_call, get_event_emitter
import inspect import inspect
import json import json
import logging import logging
from typing import AsyncGenerator, Generator, Iterator
from typing import Iterator, Generator, AsyncGenerator from apps.socket.main import get_event_call, get_event_emitter
from apps.webui.models.functions import Functions
from apps.webui.models.models import Models
from apps.webui.routers import (
auths,
chats,
configs,
documents,
files,
functions,
memories,
models,
prompts,
tools,
users,
utils,
)
from apps.webui.utils import load_function_module_by_id
from config import (
ADMIN_EMAIL,
CORS_ALLOW_ORIGIN,
DEFAULT_MODELS,
DEFAULT_PROMPT_SUGGESTIONS,
DEFAULT_USER_ROLE,
ENABLE_COMMUNITY_SHARING,
ENABLE_LOGIN_FORM,
ENABLE_MESSAGE_RATING,
ENABLE_SIGNUP,
JWT_EXPIRES_IN,
OAUTH_EMAIL_CLAIM,
OAUTH_PICTURE_CLAIM,
OAUTH_USERNAME_CLAIM,
SHOW_ADMIN_DETAILS,
USER_PERMISSIONS,
WEBHOOK_URL,
WEBUI_AUTH,
WEBUI_BANNERS,
AppConfig,
)
from env import WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from utils.misc import (
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
openai_chat_chunk_message_template,
openai_chat_completion_message_template,
)
from utils.tools import get_tools
app = FastAPI() app = FastAPI()

View File

@ -1,15 +1,13 @@
from pydantic import BaseModel
from typing import Optional
import uuid
import logging import logging
from sqlalchemy import String, Column, Boolean, Text import uuid
from typing import Optional
from utils.utils import verify_password
from apps.webui.models.users import UserModel, Users
from apps.webui.internal.db import Base, get_db from apps.webui.internal.db import Base, get_db
from apps.webui.models.users import UserModel, Users
from env import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
from pydantic import BaseModel
from sqlalchemy import Boolean, Column, String, Text
from utils.utils import verify_password
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -92,7 +90,6 @@ class AddUserForm(SignupForm):
class AuthsTable: class AuthsTable:
def insert_new_auth( def insert_new_auth(
self, self,
email: str, email: str,
@ -103,7 +100,6 @@ class AuthsTable:
oauth_sub: Optional[str] = None, oauth_sub: Optional[str] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_db() as db: with get_db() as db:
log.info("insert_new_auth") log.info("insert_new_auth")
id = str(uuid.uuid4()) id = str(uuid.uuid4())
@ -130,7 +126,6 @@ class AuthsTable:
log.info(f"authenticate_user: {email}") log.info(f"authenticate_user: {email}")
try: try:
with get_db() as db: with get_db() as db:
auth = db.query(Auth).filter_by(email=email, active=True).first() auth = db.query(Auth).filter_by(email=email, active=True).first()
if auth: if auth:
if verify_password(password, auth.password): if verify_password(password, auth.password):
@ -189,7 +184,6 @@ class AuthsTable:
def delete_auth_by_id(self, id: str) -> bool: def delete_auth_by_id(self, id: str) -> bool:
try: try:
with get_db() as db: with get_db() as db:
# Delete User # Delete User
result = Users.delete_user_by_id(id) result = Users.delete_user_by_id(id)

View File

@ -1,14 +1,11 @@
from pydantic import BaseModel, ConfigDict
from typing import Union, Optional
import json import json
import uuid
import time import time
import uuid
from sqlalchemy import Column, String, BigInteger, Boolean, Text from typing import Optional
from apps.webui.internal.db import Base, get_db from apps.webui.internal.db import Base, get_db
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text
#################### ####################
# Chat DB Schema # Chat DB Schema
@ -77,10 +74,8 @@ class ChatTitleIdResponse(BaseModel):
class ChatTable: class ChatTable:
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
with get_db() as db: with get_db() as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
chat = ChatModel( chat = ChatModel(
**{ **{
@ -106,7 +101,6 @@ class ChatTable:
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
try: try:
with get_db() as db: with get_db() as db:
chat_obj = db.get(Chat, id) chat_obj = db.get(Chat, id)
chat_obj.chat = json.dumps(chat) chat_obj.chat = json.dumps(chat)
chat_obj.title = chat["title"] if "title" in chat else "New Chat" chat_obj.title = chat["title"] if "title" in chat else "New Chat"
@ -115,12 +109,11 @@ class ChatTable:
db.refresh(chat_obj) db.refresh(chat_obj)
return ChatModel.model_validate(chat_obj) return ChatModel.model_validate(chat_obj)
except Exception as e: except Exception:
return None return None
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
with get_db() as db: with get_db() as db:
# Get the existing chat to share # Get the existing chat to share
chat = db.get(Chat, chat_id) chat = db.get(Chat, chat_id)
# Check if the chat is already shared # Check if the chat is already shared
@ -154,7 +147,6 @@ class ChatTable:
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
try: try:
with get_db() as db: with get_db() as db:
print("update_shared_chat_by_id") print("update_shared_chat_by_id")
chat = db.get(Chat, chat_id) chat = db.get(Chat, chat_id)
print(chat) print(chat)
@ -170,7 +162,6 @@ class ChatTable:
def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
try: try:
with get_db() as db: with get_db() as db:
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
db.commit() db.commit()
@ -183,7 +174,6 @@ class ChatTable:
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
try: try:
with get_db() as db: with get_db() as db:
chat = db.get(Chat, id) chat = db.get(Chat, id)
chat.share_id = share_id chat.share_id = share_id
db.commit() db.commit()
@ -195,7 +185,6 @@ class ChatTable:
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
try: try:
with get_db() as db: with get_db() as db:
chat = db.get(Chat, id) chat = db.get(Chat, id)
chat.archived = not chat.archived chat.archived = not chat.archived
db.commit() db.commit()
@ -217,7 +206,6 @@ class ChatTable:
self, user_id: str, skip: int = 0, limit: int = 50 self, user_id: str, skip: int = 0, limit: int = 50
) -> list[ChatModel]: ) -> list[ChatModel]:
with get_db() as db: with get_db() as db:
all_chats = ( all_chats = (
db.query(Chat) db.query(Chat)
.filter_by(user_id=user_id, archived=True) .filter_by(user_id=user_id, archived=True)
@ -297,7 +285,6 @@ class ChatTable:
def get_chat_by_id(self, id: str) -> Optional[ChatModel]: def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
try: try:
with get_db() as db: with get_db() as db:
chat = db.get(Chat, id) chat = db.get(Chat, id)
return ChatModel.model_validate(chat) return ChatModel.model_validate(chat)
except Exception: except Exception:
@ -306,20 +293,18 @@ class ChatTable:
def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]: def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
try: try:
with get_db() as db: with get_db() as db:
chat = db.query(Chat).filter_by(share_id=id).first() chat = db.query(Chat).filter_by(share_id=id).first()
if chat: if chat:
return self.get_chat_by_id(id) return self.get_chat_by_id(id)
else: else:
return None return None
except Exception as e: except Exception:
return None return None
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
try: try:
with get_db() as db: with get_db() as db:
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
return ChatModel.model_validate(chat) return ChatModel.model_validate(chat)
except Exception: except Exception:
@ -327,7 +312,6 @@ class ChatTable:
def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]: def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]:
with get_db() as db: with get_db() as db:
all_chats = ( all_chats = (
db.query(Chat) db.query(Chat)
# .limit(limit).offset(skip) # .limit(limit).offset(skip)
@ -337,7 +321,6 @@ class ChatTable:
def get_chats_by_user_id(self, user_id: str) -> list[ChatModel]: def get_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db: with get_db() as db:
all_chats = ( all_chats = (
db.query(Chat) db.query(Chat)
.filter_by(user_id=user_id) .filter_by(user_id=user_id)
@ -347,7 +330,6 @@ class ChatTable:
def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]: def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db: with get_db() as db:
all_chats = ( all_chats = (
db.query(Chat) db.query(Chat)
.filter_by(user_id=user_id, archived=True) .filter_by(user_id=user_id, archived=True)
@ -358,7 +340,6 @@ class ChatTable:
def delete_chat_by_id(self, id: str) -> bool: def delete_chat_by_id(self, id: str) -> bool:
try: try:
with get_db() as db: with get_db() as db:
db.query(Chat).filter_by(id=id).delete() db.query(Chat).filter_by(id=id).delete()
db.commit() db.commit()
@ -369,7 +350,6 @@ class ChatTable:
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try: try:
with get_db() as db: with get_db() as db:
db.query(Chat).filter_by(id=id, user_id=user_id).delete() db.query(Chat).filter_by(id=id, user_id=user_id).delete()
db.commit() db.commit()
@ -379,9 +359,7 @@ class ChatTable:
def delete_chats_by_user_id(self, user_id: str) -> bool: def delete_chats_by_user_id(self, user_id: str) -> bool:
try: try:
with get_db() as db: with get_db() as db:
self.delete_shared_chats_by_user_id(user_id) self.delete_shared_chats_by_user_id(user_id)
db.query(Chat).filter_by(user_id=user_id).delete() db.query(Chat).filter_by(user_id=user_id).delete()
@ -393,9 +371,7 @@ class ChatTable:
def delete_shared_chats_by_user_id(self, user_id: str) -> bool: def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
try: try:
with get_db() as db: with get_db() as db:
chats_by_user = db.query(Chat).filter_by(user_id=user_id).all() chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]

View File

@ -1,15 +1,12 @@
from pydantic import BaseModel, ConfigDict import json
from typing import Optional
import time
import logging import logging
import time
from sqlalchemy import String, Column, BigInteger, Text from typing import Optional
from apps.webui.internal.db import Base, get_db from apps.webui.internal.db import Base, get_db
import json
from env import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -70,12 +67,10 @@ class DocumentForm(DocumentUpdateForm):
class DocumentsTable: class DocumentsTable:
def insert_new_doc( def insert_new_doc(
self, user_id: str, form_data: DocumentForm self, user_id: str, form_data: DocumentForm
) -> Optional[DocumentModel]: ) -> Optional[DocumentModel]:
with get_db() as db: with get_db() as db:
document = DocumentModel( document = DocumentModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
@ -99,7 +94,6 @@ class DocumentsTable:
def get_doc_by_name(self, name: str) -> Optional[DocumentModel]: def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
try: try:
with get_db() as db: with get_db() as db:
document = db.query(Document).filter_by(name=name).first() document = db.query(Document).filter_by(name=name).first()
return DocumentModel.model_validate(document) if document else None return DocumentModel.model_validate(document) if document else None
except Exception: except Exception:
@ -107,7 +101,6 @@ class DocumentsTable:
def get_docs(self) -> list[DocumentModel]: def get_docs(self) -> list[DocumentModel]:
with get_db() as db: with get_db() as db:
return [ return [
DocumentModel.model_validate(doc) for doc in db.query(Document).all() DocumentModel.model_validate(doc) for doc in db.query(Document).all()
] ]
@ -117,7 +110,6 @@ class DocumentsTable:
) -> Optional[DocumentModel]: ) -> Optional[DocumentModel]:
try: try:
with get_db() as db: with get_db() as db:
db.query(Document).filter_by(name=name).update( db.query(Document).filter_by(name=name).update(
{ {
"title": form_data.title, "title": form_data.title,
@ -140,7 +132,6 @@ class DocumentsTable:
doc_content = {**doc_content, **updated} doc_content = {**doc_content, **updated}
with get_db() as db: with get_db() as db:
db.query(Document).filter_by(name=name).update( db.query(Document).filter_by(name=name).update(
{ {
"content": json.dumps(doc_content), "content": json.dumps(doc_content),
@ -156,7 +147,6 @@ class DocumentsTable:
def delete_doc_by_name(self, name: str) -> bool: def delete_doc_by_name(self, name: str) -> bool:
try: try:
with get_db() as db: with get_db() as db:
db.query(Document).filter_by(name=name).delete() db.query(Document).filter_by(name=name).delete()
db.commit() db.commit()
return True return True

View File

@ -1,15 +1,11 @@
from pydantic import BaseModel, ConfigDict
from typing import Union, Optional
import time
import logging import logging
import time
from typing import Optional
from sqlalchemy import Column, String, BigInteger, Text from apps.webui.internal.db import Base, JSONField, get_db
from apps.webui.internal.db import JSONField, Base, get_db
import json
from env import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -59,10 +55,8 @@ class FileForm(BaseModel):
class FilesTable: class FilesTable:
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]: def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
with get_db() as db: with get_db() as db:
file = FileModel( file = FileModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
@ -86,7 +80,6 @@ class FilesTable:
def get_file_by_id(self, id: str) -> Optional[FileModel]: def get_file_by_id(self, id: str) -> Optional[FileModel]:
with get_db() as db: with get_db() as db:
try: try:
file = db.get(File, id) file = db.get(File, id)
return FileModel.model_validate(file) return FileModel.model_validate(file)
@ -95,7 +88,6 @@ class FilesTable:
def get_files(self) -> list[FileModel]: def get_files(self) -> list[FileModel]:
with get_db() as db: with get_db() as db:
return [FileModel.model_validate(file) for file in db.query(File).all()] return [FileModel.model_validate(file) for file in db.query(File).all()]
def get_files_by_user_id(self, user_id: str) -> list[FileModel]: def get_files_by_user_id(self, user_id: str) -> list[FileModel]:
@ -106,9 +98,7 @@ class FilesTable:
] ]
def delete_file_by_id(self, id: str) -> bool: def delete_file_by_id(self, id: str) -> bool:
with get_db() as db: with get_db() as db:
try: try:
db.query(File).filter_by(id=id).delete() db.query(File).filter_by(id=id).delete()
db.commit() db.commit()
@ -118,9 +108,7 @@ class FilesTable:
return False return False
def delete_all_files(self) -> bool: def delete_all_files(self) -> bool:
with get_db() as db: with get_db() as db:
try: try:
db.query(File).delete() db.query(File).delete()
db.commit() db.commit()

View File

@ -1,18 +1,12 @@
from pydantic import BaseModel, ConfigDict
from typing import Union, Optional
import time
import logging import logging
import time
from typing import Optional
from sqlalchemy import Column, String, Text, BigInteger, Boolean from apps.webui.internal.db import Base, JSONField, get_db
from apps.webui.internal.db import JSONField, Base, get_db
from apps.webui.models.users import Users from apps.webui.models.users import Users
import json
import copy
from env import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -87,11 +81,9 @@ class FunctionValves(BaseModel):
class FunctionsTable: class FunctionsTable:
def insert_new_function( def insert_new_function(
self, user_id: str, type: str, form_data: FunctionForm self, user_id: str, type: str, form_data: FunctionForm
) -> Optional[FunctionModel]: ) -> Optional[FunctionModel]:
function = FunctionModel( function = FunctionModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
@ -119,7 +111,6 @@ class FunctionsTable:
def get_function_by_id(self, id: str) -> Optional[FunctionModel]: def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
try: try:
with get_db() as db: with get_db() as db:
function = db.get(Function, id) function = db.get(Function, id)
return FunctionModel.model_validate(function) return FunctionModel.model_validate(function)
except Exception: except Exception:
@ -127,7 +118,6 @@ class FunctionsTable:
def get_functions(self, active_only=False) -> list[FunctionModel]: def get_functions(self, active_only=False) -> list[FunctionModel]:
with get_db() as db: with get_db() as db:
if active_only: if active_only:
return [ return [
FunctionModel.model_validate(function) FunctionModel.model_validate(function)
@ -143,7 +133,6 @@ class FunctionsTable:
self, type: str, active_only=False self, type: str, active_only=False
) -> list[FunctionModel]: ) -> list[FunctionModel]:
with get_db() as db: with get_db() as db:
if active_only: if active_only:
return [ return [
FunctionModel.model_validate(function) FunctionModel.model_validate(function)
@ -159,7 +148,6 @@ class FunctionsTable:
def get_global_filter_functions(self) -> list[FunctionModel]: def get_global_filter_functions(self) -> list[FunctionModel]:
with get_db() as db: with get_db() as db:
return [ return [
FunctionModel.model_validate(function) FunctionModel.model_validate(function)
for function in db.query(Function) for function in db.query(Function)
@ -178,7 +166,6 @@ class FunctionsTable:
def get_function_valves_by_id(self, id: str) -> Optional[dict]: def get_function_valves_by_id(self, id: str) -> Optional[dict]:
with get_db() as db: with get_db() as db:
try: try:
function = db.get(Function, id) function = db.get(Function, id)
return function.valves if function.valves else {} return function.valves if function.valves else {}
@ -190,7 +177,6 @@ class FunctionsTable:
self, id: str, valves: dict self, id: str, valves: dict
) -> Optional[FunctionValves]: ) -> Optional[FunctionValves]:
with get_db() as db: with get_db() as db:
try: try:
function = db.get(Function, id) function = db.get(Function, id)
function.valves = valves function.valves = valves
@ -204,7 +190,6 @@ class FunctionsTable:
def get_user_valves_by_id_and_user_id( def get_user_valves_by_id_and_user_id(
self, id: str, user_id: str self, id: str, user_id: str
) -> Optional[dict]: ) -> Optional[dict]:
try: try:
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump() if user.settings else {} user_settings = user.settings.model_dump() if user.settings else {}
@ -223,7 +208,6 @@ class FunctionsTable:
def update_user_valves_by_id_and_user_id( def update_user_valves_by_id_and_user_id(
self, id: str, user_id: str, valves: dict self, id: str, user_id: str, valves: dict
) -> Optional[dict]: ) -> Optional[dict]:
try: try:
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump() if user.settings else {} user_settings = user.settings.model_dump() if user.settings else {}
@ -246,7 +230,6 @@ class FunctionsTable:
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
with get_db() as db: with get_db() as db:
try: try:
db.query(Function).filter_by(id=id).update( db.query(Function).filter_by(id=id).update(
{ {
@ -261,7 +244,6 @@ class FunctionsTable:
def deactivate_all_functions(self) -> Optional[bool]: def deactivate_all_functions(self) -> Optional[bool]:
with get_db() as db: with get_db() as db:
try: try:
db.query(Function).update( db.query(Function).update(
{ {

View File

@ -1,12 +1,10 @@
from pydantic import BaseModel, ConfigDict
from typing import Union, Optional
from sqlalchemy import Column, String, BigInteger, Text
from apps.webui.internal.db import Base, get_db
import time import time
import uuid import uuid
from typing import Optional
from apps.webui.internal.db import Base, get_db
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text
#################### ####################
# Memory DB Schema # Memory DB Schema
@ -39,13 +37,11 @@ class MemoryModel(BaseModel):
class MemoriesTable: class MemoriesTable:
def insert_new_memory( def insert_new_memory(
self, self,
user_id: str, user_id: str,
content: str, content: str,
) -> Optional[MemoryModel]: ) -> Optional[MemoryModel]:
with get_db() as db: with get_db() as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
@ -73,7 +69,6 @@ class MemoriesTable:
content: str, content: str,
) -> Optional[MemoryModel]: ) -> Optional[MemoryModel]:
with get_db() as db: with get_db() as db:
try: try:
db.query(Memory).filter_by(id=id).update( db.query(Memory).filter_by(id=id).update(
{"content": content, "updated_at": int(time.time())} {"content": content, "updated_at": int(time.time())}
@ -85,7 +80,6 @@ class MemoriesTable:
def get_memories(self) -> list[MemoryModel]: def get_memories(self) -> list[MemoryModel]:
with get_db() as db: with get_db() as db:
try: try:
memories = db.query(Memory).all() memories = db.query(Memory).all()
return [MemoryModel.model_validate(memory) for memory in memories] return [MemoryModel.model_validate(memory) for memory in memories]
@ -94,7 +88,6 @@ class MemoriesTable:
def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]: def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]:
with get_db() as db: with get_db() as db:
try: try:
memories = db.query(Memory).filter_by(user_id=user_id).all() memories = db.query(Memory).filter_by(user_id=user_id).all()
return [MemoryModel.model_validate(memory) for memory in memories] return [MemoryModel.model_validate(memory) for memory in memories]
@ -103,7 +96,6 @@ class MemoriesTable:
def get_memory_by_id(self, id: str) -> Optional[MemoryModel]: def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
with get_db() as db: with get_db() as db:
try: try:
memory = db.get(Memory, id) memory = db.get(Memory, id)
return MemoryModel.model_validate(memory) return MemoryModel.model_validate(memory)
@ -112,7 +104,6 @@ class MemoriesTable:
def delete_memory_by_id(self, id: str) -> bool: def delete_memory_by_id(self, id: str) -> bool:
with get_db() as db: with get_db() as db:
try: try:
db.query(Memory).filter_by(id=id).delete() db.query(Memory).filter_by(id=id).delete()
db.commit() db.commit()
@ -124,7 +115,6 @@ class MemoriesTable:
def delete_memories_by_user_id(self, user_id: str) -> bool: def delete_memories_by_user_id(self, user_id: str) -> bool:
with get_db() as db: with get_db() as db:
try: try:
db.query(Memory).filter_by(user_id=user_id).delete() db.query(Memory).filter_by(user_id=user_id).delete()
db.commit() db.commit()
@ -135,7 +125,6 @@ class MemoriesTable:
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool: def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
with get_db() as db: with get_db() as db:
try: try:
db.query(Memory).filter_by(id=id, user_id=user_id).delete() db.query(Memory).filter_by(id=id, user_id=user_id).delete()
db.commit() db.commit()

View File

@ -1,14 +1,11 @@
import logging import logging
from typing import Optional, List import time
from typing import Optional
from pydantic import BaseModel, ConfigDict
from sqlalchemy import Column, BigInteger, Text
from apps.webui.internal.db import Base, JSONField, get_db from apps.webui.internal.db import Base, JSONField, get_db
from env import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
import time from sqlalchemy import BigInteger, Column, Text
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])

View File

@ -1,12 +1,9 @@
from pydantic import BaseModel, ConfigDict
from typing import Optional
import time import time
from typing import Optional
from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import Base, get_db from apps.webui.internal.db import Base, get_db
from pydantic import BaseModel, ConfigDict
import json from sqlalchemy import BigInteger, Column, String, Text
#################### ####################
# Prompts DB Schema # Prompts DB Schema
@ -45,7 +42,6 @@ class PromptForm(BaseModel):
class PromptsTable: class PromptsTable:
def insert_new_prompt( def insert_new_prompt(
self, user_id: str, form_data: PromptForm self, user_id: str, form_data: PromptForm
) -> Optional[PromptModel]: ) -> Optional[PromptModel]:
@ -61,7 +57,6 @@ class PromptsTable:
try: try:
with get_db() as db: with get_db() as db:
result = Prompt(**prompt.dict()) result = Prompt(**prompt.dict())
db.add(result) db.add(result)
db.commit() db.commit()
@ -70,13 +65,12 @@ class PromptsTable:
return PromptModel.model_validate(result) return PromptModel.model_validate(result)
else: else:
return None return None
except Exception as e: except Exception:
return None return None
def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
try: try:
with get_db() as db: with get_db() as db:
prompt = db.query(Prompt).filter_by(command=command).first() prompt = db.query(Prompt).filter_by(command=command).first()
return PromptModel.model_validate(prompt) return PromptModel.model_validate(prompt)
except Exception: except Exception:
@ -84,7 +78,6 @@ class PromptsTable:
def get_prompts(self) -> list[PromptModel]: def get_prompts(self) -> list[PromptModel]:
with get_db() as db: with get_db() as db:
return [ return [
PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all() PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
] ]
@ -94,7 +87,6 @@ class PromptsTable:
) -> Optional[PromptModel]: ) -> Optional[PromptModel]:
try: try:
with get_db() as db: with get_db() as db:
prompt = db.query(Prompt).filter_by(command=command).first() prompt = db.query(Prompt).filter_by(command=command).first()
prompt.title = form_data.title prompt.title = form_data.title
prompt.content = form_data.content prompt.content = form_data.content
@ -107,7 +99,6 @@ class PromptsTable:
def delete_prompt_by_command(self, command: str) -> bool: def delete_prompt_by_command(self, command: str) -> bool:
try: try:
with get_db() as db: with get_db() as db:
db.query(Prompt).filter_by(command=command).delete() db.query(Prompt).filter_by(command=command).delete()
db.commit() db.commit()

View File

@ -1,16 +1,12 @@
from pydantic import BaseModel, ConfigDict import logging
import time
import uuid
from typing import Optional from typing import Optional
import json
import uuid
import time
import logging
from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import Base, get_db from apps.webui.internal.db import Base, get_db
from env import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -77,10 +73,8 @@ class ChatTagsResponse(BaseModel):
class TagTable: class TagTable:
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
with get_db() as db: with get_db() as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try: try:
@ -92,7 +86,7 @@ class TagTable:
return TagModel.model_validate(result) return TagModel.model_validate(result)
else: else:
return None return None
except Exception as e: except Exception:
return None return None
def get_tag_by_name_and_user_id( def get_tag_by_name_and_user_id(
@ -102,7 +96,7 @@ class TagTable:
with get_db() as db: with get_db() as db:
tag = db.query(Tag).filter_by(name=name, user_id=user_id).first() tag = db.query(Tag).filter_by(name=name, user_id=user_id).first()
return TagModel.model_validate(tag) return TagModel.model_validate(tag)
except Exception as e: except Exception:
return None return None
def add_tag_to_chat( def add_tag_to_chat(
@ -161,7 +155,6 @@ class TagTable:
self, chat_id: str, user_id: str self, chat_id: str, user_id: str
) -> list[TagModel]: ) -> list[TagModel]:
with get_db() as db: with get_db() as db:
tag_names = [ tag_names = [
chat_id_tag.tag_name chat_id_tag.tag_name
for chat_id_tag in ( for chat_id_tag in (
@ -186,7 +179,6 @@ class TagTable:
self, tag_name: str, user_id: str self, tag_name: str, user_id: str
) -> list[ChatIdTagModel]: ) -> list[ChatIdTagModel]:
with get_db() as db: with get_db() as db:
return [ return [
ChatIdTagModel.model_validate(chat_id_tag) ChatIdTagModel.model_validate(chat_id_tag)
for chat_id_tag in ( for chat_id_tag in (
@ -201,7 +193,6 @@ class TagTable:
self, tag_name: str, user_id: str self, tag_name: str, user_id: str
) -> int: ) -> int:
with get_db() as db: with get_db() as db:
return ( return (
db.query(ChatIdTag) db.query(ChatIdTag)
.filter_by(tag_name=tag_name, user_id=user_id) .filter_by(tag_name=tag_name, user_id=user_id)
@ -236,7 +227,6 @@ class TagTable:
) -> bool: ) -> bool:
try: try:
with get_db() as db: with get_db() as db:
res = ( res = (
db.query(ChatIdTag) db.query(ChatIdTag)
.filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)

View File

@ -1,17 +1,12 @@
from pydantic import BaseModel, ConfigDict
from typing import Optional
import time
import logging import logging
from sqlalchemy import String, Column, BigInteger, Text import time
from typing import Optional
from apps.webui.internal.db import Base, JSONField, get_db from apps.webui.internal.db import Base, JSONField, get_db
from apps.webui.models.users import Users from apps.webui.models.users import Users
import json
import copy
from env import SRC_LOG_LEVELS from env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -79,13 +74,10 @@ class ToolValves(BaseModel):
class ToolsTable: class ToolsTable:
def insert_new_tool( def insert_new_tool(
self, user_id: str, form_data: ToolForm, specs: list[dict] self, user_id: str, form_data: ToolForm, specs: list[dict]
) -> Optional[ToolModel]: ) -> Optional[ToolModel]:
with get_db() as db: with get_db() as db:
tool = ToolModel( tool = ToolModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
@ -112,7 +104,6 @@ class ToolsTable:
def get_tool_by_id(self, id: str) -> Optional[ToolModel]: def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
try: try:
with get_db() as db: with get_db() as db:
tool = db.get(Tool, id) tool = db.get(Tool, id)
return ToolModel.model_validate(tool) return ToolModel.model_validate(tool)
except Exception: except Exception:
@ -125,7 +116,6 @@ class ToolsTable:
def get_tool_valves_by_id(self, id: str) -> Optional[dict]: def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
try: try:
with get_db() as db: with get_db() as db:
tool = db.get(Tool, id) tool = db.get(Tool, id)
return tool.valves if tool.valves else {} return tool.valves if tool.valves else {}
except Exception as e: except Exception as e:
@ -135,7 +125,6 @@ class ToolsTable:
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
try: try:
with get_db() as db: with get_db() as db:
db.query(Tool).filter_by(id=id).update( db.query(Tool).filter_by(id=id).update(
{"valves": valves, "updated_at": int(time.time())} {"valves": valves, "updated_at": int(time.time())}
) )

View File

@ -1,11 +1,10 @@
from pydantic import BaseModel, ConfigDict
from typing import Optional
import time import time
from typing import Optional
from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import Base, JSONField, get_db from apps.webui.internal.db import Base, JSONField, get_db
from apps.webui.models.chats import Chats from apps.webui.models.chats import Chats
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text
#################### ####################
# User DB Schema # User DB Schema
@ -113,7 +112,7 @@ class UsersTable:
with get_db() as db: with get_db() as db:
user = db.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception as e: except Exception:
return None return None
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
@ -221,7 +220,7 @@ class UsersTable:
user = db.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
# return UserModel(**user.dict()) # return UserModel(**user.dict())
except Exception as e: except Exception:
return None return None
def delete_user_by_id(self, id: str) -> bool: def delete_user_by_id(self, id: str) -> bool:
@ -255,7 +254,7 @@ class UsersTable:
with get_db() as db: with get_db() as db:
user = db.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
return user.api_key return user.api_key
except Exception as e: except Exception:
return None return None

View File

@ -1,43 +1,33 @@
import logging
from fastapi import Request, UploadFile, File
from fastapi import Depends, HTTPException, status
from fastapi.responses import Response
from fastapi import APIRouter
from pydantic import BaseModel
import re import re
import uuid import uuid
import csv
from apps.webui.models.auths import ( from apps.webui.models.auths import (
SigninForm,
SignupForm,
AddUserForm, AddUserForm,
UpdateProfileForm,
UpdatePasswordForm,
UserResponse,
SigninResponse,
Auths,
ApiKey, ApiKey,
Auths,
SigninForm,
SigninResponse,
SignupForm,
UpdatePasswordForm,
UpdateProfileForm,
UserResponse,
) )
from apps.webui.models.users import Users from apps.webui.models.users import Users
from config import WEBUI_AUTH
from utils.utils import (
get_password_hash,
get_current_user,
get_admin_user,
create_token,
create_api_key,
)
from utils.misc import parse_duration, validate_email_format
from utils.webhook import post_webhook
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from config import ( from env import WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER
WEBUI_AUTH, from fastapi import APIRouter, Depends, HTTPException, Request, status
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, from fastapi.responses import Response
WEBUI_AUTH_TRUSTED_NAME_HEADER, from pydantic import BaseModel
from utils.misc import parse_duration, validate_email_format
from utils.utils import (
create_api_key,
create_token,
get_admin_user,
get_current_user,
get_password_hash,
) )
from utils.webhook import post_webhook
router = APIRouter() router = APIRouter()
@ -273,7 +263,6 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
@router.post("/add", response_model=SigninResponse) @router.post("/add", response_model=SigninResponse)
async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)): async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
if not validate_email_format(form_data.email.lower()): if not validate_email_format(form_data.email.lower()):
raise HTTPException( raise HTTPException(
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
@ -283,7 +272,6 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
try: try:
print(form_data) print(form_data)
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth( user = Auths.insert_new_auth(

View File

@ -1,34 +1,15 @@
from fastapi import Depends, Request, HTTPException, status
from datetime import datetime, timedelta
from typing import Union, Optional
from utils.utils import get_verified_user, get_admin_user
from fastapi import APIRouter
from pydantic import BaseModel
import json import json
import logging import logging
from typing import Optional
from apps.webui.models.users import Users from apps.webui.models.chats import ChatForm, ChatResponse, Chats, ChatTitleIdResponse
from apps.webui.models.chats import ( from apps.webui.models.tags import ChatIdTagForm, ChatIdTagModel, TagModel, Tags
ChatModel, from config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
ChatResponse,
ChatTitleForm,
ChatForm,
ChatTitleIdResponse,
Chats,
)
from apps.webui.models.tags import (
TagModel,
ChatIdTagModel,
ChatIdTagForm,
ChatTagsResponse,
Tags,
)
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from env import SRC_LOG_LEVELS
from config import SRC_LOG_LEVELS, ENABLE_ADMIN_EXPORT, ENABLE_ADMIN_CHAT_ACCESS from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel
from utils.utils import get_admin_user, get_verified_user
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -61,7 +42,6 @@ async def get_session_user_chat_list(
@router.delete("/", response_model=bool) @router.delete("/", response_model=bool)
async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)): async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
if ( if (
user.role == "user" user.role == "user"
and not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"] and not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]
@ -220,7 +200,6 @@ class TagNameForm(BaseModel):
async def get_user_chat_list_by_tag_name( async def get_user_chat_list_by_tag_name(
form_data: TagNameForm, user=Depends(get_verified_user) form_data: TagNameForm, user=Depends(get_verified_user)
): ):
chat_ids = [ chat_ids = [
chat_id_tag.chat_id chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id( for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
@ -299,7 +278,6 @@ async def update_chat_by_id(
@router.delete("/{id}", response_model=bool) @router.delete("/{id}", response_model=bool)
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)): async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
if user.role == "admin": if user.role == "admin":
result = Chats.delete_chat_by_id(id) result = Chats.delete_chat_by_id(id)
return result return result
@ -323,7 +301,6 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
async def clone_chat_by_id(id: str, user=Depends(get_verified_user)): async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
chat_body = json.loads(chat.chat) chat_body = json.loads(chat.chat)
updated_chat = { updated_chat = {
**chat_body, **chat_body,

View File

@ -1,25 +1,7 @@
from fastapi import Response, Request
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import Union
from fastapi import APIRouter
from pydantic import BaseModel
import time
import uuid
from config import BannerModel from config import BannerModel
from fastapi import APIRouter, Depends, Request
from apps.webui.models.users import Users from pydantic import BaseModel
from utils.utils import get_admin_user, get_verified_user
from utils.utils import (
get_password_hash,
get_verified_user,
get_admin_user,
create_token,
)
from utils.misc import get_gravatar_url, validate_email_format
from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()

View File

@ -1,21 +1,16 @@
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json import json
from typing import Optional
from apps.webui.models.documents import ( from apps.webui.models.documents import (
Documents,
DocumentForm, DocumentForm,
DocumentUpdateForm,
DocumentModel,
DocumentResponse, DocumentResponse,
Documents,
DocumentUpdateForm,
) )
from utils.utils import get_verified_user, get_admin_user
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from utils.utils import get_admin_user, get_verified_user
router = APIRouter() router = APIRouter()

View File

@ -1,42 +1,17 @@
from fastapi import ( import logging
Depends,
FastAPI,
HTTPException,
status,
Request,
UploadFile,
File,
Form,
)
from datetime import datetime, timedelta
from typing import Union, Optional
from pathlib import Path
from fastapi import APIRouter
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from pydantic import BaseModel
import json
from apps.webui.models.files import (
Files,
FileForm,
FileModel,
FileModelResponse,
)
from utils.utils import get_verified_user, get_admin_user
from constants import ERROR_MESSAGES
from importlib import util
import os import os
import shutil
import uuid import uuid
import os, shutil, logging, re from pathlib import Path
from typing import Optional
from config import SRC_LOG_LEVELS, UPLOAD_DIR
from apps.webui.models.files import FileForm, FileModel, Files
from config import UPLOAD_DIR
from constants import ERROR_MESSAGES
from env import SRC_LOG_LEVELS
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
from fastapi.responses import FileResponse
from utils.utils import get_admin_user, get_verified_user
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])

View File

@ -1,27 +1,18 @@
from fastapi import Depends, FastAPI, HTTPException, status, Request import os
from datetime import datetime, timedelta from pathlib import Path
from typing import Union, Optional from typing import Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.models.functions import ( from apps.webui.models.functions import (
Functions,
FunctionForm, FunctionForm,
FunctionModel, FunctionModel,
FunctionResponse, FunctionResponse,
Functions,
) )
from apps.webui.utils import load_function_module_by_id from apps.webui.utils import load_function_module_by_id
from utils.utils import get_verified_user, get_admin_user from config import CACHE_DIR, FUNCTIONS_DIR
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from importlib import util from utils.utils import get_admin_user, get_verified_user
import os
from pathlib import Path
from config import DATA_DIR, CACHE_DIR, FUNCTIONS_DIR
router = APIRouter() router = APIRouter()
@ -304,7 +295,6 @@ async def update_function_valves_by_id(
): ):
function = Functions.get_function_by_id(id) function = Functions.get_function_by_id(id)
if function: if function:
if id in request.app.state.FUNCTIONS: if id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[id] function_module = request.app.state.FUNCTIONS[id]
else: else:

View File

@ -1,18 +1,12 @@
from fastapi import Response, Request
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import logging import logging
from typing import Optional
from apps.webui.models.memories import Memories, MemoryModel from apps.webui.models.memories import Memories, MemoryModel
from config import CHROMA_CLIENT
from env import SRC_LOG_LEVELS
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel
from utils.utils import get_verified_user from utils.utils import get_verified_user
from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])

View File

@ -1,15 +1,9 @@
from fastapi import Depends, FastAPI, HTTPException, status, Request from typing import Optional
from datetime import datetime, timedelta
from typing import Union, Optional
from fastapi import APIRouter from apps.webui.models.models import ModelForm, ModelModel, ModelResponse, Models
from pydantic import BaseModel
import json
from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
from utils.utils import get_verified_user, get_admin_user
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from utils.utils import get_admin_user, get_verified_user
router = APIRouter() router = APIRouter()

View File

@ -1,15 +1,9 @@
from fastapi import Depends, FastAPI, HTTPException, status from typing import Optional
from datetime import datetime, timedelta
from typing import Union, Optional
from fastapi import APIRouter from apps.webui.models.prompts import PromptForm, PromptModel, Prompts
from pydantic import BaseModel
import json
from apps.webui.models.prompts import Prompts, PromptForm, PromptModel
from utils.utils import get_verified_user, get_admin_user
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, status
from utils.utils import get_admin_user, get_verified_user
router = APIRouter() router = APIRouter()

View File

@ -1,20 +1,14 @@
from fastapi import Depends, HTTPException, status, Request
from typing import Optional
from fastapi import APIRouter
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id
from utils.utils import get_admin_user, get_verified_user
from utils.tools import get_tools_specs
from constants import ERROR_MESSAGES
import os import os
from pathlib import Path from pathlib import Path
from typing import Optional
from config import DATA_DIR, CACHE_DIR from apps.webui.models.tools import ToolForm, ToolModel, ToolResponse, Tools
from apps.webui.utils import load_toolkit_module_by_id
from config import CACHE_DIR, DATA_DIR
from constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from utils.tools import get_tools_specs
from utils.utils import get_admin_user, get_verified_user
TOOLS_DIR = f"{DATA_DIR}/tools" TOOLS_DIR = f"{DATA_DIR}/tools"
os.makedirs(TOOLS_DIR, exist_ok=True) os.makedirs(TOOLS_DIR, exist_ok=True)

View File

@ -1,33 +1,20 @@
from fastapi import Response, Request
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import time
import uuid
import logging import logging
from typing import Optional
from apps.webui.models.users import (
UserModel,
UserUpdateForm,
UserRoleUpdateForm,
UserSettings,
Users,
)
from apps.webui.models.auths import Auths from apps.webui.models.auths import Auths
from apps.webui.models.chats import Chats from apps.webui.models.chats import Chats
from apps.webui.models.users import (
from utils.utils import ( UserModel,
get_verified_user, UserRoleUpdateForm,
get_password_hash, Users,
get_current_user, UserSettings,
get_admin_user, UserUpdateForm,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from env import SRC_LOG_LEVELS
from config import SRC_LOG_LEVELS from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel
from utils.utils import get_admin_user, get_password_hash, get_verified_user
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -69,7 +56,6 @@ async def update_user_permissions(
@router.post("/update/role", response_model=Optional[UserModel]) @router.post("/update/role", response_model=Optional[UserModel])
async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)): async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)):
if user.id != form_data.id and form_data.id != Users.get_first_user().id: if user.id != form_data.id and form_data.id != Users.get_first_user().id:
return Users.update_user_role_by_id(form_data.id, form_data.role) return Users.update_user_role_by_id(form_data.id, form_data.role)
@ -173,7 +159,6 @@ class UserResponse(BaseModel):
@router.get("/{user_id}", response_model=UserResponse) @router.get("/{user_id}", response_model=UserResponse)
async def get_user_by_id(user_id: str, user=Depends(get_verified_user)): async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
# Check if user_id is a shared chat # Check if user_id is a shared chat
# If it is, get the user_id from the chat # If it is, get the user_id from the chat
if user_id.startswith("shared-"): if user_id.startswith("shared-"):

View File

@ -1,23 +1,16 @@
from pathlib import Path
import site import site
from pathlib import Path
from fastapi import APIRouter, UploadFile, File, Response
from fastapi import Depends, HTTPException, status
from starlette.responses import StreamingResponse, FileResponse
from pydantic import BaseModel
from fpdf import FPDF
import markdown
import black import black
import markdown
from config import DATA_DIR, ENABLE_ADMIN_EXPORT
from utils.utils import get_admin_user
from utils.misc import calculate_sha256, get_gravatar_url
from config import OLLAMA_BASE_URLS, DATA_DIR, UPLOAD_DIR, ENABLE_ADMIN_EXPORT
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Response, status
from fpdf import FPDF
from pydantic import BaseModel
from starlette.responses import FileResponse
from utils.misc import get_gravatar_url
from utils.utils import get_admin_user
router = APIRouter() router = APIRouter()
@ -115,7 +108,7 @@ async def download_chat_as_pdf(
return Response( return Response(
content=bytes(pdf_bytes), content=bytes(pdf_bytes),
media_type="application/pdf", media_type="application/pdf",
headers={"Content-Disposition": f"attachment;filename=chat.pdf"}, headers={"Content-Disposition": "attachment;filename=chat.pdf"},
) )

View File

@ -1,13 +1,12 @@
from importlib import util
import os import os
import re import re
import sys
import subprocess import subprocess
import sys
from importlib import util
from apps.webui.models.tools import Tools
from apps.webui.models.functions import Functions from apps.webui.models.functions import Functions
from config import TOOLS_DIR, FUNCTIONS_DIR from apps.webui.models.tools import Tools
from config import FUNCTIONS_DIR, TOOLS_DIR
def extract_frontmatter(file_path): def extract_frontmatter(file_path):

View File

@ -1,58 +1,30 @@
from sqlalchemy import create_engine, Column, Integer, DateTime, JSON, func import json
from contextlib import contextmanager
import os
import sys
import logging import logging
import importlib.metadata import os
import pkgutil import shutil
from urllib.parse import urlparse
from datetime import datetime from datetime import datetime
from pathlib import Path
from typing import Generic, Optional, TypeVar
from urllib.parse import urlparse
import chromadb import chromadb
from chromadb import Settings
from typing import TypeVar, Generic
from pydantic import BaseModel
from typing import Optional
from pathlib import Path
import json
import yaml
import requests import requests
import shutil import yaml
from apps.webui.internal.db import Base, get_db from apps.webui.internal.db import Base, get_db
from chromadb import Settings
from constants import ERROR_MESSAGES
from env import ( from env import (
ENV,
VERSION,
SAFE_MODE,
GLOBAL_LOG_LEVEL,
SRC_LOG_LEVELS,
BASE_DIR,
DATA_DIR,
BACKEND_DIR, BACKEND_DIR,
FRONTEND_BUILD_DIR,
WEBUI_NAME,
WEBUI_URL,
WEBUI_FAVICON_URL,
WEBUI_BUILD_HASH,
CONFIG_DATA, CONFIG_DATA,
DATABASE_URL, DATA_DIR,
CHANGELOG, ENV,
FRONTEND_BUILD_DIR,
WEBUI_AUTH, WEBUI_AUTH,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_FAVICON_URL,
WEBUI_AUTH_TRUSTED_NAME_HEADER, WEBUI_NAME,
WEBUI_SECRET_KEY,
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
log, log,
) )
from pydantic import BaseModel
from sqlalchemy import JSON, Column, DateTime, Integer, func
class EndpointFilter(logging.Filter): class EndpointFilter(logging.Filter):
@ -72,8 +44,8 @@ logging.getLogger("uvicorn.access").addFilter(EndpointFilter())
def run_migrations(): def run_migrations():
print("Running migrations") print("Running migrations")
try: try:
from alembic.config import Config
from alembic import command from alembic import command
from alembic.config import Config
alembic_cfg = Config(BACKEND_DIR / "alembic.ini") alembic_cfg = Config(BACKEND_DIR / "alembic.ini")
command.upgrade(alembic_cfg, "head") command.upgrade(alembic_cfg, "head")

View File

@ -1,19 +1,13 @@
from pathlib import Path
import os
import logging
import sys
import json
import importlib.metadata import importlib.metadata
import json
import logging
import os
import pkgutil import pkgutil
from urllib.parse import urlparse import sys
from datetime import datetime from pathlib import Path
import markdown import markdown
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
#################################### ####################################
@ -26,7 +20,7 @@ BASE_DIR = BACKEND_DIR.parent # the path containing the backend/
print(BASE_DIR) print(BASE_DIR)
try: try:
from dotenv import load_dotenv, find_dotenv from dotenv import find_dotenv, load_dotenv
load_dotenv(find_dotenv(str(BASE_DIR / ".env"))) load_dotenv(find_dotenv(str(BASE_DIR / ".env")))
except ImportError: except ImportError:

View File

@ -1,130 +1,124 @@
import base64 import base64
import inspect
import json
import logging
import mimetypes
import os
import shutil
import sys
import time
import uuid import uuid
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
import json
import time
import os
import sys
import logging
import aiohttp
import requests
import mimetypes
import shutil
import inspect
from typing import Optional from typing import Optional
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form import aiohttp
from fastapi.staticfiles import StaticFiles import requests
from fastapi.responses import JSONResponse from apps.audio.main import app as audio_app
from fastapi import HTTPException from apps.images.main import app as images_app
from apps.ollama.main import app as ollama_app
from apps.ollama.main import (
generate_openai_chat_completion as generate_ollama_chat_completion,
)
from apps.ollama.main import get_all_models as get_ollama_models
from apps.openai.main import app as openai_app
from apps.openai.main import generate_chat_completion as generate_openai_chat_completion
from apps.openai.main import get_all_models as get_openai_models
from apps.rag.main import app as rag_app
from apps.rag.utils import get_rag_context, rag_template
from apps.socket.main import app as socket_app
from apps.socket.main import get_event_call, get_event_emitter
from apps.webui.internal.db import Session
from apps.webui.main import app as webui_app
from apps.webui.main import generate_function_chat_completion, get_pipe_models
from apps.webui.models.auths import Auths
from apps.webui.models.functions import Functions
from apps.webui.models.models import Models
from apps.webui.models.users import UserModel, Users
from apps.webui.utils import load_function_module_by_id
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from config import (
CACHE_DIR,
CORS_ALLOW_ORIGIN,
DEFAULT_LOCALE,
ENABLE_ADMIN_CHAT_ACCESS,
ENABLE_ADMIN_EXPORT,
ENABLE_MODEL_FILTER,
ENABLE_OAUTH_SIGNUP,
ENABLE_OLLAMA_API,
ENABLE_OPENAI_API,
ENV,
FRONTEND_BUILD_DIR,
MODEL_FILTER_LIST,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
OAUTH_PROVIDERS,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
STATIC_DIR,
TASK_MODEL,
TASK_MODEL_EXTERNAL,
TITLE_GENERATION_PROMPT_TEMPLATE,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
WEBHOOK_URL,
WEBUI_AUTH,
WEBUI_NAME,
AppConfig,
run_migrations,
)
from constants import ERROR_MESSAGES, TASKS, WEBHOOK_MESSAGES
from env import (
CHANGELOG,
GLOBAL_LOG_LEVEL,
SAFE_MODE,
SRC_LOG_LEVELS,
VERSION,
WEBUI_BUILD_HASH,
WEBUI_SECRET_KEY,
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
WEBUI_URL,
)
from fastapi import (
Depends,
FastAPI,
File,
Form,
HTTPException,
Request,
UploadFile,
status,
)
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from sqlalchemy import text from sqlalchemy import text
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import StreamingResponse, Response, RedirectResponse from starlette.responses import RedirectResponse, Response, StreamingResponse
from utils.misc import (
add_or_update_system_message,
from apps.socket.main import app as socket_app, get_event_emitter, get_event_call get_last_user_message,
from apps.ollama.main import ( parse_duration,
app as ollama_app, prepend_to_first_user_message_content,
get_all_models as get_ollama_models,
generate_openai_chat_completion as generate_ollama_chat_completion,
) )
from apps.openai.main import ( from utils.task import (
app as openai_app, moa_response_generation_template,
get_all_models as get_openai_models, search_query_generation_template,
generate_chat_completion as generate_openai_chat_completion, title_generation_template,
tools_function_calling_generation_template,
) )
from utils.tools import get_tools
from apps.audio.main import app as audio_app
from apps.images.main import app as images_app
from apps.rag.main import app as rag_app
from apps.webui.main import (
app as webui_app,
get_pipe_models,
generate_function_chat_completion,
)
from apps.webui.internal.db import Session
from pydantic import BaseModel
from apps.webui.models.auths import Auths
from apps.webui.models.models import Models
from apps.webui.models.functions import Functions
from apps.webui.models.users import Users, UserModel
from apps.webui.utils import load_function_module_by_id
from utils.utils import ( from utils.utils import (
create_token,
decode_token,
get_admin_user, get_admin_user,
get_verified_user,
get_current_user, get_current_user,
get_http_authorization_cred, get_http_authorization_cred,
get_password_hash, get_password_hash,
create_token, get_verified_user,
decode_token,
) )
from utils.task import (
title_generation_template,
search_query_generation_template,
tools_function_calling_generation_template,
moa_response_generation_template,
)
from utils.tools import get_tools
from utils.misc import (
get_last_user_message,
add_or_update_system_message,
prepend_to_first_user_message_content,
parse_duration,
)
from apps.rag.utils import get_rag_context, rag_template
from config import (
run_migrations,
WEBUI_NAME,
WEBUI_URL,
WEBUI_AUTH,
ENV,
VERSION,
CHANGELOG,
FRONTEND_BUILD_DIR,
CACHE_DIR,
STATIC_DIR,
DEFAULT_LOCALE,
ENABLE_OPENAI_API,
ENABLE_OLLAMA_API,
ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST,
GLOBAL_LOG_LEVEL,
SRC_LOG_LEVELS,
WEBHOOK_URL,
ENABLE_ADMIN_EXPORT,
WEBUI_BUILD_HASH,
TASK_MODEL,
TASK_MODEL_EXTERNAL,
TITLE_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
SAFE_MODE,
OAUTH_PROVIDERS,
ENABLE_OAUTH_SIGNUP,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
WEBUI_SECRET_KEY,
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
ENABLE_ADMIN_CHAT_ACCESS,
AppConfig,
CORS_ALLOW_ORIGIN,
)
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
from utils.webhook import post_webhook from utils.webhook import post_webhook
if SAFE_MODE: if SAFE_MODE:

View File

@ -1,24 +1,9 @@
import os
from logging.config import fileConfig from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context from alembic import context
from apps.webui.models.auths import Auth from apps.webui.models.auths import Auth
from apps.webui.models.chats import Chat
from apps.webui.models.documents import Document
from apps.webui.models.memories import Memory
from apps.webui.models.models import Model
from apps.webui.models.prompts import Prompt
from apps.webui.models.tags import Tag, ChatIdTag
from apps.webui.models.tools import Tool
from apps.webui.models.users import User
from apps.webui.models.files import File
from apps.webui.models.functions import Function
from env import DATABASE_URL from env import DATABASE_URL
from sqlalchemy import engine_from_config, pool
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.

View File

@ -1,16 +1,16 @@
"""init """init
Revision ID: 7e5b5dc7342b Revision ID: 7e5b5dc7342b
Revises: Revises:
Create Date: 2024-06-24 13:15:33.808998 Create Date: 2024-06-24 13:15:33.808998
""" """
from typing import Sequence, Union from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import apps.webui.internal.db import apps.webui.internal.db
import sqlalchemy as sa
from alembic import op
from migrations.util import get_existing_tables from migrations.util import get_existing_tables
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.

View File

@ -8,10 +8,8 @@ Create Date: 2024-08-25 15:26:35.241684
from typing import Sequence, Union from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
import apps.webui.internal.db from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "ca81bd47c050" revision: str = "ca81bd47c050"

View File

@ -1,5 +1,3 @@
import pytest
from test.util.abstract_integration_test import AbstractPostgresTest from test.util.abstract_integration_test import AbstractPostgresTest
from test.util.mock_user import mock_webui_user from test.util.mock_user import mock_webui_user
@ -9,8 +7,8 @@ class TestAuths(AbstractPostgresTest):
def setup_class(cls): def setup_class(cls):
super().setup_class() super().setup_class()
from apps.webui.models.users import Users
from apps.webui.models.auths import Auths from apps.webui.models.auths import Auths
from apps.webui.models.users import Users
cls.users = Users cls.users = Users
cls.auths = Auths cls.auths = Auths

View File

@ -5,7 +5,6 @@ from test.util.mock_user import mock_webui_user
class TestChats(AbstractPostgresTest): class TestChats(AbstractPostgresTest):
BASE_PATH = "/api/v1/chats" BASE_PATH = "/api/v1/chats"
def setup_class(cls): def setup_class(cls):
@ -13,8 +12,7 @@ class TestChats(AbstractPostgresTest):
def setup_method(self): def setup_method(self):
super().setup_method() super().setup_method()
from apps.webui.models.chats import ChatForm from apps.webui.models.chats import ChatForm, Chats
from apps.webui.models.chats import Chats
self.chats = Chats self.chats = Chats
self.chats.insert_new_chat( self.chats.insert_new_chat(

View File

@ -3,7 +3,6 @@ from test.util.mock_user import mock_webui_user
class TestDocuments(AbstractPostgresTest): class TestDocuments(AbstractPostgresTest):
BASE_PATH = "/api/v1/documents" BASE_PATH = "/api/v1/documents"
def setup_class(cls): def setup_class(cls):

View File

@ -3,7 +3,6 @@ from test.util.mock_user import mock_webui_user
class TestModels(AbstractPostgresTest): class TestModels(AbstractPostgresTest):
BASE_PATH = "/api/v1/models" BASE_PATH = "/api/v1/models"
def setup_class(cls): def setup_class(cls):

View File

@ -3,7 +3,6 @@ from test.util.mock_user import mock_webui_user
class TestPrompts(AbstractPostgresTest): class TestPrompts(AbstractPostgresTest):
BASE_PATH = "/api/v1/prompts" BASE_PATH = "/api/v1/prompts"
def test_prompts(self): def test_prompts(self):

View File

@ -21,7 +21,6 @@ def _assert_user(data, id, **kwargs):
class TestUsers(AbstractPostgresTest): class TestUsers(AbstractPostgresTest):
BASE_PATH = "/api/v1/users" BASE_PATH = "/api/v1/users"
def setup_class(cls): def setup_class(cls):

View File

@ -1,10 +1,10 @@
from pathlib import Path
import hashlib import hashlib
import re import re
from datetime import timedelta
from typing import Optional, Callable
import uuid
import time import time
import uuid
from datetime import timedelta
from pathlib import Path
from typing import Callable, Optional
from utils.task import prompt_template from utils.task import prompt_template

View File

@ -1,7 +1,7 @@
from ast import literal_eval from ast import literal_eval
from typing import Any, Literal, Optional, Type
from pydantic import BaseModel, Field, create_model from pydantic import BaseModel, Field, create_model
from typing import Any, Optional, Type, Literal
def json_schema_to_model(tool_dict: dict[str, Any]) -> Type[BaseModel]: def json_schema_to_model(tool_dict: dict[str, Any]) -> Type[BaseModel]:

View File

@ -1,6 +1,5 @@
import re
import math import math
import re
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional

View File

@ -5,7 +5,6 @@ from typing import Awaitable, Callable, get_type_hints
from apps.webui.models.tools import Tools from apps.webui.models.tools import Tools
from apps.webui.models.users import UserModel from apps.webui.models.users import UserModel
from apps.webui.utils import load_toolkit_module_by_id from apps.webui.utils import load_toolkit_module_by_id
from utils.schemas import json_schema_to_model from utils.schemas import json_schema_to_model
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View File

@ -1,16 +1,15 @@
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends, Request
from apps.webui.models.users import Users
from typing import Union, Optional
from constants import ERROR_MESSAGES
from passlib.context import CryptContext
from datetime import datetime, timedelta, UTC
import jwt
import uuid
import logging import logging
import uuid
from datetime import UTC, datetime, timedelta
from typing import Optional, Union
import jwt
from apps.webui.models.users import Users
from constants import ERROR_MESSAGES
from env import WEBUI_SECRET_KEY from env import WEBUI_SECRET_KEY
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from passlib.context import CryptContext
logging.getLogger("passlib").setLevel(logging.ERROR) logging.getLogger("passlib").setLevel(logging.ERROR)

View File

@ -1,8 +1,9 @@
import json import json
import requests
import logging import logging
from config import SRC_LOG_LEVELS, VERSION, WEBUI_FAVICON_URL, WEBUI_NAME import requests
from config import WEBUI_FAVICON_URL, WEBUI_NAME
from env import SRC_LOG_LEVELS, VERSION
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["WEBHOOK"]) log.setLevel(SRC_LOG_LEVELS["WEBHOOK"])