diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py index 46be15364..5a928a8b6 100644 --- a/backend/apps/audio/main.py +++ b/backend/apps/audio/main.py @@ -7,46 +7,33 @@ from functools import lru_cache from pathlib import Path import requests -from fastapi import ( - FastAPI, - Request, - Depends, - HTTPException, - status, - UploadFile, - File, +from config import ( + AUDIO_STT_ENGINE, + AUDIO_STT_MODEL, + AUDIO_STT_OPENAI_API_BASE_URL, + AUDIO_STT_OPENAI_API_KEY, + AUDIO_TTS_API_KEY, + AUDIO_TTS_ENGINE, + AUDIO_TTS_MODEL, + AUDIO_TTS_OPENAI_API_BASE_URL, + AUDIO_TTS_OPENAI_API_KEY, + AUDIO_TTS_SPLIT_ON, + AUDIO_TTS_VOICE, + CACHE_DIR, + CORS_ALLOW_ORIGIN, + DEVICE_TYPE, + WHISPER_MODEL, + WHISPER_MODEL_AUTO_UPDATE, + WHISPER_MODEL_DIR, + AppConfig, ) +from constants import ERROR_MESSAGES +from env import SRC_LOG_LEVELS +from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse from pydantic import BaseModel - -from config import ( - SRC_LOG_LEVELS, - CACHE_DIR, - WHISPER_MODEL, - WHISPER_MODEL_DIR, - WHISPER_MODEL_AUTO_UPDATE, - DEVICE_TYPE, - AUDIO_STT_OPENAI_API_BASE_URL, - AUDIO_STT_OPENAI_API_KEY, - AUDIO_TTS_OPENAI_API_BASE_URL, - AUDIO_TTS_OPENAI_API_KEY, - AUDIO_TTS_API_KEY, - AUDIO_STT_ENGINE, - AUDIO_STT_MODEL, - AUDIO_TTS_ENGINE, - AUDIO_TTS_MODEL, - AUDIO_TTS_VOICE, - AUDIO_TTS_SPLIT_ON, - AppConfig, - CORS_ALLOW_ORIGIN, -) -from constants import ERROR_MESSAGES -from utils.utils import ( - get_current_user, - get_verified_user, - get_admin_user, -) +from utils.utils import get_admin_user, get_current_user, get_verified_user log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["AUDIO"]) @@ -211,7 +198,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): body = json.loads(body) body["model"] = app.state.config.TTS_MODEL body = json.dumps(body).encode("utf-8") - except Exception as e: + except Exception: pass r = None @@ -488,7 +475,7 @@ def get_available_voices() -> dict: elif app.state.config.TTS_ENGINE == "elevenlabs": try: ret = get_elevenlabs_voices() - except Exception as e: + except Exception: # Avoided @lru_cache with exception pass diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index ed7543106..ec2b963a8 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -1,52 +1,42 @@ -from fastapi import ( - FastAPI, - Request, - Depends, - HTTPException, -) -from fastapi.middleware.cors import CORSMiddleware -from typing import Optional -from pydantic import BaseModel -from pathlib import Path -import mimetypes -import uuid +import asyncio import base64 import json import logging +import mimetypes import re +import uuid +from pathlib import Path +from typing import Optional + import requests -import asyncio - -from utils.utils import ( - get_verified_user, - get_admin_user, -) - from apps.images.utils.comfyui import ( - ComfyUIWorkflow, ComfyUIGenerateImageForm, + ComfyUIWorkflow, comfyui_generate_image, ) - -from constants import ERROR_MESSAGES from config import ( - SRC_LOG_LEVELS, - CACHE_DIR, - IMAGE_GENERATION_ENGINE, - ENABLE_IMAGE_GENERATION, - AUTOMATIC1111_BASE_URL, AUTOMATIC1111_API_AUTH, + AUTOMATIC1111_BASE_URL, + CACHE_DIR, COMFYUI_BASE_URL, COMFYUI_WORKFLOW, COMFYUI_WORKFLOW_NODES, - IMAGES_OPENAI_API_BASE_URL, - IMAGES_OPENAI_API_KEY, + CORS_ALLOW_ORIGIN, + ENABLE_IMAGE_GENERATION, + IMAGE_GENERATION_ENGINE, IMAGE_GENERATION_MODEL, IMAGE_SIZE, IMAGE_STEPS, - CORS_ALLOW_ORIGIN, + IMAGES_OPENAI_API_BASE_URL, + IMAGES_OPENAI_API_KEY, AppConfig, ) +from constants import ERROR_MESSAGES +from env import SRC_LOG_LEVELS +from fastapi import Depends, FastAPI, HTTPException, Request +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +from utils.utils import get_admin_user, get_verified_user log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["IMAGES"]) @@ -186,7 +176,7 @@ async def verify_url(user=Depends(get_admin_user)): ) r.raise_for_status() return True - except Exception as e: + except Exception: app.state.config.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) elif app.state.config.ENGINE == "comfyui": @@ -194,7 +184,7 @@ async def verify_url(user=Depends(get_admin_user)): r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info") r.raise_for_status() return True - except Exception as e: + except Exception: app.state.config.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) else: @@ -397,7 +387,6 @@ def save_url_image(url): r = requests.get(url) r.raise_for_status() if r.headers["content-type"].split("/")[0] == "image": - mime_type = r.headers["content-type"] image_format = mimetypes.guess_extension(mime_type) @@ -412,7 +401,7 @@ def save_url_image(url): image_file.write(chunk) return image_filename else: - log.error(f"Url does not point to an image.") + log.error("Url does not point to an image.") return None except Exception as e: @@ -430,7 +419,6 @@ async def image_generations( r = None try: if app.state.config.ENGINE == "openai": - headers = {} headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}" headers["Content-Type"] = "application/json" diff --git a/backend/apps/images/utils/comfyui.py b/backend/apps/images/utils/comfyui.py index fa4736131..00121d99d 100644 --- a/backend/apps/images/utils/comfyui.py +++ b/backend/apps/images/utils/comfyui.py @@ -1,20 +1,18 @@ import asyncio -import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client) import json -import urllib.request -import urllib.parse -import random import logging +import random +import urllib.parse +import urllib.request +from typing import Optional -from config import SRC_LOG_LEVELS +import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client) +from env import SRC_LOG_LEVELS +from pydantic import BaseModel log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["COMFYUI"]) -from pydantic import BaseModel - -from typing import Optional - default_headers = {"User-Agent": "Mozilla/5.0"} diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index db677e84c..a887d2981 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -1,54 +1,40 @@ -from fastapi import ( - FastAPI, - Request, - HTTPException, - Depends, - UploadFile, - File, -) -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse - -from pydantic import BaseModel, ConfigDict - -import os -import re -import random -import requests -import json -import aiohttp import asyncio +import json import logging +import os +import random +import re import time -from urllib.parse import urlparse from typing import Optional, Union +from urllib.parse import urlparse -from starlette.background import BackgroundTask - +import aiohttp +import requests from apps.webui.models.models import Models -from constants import ERROR_MESSAGES -from utils.utils import ( - get_verified_user, - get_admin_user, -) - from config import ( - SRC_LOG_LEVELS, - OLLAMA_BASE_URLS, - ENABLE_OLLAMA_API, AIOHTTP_CLIENT_TIMEOUT, + CORS_ALLOW_ORIGIN, ENABLE_MODEL_FILTER, + ENABLE_OLLAMA_API, MODEL_FILTER_LIST, + OLLAMA_BASE_URLS, UPLOAD_DIR, AppConfig, - CORS_ALLOW_ORIGIN, ) +from constants import ERROR_MESSAGES +from env import SRC_LOG_LEVELS +from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse +from pydantic import BaseModel, ConfigDict +from starlette.background import BackgroundTask from utils.misc import ( - calculate_sha256, apply_model_params_to_body_ollama, apply_model_params_to_body_openai, apply_model_system_prompt_to_body, + calculate_sha256, ) +from utils.utils import get_admin_user, get_verified_user log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 9ad67c40c..9ac4ee0ac 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -1,44 +1,36 @@ -from fastapi import FastAPI, Request, HTTPException, Depends -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse, FileResponse - -import requests -import aiohttp import asyncio +import hashlib import json import logging +from pathlib import Path +from typing import Literal, Optional, overload +import aiohttp +import requests +from apps.webui.models.models import Models +from config import ( + AIOHTTP_CLIENT_TIMEOUT, + CACHE_DIR, + CORS_ALLOW_ORIGIN, + ENABLE_MODEL_FILTER, + ENABLE_OPENAI_API, + MODEL_FILTER_LIST, + OPENAI_API_BASE_URLS, + OPENAI_API_KEYS, + AppConfig, +) +from constants import ERROR_MESSAGES +from env import SRC_LOG_LEVELS +from fastapi import Depends, FastAPI, HTTPException, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse, StreamingResponse from pydantic import BaseModel from starlette.background import BackgroundTask - -from apps.webui.models.models import Models -from constants import ERROR_MESSAGES -from utils.utils import ( - get_verified_user, - get_admin_user, -) from utils.misc import ( apply_model_params_to_body_openai, apply_model_system_prompt_to_body, ) - -from config import ( - SRC_LOG_LEVELS, - ENABLE_OPENAI_API, - AIOHTTP_CLIENT_TIMEOUT, - OPENAI_API_BASE_URLS, - OPENAI_API_KEYS, - CACHE_DIR, - ENABLE_MODEL_FILTER, - MODEL_FILTER_LIST, - AppConfig, - CORS_ALLOW_ORIGIN, -) -from typing import Optional, Literal, overload - - -import hashlib -from pathlib import Path +from utils.utils import get_admin_user, get_verified_user log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OPENAI"]) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index bb6149e05..e1576574f 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -1,143 +1,118 @@ -from fastapi import ( - FastAPI, - Depends, - HTTPException, - status, - UploadFile, - File, - Form, -) -from fastapi.middleware.cors import CORSMiddleware -import requests -import os, shutil, logging, re -from datetime import datetime - -from pathlib import Path -from typing import Union, Sequence, Iterator, Any - -from chromadb.utils.batch_utils import create_batches -from langchain_core.documents import Document - -from langchain_community.document_loaders import ( - WebBaseLoader, - TextLoader, - PyPDFLoader, - CSVLoader, - BSHTMLLoader, - Docx2txtLoader, - UnstructuredEPubLoader, - UnstructuredWordDocumentLoader, - UnstructuredMarkdownLoader, - UnstructuredXMLLoader, - UnstructuredRSTLoader, - UnstructuredExcelLoader, - UnstructuredPowerPointLoader, - YoutubeLoader, - OutlookMessageLoader, -) -from langchain.text_splitter import RecursiveCharacterTextSplitter - -import validators -import urllib.parse -import socket - - -from pydantic import BaseModel -from typing import Optional -import mimetypes -import uuid import json +import logging +import mimetypes +import os +import shutil +import socket +import urllib.parse +import uuid +from datetime import datetime +from pathlib import Path +from typing import Iterator, Optional, Sequence, Union -from apps.webui.models.documents import ( - Documents, - DocumentForm, - DocumentResponse, -) -from apps.webui.models.files import ( - Files, -) - -from apps.rag.utils import ( - get_model_path, - get_embedding_function, - query_doc, - query_doc_with_hybrid_search, - query_collection, - query_collection_with_hybrid_search, -) - +import requests +import validators from apps.rag.search.brave import search_brave +from apps.rag.search.duckduckgo import search_duckduckgo from apps.rag.search.google_pse import search_google_pse +from apps.rag.search.jina_search import search_jina from apps.rag.search.main import SearchResult +from apps.rag.search.searchapi import search_searchapi from apps.rag.search.searxng import search_searxng from apps.rag.search.serper import search_serper -from apps.rag.search.serpstack import search_serpstack from apps.rag.search.serply import search_serply -from apps.rag.search.duckduckgo import search_duckduckgo +from apps.rag.search.serpstack import search_serpstack from apps.rag.search.tavily import search_tavily -from apps.rag.search.jina_search import search_jina -from apps.rag.search.searchapi import search_searchapi - -from utils.misc import ( - calculate_sha256, - calculate_sha256_string, - sanitize_filename, - extract_folders_after_data_docs, +from apps.rag.utils import ( + get_embedding_function, + get_model_path, + query_collection, + query_collection_with_hybrid_search, + query_doc, + query_doc_with_hybrid_search, ) -from utils.utils import get_verified_user, get_admin_user - +from apps.webui.models.documents import DocumentForm, Documents +from apps.webui.models.files import Files +from chromadb.utils.batch_utils import create_batches from config import ( - AppConfig, - ENV, - SRC_LOG_LEVELS, - UPLOAD_DIR, - DOCS_DIR, + BRAVE_SEARCH_API_KEY, + CHROMA_CLIENT, + CHUNK_OVERLAP, + CHUNK_SIZE, CONTENT_EXTRACTION_ENGINE, - TIKA_SERVER_URL, - RAG_TOP_K, - RAG_RELEVANCE_THRESHOLD, - RAG_FILE_MAX_SIZE, - RAG_FILE_MAX_COUNT, + CORS_ALLOW_ORIGIN, + DEVICE_TYPE, + DOCS_DIR, + ENABLE_RAG_HYBRID_SEARCH, + ENABLE_RAG_LOCAL_WEB_FETCH, + ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + ENABLE_RAG_WEB_SEARCH, + ENV, + GOOGLE_PSE_API_KEY, + GOOGLE_PSE_ENGINE_ID, + PDF_EXTRACT_IMAGES, RAG_EMBEDDING_ENGINE, RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, - ENABLE_RAG_HYBRID_SEARCH, - ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, - RAG_RERANKING_MODEL, - PDF_EXTRACT_IMAGES, - RAG_RERANKING_MODEL_AUTO_UPDATE, - RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, + RAG_EMBEDDING_OPENAI_BATCH_SIZE, + RAG_FILE_MAX_COUNT, + RAG_FILE_MAX_SIZE, RAG_OPENAI_API_BASE_URL, RAG_OPENAI_API_KEY, - DEVICE_TYPE, - CHROMA_CLIENT, - CHUNK_SIZE, - CHUNK_OVERLAP, + RAG_RELEVANCE_THRESHOLD, + RAG_RERANKING_MODEL, + RAG_RERANKING_MODEL_AUTO_UPDATE, + RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, RAG_TEMPLATE, - ENABLE_RAG_LOCAL_WEB_FETCH, - YOUTUBE_LOADER_LANGUAGE, - ENABLE_RAG_WEB_SEARCH, - RAG_WEB_SEARCH_ENGINE, + RAG_TOP_K, + RAG_WEB_SEARCH_CONCURRENT_REQUESTS, RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, - SEARXNG_QUERY_URL, - GOOGLE_PSE_API_KEY, - GOOGLE_PSE_ENGINE_ID, - BRAVE_SEARCH_API_KEY, - SERPSTACK_API_KEY, - SERPSTACK_HTTPS, - SERPER_API_KEY, - SERPLY_API_KEY, - TAVILY_API_KEY, + RAG_WEB_SEARCH_ENGINE, + RAG_WEB_SEARCH_RESULT_COUNT, SEARCHAPI_API_KEY, SEARCHAPI_ENGINE, - RAG_WEB_SEARCH_RESULT_COUNT, - RAG_WEB_SEARCH_CONCURRENT_REQUESTS, - RAG_EMBEDDING_OPENAI_BATCH_SIZE, - CORS_ALLOW_ORIGIN, + SEARXNG_QUERY_URL, + SERPER_API_KEY, + SERPLY_API_KEY, + SERPSTACK_API_KEY, + SERPSTACK_HTTPS, + TAVILY_API_KEY, + TIKA_SERVER_URL, + UPLOAD_DIR, + YOUTUBE_LOADER_LANGUAGE, + AppConfig, ) - from constants import ERROR_MESSAGES +from env import SRC_LOG_LEVELS +from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status +from fastapi.middleware.cors import CORSMiddleware +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.document_loaders import ( + BSHTMLLoader, + CSVLoader, + Docx2txtLoader, + OutlookMessageLoader, + PyPDFLoader, + TextLoader, + UnstructuredEPubLoader, + UnstructuredExcelLoader, + UnstructuredMarkdownLoader, + UnstructuredPowerPointLoader, + UnstructuredRSTLoader, + UnstructuredXMLLoader, + WebBaseLoader, + YoutubeLoader, +) +from langchain_core.documents import Document +from pydantic import BaseModel +from utils.misc import ( + calculate_sha256, + calculate_sha256_string, + extract_folders_after_data_docs, + sanitize_filename, +) +from utils.utils import get_admin_user, get_verified_user log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -539,9 +514,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key app.state.config.SEARCHAPI_API_KEY = form_data.web.search.searchapi_api_key - app.state.config.SEARCHAPI_ENGINE = ( - form_data.web.search.searchapi_engine - ) + app.state.config.SEARCHAPI_ENGINE = form_data.web.search.searchapi_engine app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = ( form_data.web.search.concurrent_requests @@ -981,7 +954,6 @@ def store_web_search(form_data: SearchForm, user=Depends(get_verified_user)): def store_data_in_vector_db( data, collection_name, metadata: Optional[dict] = None, overwrite: bool = False ) -> bool: - text_splitter = RecursiveCharacterTextSplitter( chunk_size=app.state.config.CHUNK_SIZE, chunk_overlap=app.state.config.CHUNK_OVERLAP, @@ -1342,7 +1314,6 @@ def store_text( form_data: TextRAGForm, user=Depends(get_verified_user), ): - collection_name = form_data.collection_name if collection_name is None: collection_name = calculate_sha256_string(form_data.content) diff --git a/backend/apps/rag/search/brave.py b/backend/apps/rag/search/brave.py index 681caa976..0da55506b 100644 --- a/backend/apps/rag/search/brave.py +++ b/backend/apps/rag/search/brave.py @@ -1,9 +1,9 @@ import logging from typing import Optional -import requests +import requests from apps.rag.search.main import SearchResult, get_filtered_results -from config import SRC_LOG_LEVELS +from env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) diff --git a/backend/apps/rag/search/duckduckgo.py b/backend/apps/rag/search/duckduckgo.py index e994ef47a..ed47c5689 100644 --- a/backend/apps/rag/search/duckduckgo.py +++ b/backend/apps/rag/search/duckduckgo.py @@ -1,8 +1,9 @@ import logging from typing import Optional + from apps.rag.search.main import SearchResult, get_filtered_results from duckduckgo_search import DDGS -from config import SRC_LOG_LEVELS +from env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) diff --git a/backend/apps/rag/search/google_pse.py b/backend/apps/rag/search/google_pse.py index 7fedb3dad..f2cbc68fd 100644 --- a/backend/apps/rag/search/google_pse.py +++ b/backend/apps/rag/search/google_pse.py @@ -1,10 +1,9 @@ -import json import logging from typing import Optional -import requests +import requests from apps.rag.search.main import SearchResult, get_filtered_results -from config import SRC_LOG_LEVELS +from env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) diff --git a/backend/apps/rag/search/jina_search.py b/backend/apps/rag/search/jina_search.py index 8d1c582a1..a7d56581a 100644 --- a/backend/apps/rag/search/jina_search.py +++ b/backend/apps/rag/search/jina_search.py @@ -1,9 +1,9 @@ import logging -import requests -from yarl import URL +import requests from apps.rag.search.main import SearchResult -from config import SRC_LOG_LEVELS +from env import SRC_LOG_LEVELS +from yarl import URL log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) diff --git a/backend/apps/rag/search/main.py b/backend/apps/rag/search/main.py index f06b30505..1af8a70aa 100644 --- a/backend/apps/rag/search/main.py +++ b/backend/apps/rag/search/main.py @@ -1,5 +1,6 @@ from typing import Optional from urllib.parse import urlparse + from pydantic import BaseModel diff --git a/backend/apps/rag/search/searxng.py b/backend/apps/rag/search/searxng.py index 94bed2857..ca26e9c00 100644 --- a/backend/apps/rag/search/searxng.py +++ b/backend/apps/rag/search/searxng.py @@ -1,10 +1,9 @@ import logging -import requests - from typing import Optional +import requests from apps.rag.search.main import SearchResult, get_filtered_results -from config import SRC_LOG_LEVELS +from env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) diff --git a/backend/apps/rag/search/serper.py b/backend/apps/rag/search/serper.py index e71fbb628..702ab5500 100644 --- a/backend/apps/rag/search/serper.py +++ b/backend/apps/rag/search/serper.py @@ -1,10 +1,10 @@ import json import logging from typing import Optional -import requests +import requests from apps.rag.search.main import SearchResult, get_filtered_results -from config import SRC_LOG_LEVELS +from env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) diff --git a/backend/apps/rag/search/serply.py b/backend/apps/rag/search/serply.py index 28c15fd78..a65f62d61 100644 --- a/backend/apps/rag/search/serply.py +++ b/backend/apps/rag/search/serply.py @@ -1,11 +1,10 @@ -import json import logging from typing import Optional -import requests from urllib.parse import urlencode +import requests from apps.rag.search.main import SearchResult, get_filtered_results -from config import SRC_LOG_LEVELS +from env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) diff --git a/backend/apps/rag/search/serpstack.py b/backend/apps/rag/search/serpstack.py index 5c19bd134..643be5aaf 100644 --- a/backend/apps/rag/search/serpstack.py +++ b/backend/apps/rag/search/serpstack.py @@ -1,10 +1,9 @@ -import json import logging from typing import Optional -import requests +import requests from apps.rag.search.main import SearchResult, get_filtered_results -from config import SRC_LOG_LEVELS +from env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) diff --git a/backend/apps/rag/search/tavily.py b/backend/apps/rag/search/tavily.py index ed4ab6e08..2c45f9801 100644 --- a/backend/apps/rag/search/tavily.py +++ b/backend/apps/rag/search/tavily.py @@ -1,9 +1,8 @@ import logging import requests - from apps.rag.search.main import SearchResult -from config import SRC_LOG_LEVELS +from env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 82bead012..14e558ae5 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -1,27 +1,16 @@ -import os import logging +import os +from typing import Optional, Union + import requests - -from typing import Union - -from apps.ollama.main import ( - generate_ollama_embeddings, - GenerateEmbeddingsForm, -) - +from apps.ollama.main import GenerateEmbeddingsForm, generate_ollama_embeddings +from config import CHROMA_CLIENT +from env import SRC_LOG_LEVELS from huggingface_hub import snapshot_download - -from langchain_core.documents import Document +from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever from langchain_community.retrievers import BM25Retriever -from langchain.retrievers import ( - ContextualCompressionRetriever, - EnsembleRetriever, -) - -from typing import Optional - -from utils.misc import get_last_user_message, add_or_update_system_message -from config import SRC_LOG_LEVELS, CHROMA_CLIENT +from langchain_core.documents import Document +from utils.misc import get_last_user_message log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -261,7 +250,9 @@ def get_rag_context( collection_names = ( file["collection_names"] if file["type"] == "collection" - else [file["collection_name"]] if file["collection_name"] else [] + else [file["collection_name"]] + if file["collection_name"] + else [] ) collection_names = set(collection_names).difference(extracted_collections) @@ -401,8 +392,8 @@ def generate_openai_batch_embeddings( from typing import Any -from langchain_core.retrievers import BaseRetriever from langchain_core.callbacks import CallbackManagerForRetrieverRun +from langchain_core.retrievers import BaseRetriever class ChromaRetriever(BaseRetriever): @@ -439,11 +430,10 @@ class ChromaRetriever(BaseRetriever): import operator - from typing import Optional, Sequence -from langchain_core.documents import BaseDocumentCompressor, Document from langchain_core.callbacks import Callbacks +from langchain_core.documents import BaseDocumentCompressor, Document from langchain_core.pydantic_v1 import Extra diff --git a/backend/apps/socket/main.py b/backend/apps/socket/main.py index fcffca420..6b6282e61 100644 --- a/backend/apps/socket/main.py +++ b/backend/apps/socket/main.py @@ -1,7 +1,6 @@ -import socketio import asyncio - +import socketio from apps.webui.models.users import Users from utils.utils import decode_token diff --git a/backend/apps/webui/internal/db.py b/backend/apps/webui/internal/db.py index db8df5ee5..2ab2e1e0a 100644 --- a/backend/apps/webui/internal/db.py +++ b/backend/apps/webui/internal/db.py @@ -1,21 +1,16 @@ -import os -import logging import json +import logging from contextlib import contextmanager +from typing import Any, Optional - -from typing import Optional, Any -from typing_extensions import Self - -from sqlalchemy import create_engine, types, Dialect -from sqlalchemy.sql.type_api import _T -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker, scoped_session - - -from peewee_migrate import Router from apps.webui.internal.wrappers import register_connection -from env import SRC_LOG_LEVELS, BACKEND_DIR, DATABASE_URL +from env import BACKEND_DIR, DATABASE_URL, SRC_LOG_LEVELS +from peewee_migrate import Router +from sqlalchemy import Dialect, create_engine, types +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import scoped_session, sessionmaker +from sqlalchemy.sql.type_api import _T +from typing_extensions import Self log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["DB"]) diff --git a/backend/apps/webui/internal/wrappers.py b/backend/apps/webui/internal/wrappers.py index 19523064a..0a36cdce3 100644 --- a/backend/apps/webui/internal/wrappers.py +++ b/backend/apps/webui/internal/wrappers.py @@ -1,12 +1,12 @@ -from contextvars import ContextVar -from peewee import * -from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError - import logging -from playhouse.db_url import connect, parse -from playhouse.shortcuts import ReconnectMixin +from contextvars import ContextVar from env import SRC_LOG_LEVELS +from peewee import * +from peewee import InterfaceError as PeeWeeInterfaceError +from peewee import PostgresqlDatabase +from playhouse.db_url import connect, parse +from playhouse.shortcuts import ReconnectMixin log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["DB"]) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 00963def6..10d265ad1 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -1,65 +1,59 @@ -from fastapi import FastAPI -from fastapi.responses import StreamingResponse -from fastapi.middleware.cors import CORSMiddleware -from apps.webui.routers import ( - auths, - users, - chats, - documents, - tools, - models, - prompts, - configs, - memories, - utils, - files, - functions, -) -from apps.webui.models.functions import Functions -from apps.webui.models.models import Models -from apps.webui.utils import load_function_module_by_id - -from utils.misc import ( - openai_chat_chunk_message_template, - openai_chat_completion_message_template, - apply_model_params_to_body_openai, - apply_model_system_prompt_to_body, -) - -from utils.tools import get_tools - -from config import ( - SHOW_ADMIN_DETAILS, - ADMIN_EMAIL, - WEBUI_AUTH, - DEFAULT_MODELS, - DEFAULT_PROMPT_SUGGESTIONS, - DEFAULT_USER_ROLE, - ENABLE_SIGNUP, - ENABLE_LOGIN_FORM, - USER_PERMISSIONS, - WEBHOOK_URL, - WEBUI_AUTH_TRUSTED_EMAIL_HEADER, - WEBUI_AUTH_TRUSTED_NAME_HEADER, - JWT_EXPIRES_IN, - WEBUI_BANNERS, - ENABLE_COMMUNITY_SHARING, - ENABLE_MESSAGE_RATING, - AppConfig, - OAUTH_USERNAME_CLAIM, - OAUTH_PICTURE_CLAIM, - OAUTH_EMAIL_CLAIM, - CORS_ALLOW_ORIGIN, -) - -from apps.socket.main import get_event_call, get_event_emitter - import inspect import json import logging +from typing import AsyncGenerator, Generator, Iterator -from typing import Iterator, Generator, AsyncGenerator +from apps.socket.main import get_event_call, get_event_emitter +from apps.webui.models.functions import Functions +from apps.webui.models.models import Models +from apps.webui.routers import ( + auths, + chats, + configs, + documents, + files, + functions, + memories, + models, + prompts, + tools, + users, + utils, +) +from apps.webui.utils import load_function_module_by_id +from config import ( + ADMIN_EMAIL, + CORS_ALLOW_ORIGIN, + DEFAULT_MODELS, + DEFAULT_PROMPT_SUGGESTIONS, + DEFAULT_USER_ROLE, + ENABLE_COMMUNITY_SHARING, + ENABLE_LOGIN_FORM, + ENABLE_MESSAGE_RATING, + ENABLE_SIGNUP, + JWT_EXPIRES_IN, + OAUTH_EMAIL_CLAIM, + OAUTH_PICTURE_CLAIM, + OAUTH_USERNAME_CLAIM, + SHOW_ADMIN_DETAILS, + USER_PERMISSIONS, + WEBHOOK_URL, + WEBUI_AUTH, + WEBUI_BANNERS, + AppConfig, +) +from env import WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse from pydantic import BaseModel +from utils.misc import ( + apply_model_params_to_body_openai, + apply_model_system_prompt_to_body, + openai_chat_chunk_message_template, + openai_chat_completion_message_template, +) +from utils.tools import get_tools app = FastAPI() diff --git a/backend/apps/webui/models/auths.py b/backend/apps/webui/models/auths.py index 601c7c9a4..8c57b9546 100644 --- a/backend/apps/webui/models/auths.py +++ b/backend/apps/webui/models/auths.py @@ -1,15 +1,13 @@ -from pydantic import BaseModel -from typing import Optional -import uuid import logging -from sqlalchemy import String, Column, Boolean, Text +import uuid +from typing import Optional -from utils.utils import verify_password - -from apps.webui.models.users import UserModel, Users from apps.webui.internal.db import Base, get_db - +from apps.webui.models.users import UserModel, Users from env import SRC_LOG_LEVELS +from pydantic import BaseModel +from sqlalchemy import Boolean, Column, String, Text +from utils.utils import verify_password log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -92,7 +90,6 @@ class AddUserForm(SignupForm): class AuthsTable: - def insert_new_auth( self, email: str, @@ -103,7 +100,6 @@ class AuthsTable: oauth_sub: Optional[str] = None, ) -> Optional[UserModel]: with get_db() as db: - log.info("insert_new_auth") id = str(uuid.uuid4()) @@ -130,7 +126,6 @@ class AuthsTable: log.info(f"authenticate_user: {email}") try: with get_db() as db: - auth = db.query(Auth).filter_by(email=email, active=True).first() if auth: if verify_password(password, auth.password): @@ -189,7 +184,6 @@ class AuthsTable: def delete_auth_by_id(self, id: str) -> bool: try: with get_db() as db: - # Delete User result = Users.delete_user_by_id(id) diff --git a/backend/apps/webui/models/chats.py b/backend/apps/webui/models/chats.py index 164be0646..5e08c92ce 100644 --- a/backend/apps/webui/models/chats.py +++ b/backend/apps/webui/models/chats.py @@ -1,14 +1,11 @@ -from pydantic import BaseModel, ConfigDict -from typing import Union, Optional - import json -import uuid import time - -from sqlalchemy import Column, String, BigInteger, Boolean, Text +import uuid +from typing import Optional from apps.webui.internal.db import Base, get_db - +from pydantic import BaseModel, ConfigDict +from sqlalchemy import BigInteger, Boolean, Column, String, Text #################### # Chat DB Schema @@ -77,10 +74,8 @@ class ChatTitleIdResponse(BaseModel): class ChatTable: - def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: with get_db() as db: - id = str(uuid.uuid4()) chat = ChatModel( **{ @@ -106,7 +101,6 @@ class ChatTable: def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: try: with get_db() as db: - chat_obj = db.get(Chat, id) chat_obj.chat = json.dumps(chat) chat_obj.title = chat["title"] if "title" in chat else "New Chat" @@ -115,12 +109,11 @@ class ChatTable: db.refresh(chat_obj) return ChatModel.model_validate(chat_obj) - except Exception as e: + except Exception: return None def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: with get_db() as db: - # Get the existing chat to share chat = db.get(Chat, chat_id) # Check if the chat is already shared @@ -154,7 +147,6 @@ class ChatTable: def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: try: with get_db() as db: - print("update_shared_chat_by_id") chat = db.get(Chat, chat_id) print(chat) @@ -170,7 +162,6 @@ class ChatTable: def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: try: with get_db() as db: - db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() db.commit() @@ -183,7 +174,6 @@ class ChatTable: ) -> Optional[ChatModel]: try: with get_db() as db: - chat = db.get(Chat, id) chat.share_id = share_id db.commit() @@ -195,7 +185,6 @@ class ChatTable: def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: try: with get_db() as db: - chat = db.get(Chat, id) chat.archived = not chat.archived db.commit() @@ -217,7 +206,6 @@ class ChatTable: self, user_id: str, skip: int = 0, limit: int = 50 ) -> list[ChatModel]: with get_db() as db: - all_chats = ( db.query(Chat) .filter_by(user_id=user_id, archived=True) @@ -297,7 +285,6 @@ class ChatTable: def get_chat_by_id(self, id: str) -> Optional[ChatModel]: try: with get_db() as db: - chat = db.get(Chat, id) return ChatModel.model_validate(chat) except Exception: @@ -306,20 +293,18 @@ class ChatTable: def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]: try: with get_db() as db: - chat = db.query(Chat).filter_by(share_id=id).first() if chat: return self.get_chat_by_id(id) else: return None - except Exception as e: + except Exception: return None def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: try: with get_db() as db: - chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() return ChatModel.model_validate(chat) except Exception: @@ -327,7 +312,6 @@ class ChatTable: def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]: with get_db() as db: - all_chats = ( db.query(Chat) # .limit(limit).offset(skip) @@ -337,7 +321,6 @@ class ChatTable: def get_chats_by_user_id(self, user_id: str) -> list[ChatModel]: with get_db() as db: - all_chats = ( db.query(Chat) .filter_by(user_id=user_id) @@ -347,7 +330,6 @@ class ChatTable: def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]: with get_db() as db: - all_chats = ( db.query(Chat) .filter_by(user_id=user_id, archived=True) @@ -358,7 +340,6 @@ class ChatTable: def delete_chat_by_id(self, id: str) -> bool: try: with get_db() as db: - db.query(Chat).filter_by(id=id).delete() db.commit() @@ -369,7 +350,6 @@ class ChatTable: def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: try: with get_db() as db: - db.query(Chat).filter_by(id=id, user_id=user_id).delete() db.commit() @@ -379,9 +359,7 @@ class ChatTable: def delete_chats_by_user_id(self, user_id: str) -> bool: try: - with get_db() as db: - self.delete_shared_chats_by_user_id(user_id) db.query(Chat).filter_by(user_id=user_id).delete() @@ -393,9 +371,7 @@ class ChatTable: def delete_shared_chats_by_user_id(self, user_id: str) -> bool: try: - with get_db() as db: - chats_by_user = db.query(Chat).filter_by(user_id=user_id).all() shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] diff --git a/backend/apps/webui/models/documents.py b/backend/apps/webui/models/documents.py index 15dd63663..0738716a0 100644 --- a/backend/apps/webui/models/documents.py +++ b/backend/apps/webui/models/documents.py @@ -1,15 +1,12 @@ -from pydantic import BaseModel, ConfigDict -from typing import Optional -import time +import json import logging - -from sqlalchemy import String, Column, BigInteger, Text +import time +from typing import Optional from apps.webui.internal.db import Base, get_db - -import json - from env import SRC_LOG_LEVELS +from pydantic import BaseModel, ConfigDict +from sqlalchemy import BigInteger, Column, String, Text log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -70,12 +67,10 @@ class DocumentForm(DocumentUpdateForm): class DocumentsTable: - def insert_new_doc( self, user_id: str, form_data: DocumentForm ) -> Optional[DocumentModel]: with get_db() as db: - document = DocumentModel( **{ **form_data.model_dump(), @@ -99,7 +94,6 @@ class DocumentsTable: def get_doc_by_name(self, name: str) -> Optional[DocumentModel]: try: with get_db() as db: - document = db.query(Document).filter_by(name=name).first() return DocumentModel.model_validate(document) if document else None except Exception: @@ -107,7 +101,6 @@ class DocumentsTable: def get_docs(self) -> list[DocumentModel]: with get_db() as db: - return [ DocumentModel.model_validate(doc) for doc in db.query(Document).all() ] @@ -117,7 +110,6 @@ class DocumentsTable: ) -> Optional[DocumentModel]: try: with get_db() as db: - db.query(Document).filter_by(name=name).update( { "title": form_data.title, @@ -140,7 +132,6 @@ class DocumentsTable: doc_content = {**doc_content, **updated} with get_db() as db: - db.query(Document).filter_by(name=name).update( { "content": json.dumps(doc_content), @@ -156,7 +147,6 @@ class DocumentsTable: def delete_doc_by_name(self, name: str) -> bool: try: with get_db() as db: - db.query(Document).filter_by(name=name).delete() db.commit() return True diff --git a/backend/apps/webui/models/files.py b/backend/apps/webui/models/files.py index 1b7175124..794b7070c 100644 --- a/backend/apps/webui/models/files.py +++ b/backend/apps/webui/models/files.py @@ -1,15 +1,11 @@ -from pydantic import BaseModel, ConfigDict -from typing import Union, Optional -import time import logging +import time +from typing import Optional -from sqlalchemy import Column, String, BigInteger, Text - -from apps.webui.internal.db import JSONField, Base, get_db - -import json - +from apps.webui.internal.db import Base, JSONField, get_db from env import SRC_LOG_LEVELS +from pydantic import BaseModel, ConfigDict +from sqlalchemy import BigInteger, Column, String, Text log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -59,10 +55,8 @@ class FileForm(BaseModel): class FilesTable: - def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]: with get_db() as db: - file = FileModel( **{ **form_data.model_dump(), @@ -86,7 +80,6 @@ class FilesTable: def get_file_by_id(self, id: str) -> Optional[FileModel]: with get_db() as db: - try: file = db.get(File, id) return FileModel.model_validate(file) @@ -95,7 +88,6 @@ class FilesTable: def get_files(self) -> list[FileModel]: with get_db() as db: - return [FileModel.model_validate(file) for file in db.query(File).all()] def get_files_by_user_id(self, user_id: str) -> list[FileModel]: @@ -106,9 +98,7 @@ class FilesTable: ] def delete_file_by_id(self, id: str) -> bool: - with get_db() as db: - try: db.query(File).filter_by(id=id).delete() db.commit() @@ -118,9 +108,7 @@ class FilesTable: return False def delete_all_files(self) -> bool: - with get_db() as db: - try: db.query(File).delete() db.commit() diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index 10d811148..bb85c83e5 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -1,18 +1,12 @@ -from pydantic import BaseModel, ConfigDict -from typing import Union, Optional -import time import logging +import time +from typing import Optional -from sqlalchemy import Column, String, Text, BigInteger, Boolean - -from apps.webui.internal.db import JSONField, Base, get_db +from apps.webui.internal.db import Base, JSONField, get_db from apps.webui.models.users import Users - -import json -import copy - - from env import SRC_LOG_LEVELS +from pydantic import BaseModel, ConfigDict +from sqlalchemy import BigInteger, Boolean, Column, String, Text log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -87,11 +81,9 @@ class FunctionValves(BaseModel): class FunctionsTable: - def insert_new_function( self, user_id: str, type: str, form_data: FunctionForm ) -> Optional[FunctionModel]: - function = FunctionModel( **{ **form_data.model_dump(), @@ -119,7 +111,6 @@ class FunctionsTable: def get_function_by_id(self, id: str) -> Optional[FunctionModel]: try: with get_db() as db: - function = db.get(Function, id) return FunctionModel.model_validate(function) except Exception: @@ -127,7 +118,6 @@ class FunctionsTable: def get_functions(self, active_only=False) -> list[FunctionModel]: with get_db() as db: - if active_only: return [ FunctionModel.model_validate(function) @@ -143,7 +133,6 @@ class FunctionsTable: self, type: str, active_only=False ) -> list[FunctionModel]: with get_db() as db: - if active_only: return [ FunctionModel.model_validate(function) @@ -159,7 +148,6 @@ class FunctionsTable: def get_global_filter_functions(self) -> list[FunctionModel]: with get_db() as db: - return [ FunctionModel.model_validate(function) for function in db.query(Function) @@ -178,7 +166,6 @@ class FunctionsTable: def get_function_valves_by_id(self, id: str) -> Optional[dict]: with get_db() as db: - try: function = db.get(Function, id) return function.valves if function.valves else {} @@ -190,7 +177,6 @@ class FunctionsTable: self, id: str, valves: dict ) -> Optional[FunctionValves]: with get_db() as db: - try: function = db.get(Function, id) function.valves = valves @@ -204,7 +190,6 @@ class FunctionsTable: def get_user_valves_by_id_and_user_id( self, id: str, user_id: str ) -> Optional[dict]: - try: user = Users.get_user_by_id(user_id) user_settings = user.settings.model_dump() if user.settings else {} @@ -223,7 +208,6 @@ class FunctionsTable: def update_user_valves_by_id_and_user_id( self, id: str, user_id: str, valves: dict ) -> Optional[dict]: - try: user = Users.get_user_by_id(user_id) user_settings = user.settings.model_dump() if user.settings else {} @@ -246,7 +230,6 @@ class FunctionsTable: def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: with get_db() as db: - try: db.query(Function).filter_by(id=id).update( { @@ -261,7 +244,6 @@ class FunctionsTable: def deactivate_all_functions(self) -> Optional[bool]: with get_db() as db: - try: db.query(Function).update( { diff --git a/backend/apps/webui/models/memories.py b/backend/apps/webui/models/memories.py index 41bb11ccf..9c8ac8746 100644 --- a/backend/apps/webui/models/memories.py +++ b/backend/apps/webui/models/memories.py @@ -1,12 +1,10 @@ -from pydantic import BaseModel, ConfigDict -from typing import Union, Optional - -from sqlalchemy import Column, String, BigInteger, Text - -from apps.webui.internal.db import Base, get_db - import time import uuid +from typing import Optional + +from apps.webui.internal.db import Base, get_db +from pydantic import BaseModel, ConfigDict +from sqlalchemy import BigInteger, Column, String, Text #################### # Memory DB Schema @@ -39,13 +37,11 @@ class MemoryModel(BaseModel): class MemoriesTable: - def insert_new_memory( self, user_id: str, content: str, ) -> Optional[MemoryModel]: - with get_db() as db: id = str(uuid.uuid4()) @@ -73,7 +69,6 @@ class MemoriesTable: content: str, ) -> Optional[MemoryModel]: with get_db() as db: - try: db.query(Memory).filter_by(id=id).update( {"content": content, "updated_at": int(time.time())} @@ -85,7 +80,6 @@ class MemoriesTable: def get_memories(self) -> list[MemoryModel]: with get_db() as db: - try: memories = db.query(Memory).all() return [MemoryModel.model_validate(memory) for memory in memories] @@ -94,7 +88,6 @@ class MemoriesTable: def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]: with get_db() as db: - try: memories = db.query(Memory).filter_by(user_id=user_id).all() return [MemoryModel.model_validate(memory) for memory in memories] @@ -103,7 +96,6 @@ class MemoriesTable: def get_memory_by_id(self, id: str) -> Optional[MemoryModel]: with get_db() as db: - try: memory = db.get(Memory, id) return MemoryModel.model_validate(memory) @@ -112,7 +104,6 @@ class MemoriesTable: def delete_memory_by_id(self, id: str) -> bool: with get_db() as db: - try: db.query(Memory).filter_by(id=id).delete() db.commit() @@ -124,7 +115,6 @@ class MemoriesTable: def delete_memories_by_user_id(self, user_id: str) -> bool: with get_db() as db: - try: db.query(Memory).filter_by(user_id=user_id).delete() db.commit() @@ -135,7 +125,6 @@ class MemoriesTable: def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool: with get_db() as db: - try: db.query(Memory).filter_by(id=id, user_id=user_id).delete() db.commit() diff --git a/backend/apps/webui/models/models.py b/backend/apps/webui/models/models.py index 0a36da987..13b111544 100644 --- a/backend/apps/webui/models/models.py +++ b/backend/apps/webui/models/models.py @@ -1,14 +1,11 @@ import logging -from typing import Optional, List - -from pydantic import BaseModel, ConfigDict -from sqlalchemy import Column, BigInteger, Text +import time +from typing import Optional from apps.webui.internal.db import Base, JSONField, get_db - from env import SRC_LOG_LEVELS - -import time +from pydantic import BaseModel, ConfigDict +from sqlalchemy import BigInteger, Column, Text log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) diff --git a/backend/apps/webui/models/prompts.py b/backend/apps/webui/models/prompts.py index 942f64a43..677cdd7c3 100644 --- a/backend/apps/webui/models/prompts.py +++ b/backend/apps/webui/models/prompts.py @@ -1,12 +1,9 @@ -from pydantic import BaseModel, ConfigDict -from typing import Optional import time - -from sqlalchemy import String, Column, BigInteger, Text +from typing import Optional from apps.webui.internal.db import Base, get_db - -import json +from pydantic import BaseModel, ConfigDict +from sqlalchemy import BigInteger, Column, String, Text #################### # Prompts DB Schema @@ -45,7 +42,6 @@ class PromptForm(BaseModel): class PromptsTable: - def insert_new_prompt( self, user_id: str, form_data: PromptForm ) -> Optional[PromptModel]: @@ -61,7 +57,6 @@ class PromptsTable: try: with get_db() as db: - result = Prompt(**prompt.dict()) db.add(result) db.commit() @@ -70,13 +65,12 @@ class PromptsTable: return PromptModel.model_validate(result) else: return None - except Exception as e: + except Exception: return None def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: try: with get_db() as db: - prompt = db.query(Prompt).filter_by(command=command).first() return PromptModel.model_validate(prompt) except Exception: @@ -84,7 +78,6 @@ class PromptsTable: def get_prompts(self) -> list[PromptModel]: with get_db() as db: - return [ PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all() ] @@ -94,7 +87,6 @@ class PromptsTable: ) -> Optional[PromptModel]: try: with get_db() as db: - prompt = db.query(Prompt).filter_by(command=command).first() prompt.title = form_data.title prompt.content = form_data.content @@ -107,7 +99,6 @@ class PromptsTable: def delete_prompt_by_command(self, command: str) -> bool: try: with get_db() as db: - db.query(Prompt).filter_by(command=command).delete() db.commit() diff --git a/backend/apps/webui/models/tags.py b/backend/apps/webui/models/tags.py index 605cca2e7..61d11e7c2 100644 --- a/backend/apps/webui/models/tags.py +++ b/backend/apps/webui/models/tags.py @@ -1,16 +1,12 @@ -from pydantic import BaseModel, ConfigDict +import logging +import time +import uuid from typing import Optional -import json -import uuid -import time -import logging - -from sqlalchemy import String, Column, BigInteger, Text - from apps.webui.internal.db import Base, get_db - from env import SRC_LOG_LEVELS +from pydantic import BaseModel, ConfigDict +from sqlalchemy import BigInteger, Column, String, Text log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -77,10 +73,8 @@ class ChatTagsResponse(BaseModel): class TagTable: - def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: with get_db() as db: - id = str(uuid.uuid4()) tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) try: @@ -92,7 +86,7 @@ class TagTable: return TagModel.model_validate(result) else: return None - except Exception as e: + except Exception: return None def get_tag_by_name_and_user_id( @@ -102,7 +96,7 @@ class TagTable: with get_db() as db: tag = db.query(Tag).filter_by(name=name, user_id=user_id).first() return TagModel.model_validate(tag) - except Exception as e: + except Exception: return None def add_tag_to_chat( @@ -161,7 +155,6 @@ class TagTable: self, chat_id: str, user_id: str ) -> list[TagModel]: with get_db() as db: - tag_names = [ chat_id_tag.tag_name for chat_id_tag in ( @@ -186,7 +179,6 @@ class TagTable: self, tag_name: str, user_id: str ) -> list[ChatIdTagModel]: with get_db() as db: - return [ ChatIdTagModel.model_validate(chat_id_tag) for chat_id_tag in ( @@ -201,7 +193,6 @@ class TagTable: self, tag_name: str, user_id: str ) -> int: with get_db() as db: - return ( db.query(ChatIdTag) .filter_by(tag_name=tag_name, user_id=user_id) @@ -236,7 +227,6 @@ class TagTable: ) -> bool: try: with get_db() as db: - res = ( db.query(ChatIdTag) .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py index 2f4c532b8..bff3e79dc 100644 --- a/backend/apps/webui/models/tools.py +++ b/backend/apps/webui/models/tools.py @@ -1,17 +1,12 @@ -from pydantic import BaseModel, ConfigDict -from typing import Optional -import time import logging -from sqlalchemy import String, Column, BigInteger, Text +import time +from typing import Optional from apps.webui.internal.db import Base, JSONField, get_db from apps.webui.models.users import Users - -import json -import copy - - from env import SRC_LOG_LEVELS +from pydantic import BaseModel, ConfigDict +from sqlalchemy import BigInteger, Column, String, Text log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -79,13 +74,10 @@ class ToolValves(BaseModel): class ToolsTable: - def insert_new_tool( self, user_id: str, form_data: ToolForm, specs: list[dict] ) -> Optional[ToolModel]: - with get_db() as db: - tool = ToolModel( **{ **form_data.model_dump(), @@ -112,7 +104,6 @@ class ToolsTable: def get_tool_by_id(self, id: str) -> Optional[ToolModel]: try: with get_db() as db: - tool = db.get(Tool, id) return ToolModel.model_validate(tool) except Exception: @@ -125,7 +116,6 @@ class ToolsTable: def get_tool_valves_by_id(self, id: str) -> Optional[dict]: try: with get_db() as db: - tool = db.get(Tool, id) return tool.valves if tool.valves else {} except Exception as e: @@ -135,7 +125,6 @@ class ToolsTable: def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: try: with get_db() as db: - db.query(Tool).filter_by(id=id).update( {"valves": valves, "updated_at": int(time.time())} ) diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index b6e85e2ca..25d84b03e 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -1,11 +1,10 @@ -from pydantic import BaseModel, ConfigDict -from typing import Optional import time - -from sqlalchemy import String, Column, BigInteger, Text +from typing import Optional from apps.webui.internal.db import Base, JSONField, get_db from apps.webui.models.chats import Chats +from pydantic import BaseModel, ConfigDict +from sqlalchemy import BigInteger, Column, String, Text #################### # User DB Schema @@ -113,7 +112,7 @@ class UsersTable: with get_db() as db: user = db.query(User).filter_by(id=id).first() return UserModel.model_validate(user) - except Exception as e: + except Exception: return None def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: @@ -221,7 +220,7 @@ class UsersTable: user = db.query(User).filter_by(id=id).first() return UserModel.model_validate(user) # return UserModel(**user.dict()) - except Exception as e: + except Exception: return None def delete_user_by_id(self, id: str) -> bool: @@ -255,7 +254,7 @@ class UsersTable: with get_db() as db: user = db.query(User).filter_by(id=id).first() return user.api_key - except Exception as e: + except Exception: return None diff --git a/backend/apps/webui/routers/auths.py b/backend/apps/webui/routers/auths.py index 8909b1e05..b96030807 100644 --- a/backend/apps/webui/routers/auths.py +++ b/backend/apps/webui/routers/auths.py @@ -1,43 +1,33 @@ -import logging - -from fastapi import Request, UploadFile, File -from fastapi import Depends, HTTPException, status -from fastapi.responses import Response - -from fastapi import APIRouter -from pydantic import BaseModel import re import uuid -import csv from apps.webui.models.auths import ( - SigninForm, - SignupForm, AddUserForm, - UpdateProfileForm, - UpdatePasswordForm, - UserResponse, - SigninResponse, - Auths, ApiKey, + Auths, + SigninForm, + SigninResponse, + SignupForm, + UpdatePasswordForm, + UpdateProfileForm, + UserResponse, ) from apps.webui.models.users import Users - -from utils.utils import ( - get_password_hash, - get_current_user, - get_admin_user, - create_token, - create_api_key, -) -from utils.misc import parse_duration, validate_email_format -from utils.webhook import post_webhook +from config import WEBUI_AUTH from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES -from config import ( - WEBUI_AUTH, - WEBUI_AUTH_TRUSTED_EMAIL_HEADER, - WEBUI_AUTH_TRUSTED_NAME_HEADER, +from env import WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER +from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi.responses import Response +from pydantic import BaseModel +from utils.misc import parse_duration, validate_email_format +from utils.utils import ( + create_api_key, + create_token, + get_admin_user, + get_current_user, + get_password_hash, ) +from utils.webhook import post_webhook router = APIRouter() @@ -273,7 +263,6 @@ async def signup(request: Request, response: Response, form_data: SignupForm): @router.post("/add", response_model=SigninResponse) async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)): - if not validate_email_format(form_data.email.lower()): raise HTTPException( status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT @@ -283,7 +272,6 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)): raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) try: - print(form_data) hashed = get_password_hash(form_data.password) user = Auths.insert_new_auth( diff --git a/backend/apps/webui/routers/chats.py b/backend/apps/webui/routers/chats.py index 6621e7337..d0a5c6fc5 100644 --- a/backend/apps/webui/routers/chats.py +++ b/backend/apps/webui/routers/chats.py @@ -1,34 +1,15 @@ -from fastapi import Depends, Request, HTTPException, status -from datetime import datetime, timedelta -from typing import Union, Optional -from utils.utils import get_verified_user, get_admin_user -from fastapi import APIRouter -from pydantic import BaseModel import json import logging +from typing import Optional -from apps.webui.models.users import Users -from apps.webui.models.chats import ( - ChatModel, - ChatResponse, - ChatTitleForm, - ChatForm, - ChatTitleIdResponse, - Chats, -) - - -from apps.webui.models.tags import ( - TagModel, - ChatIdTagModel, - ChatIdTagForm, - ChatTagsResponse, - Tags, -) - +from apps.webui.models.chats import ChatForm, ChatResponse, Chats, ChatTitleIdResponse +from apps.webui.models.tags import ChatIdTagForm, ChatIdTagModel, TagModel, Tags +from config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT from constants import ERROR_MESSAGES - -from config import SRC_LOG_LEVELS, ENABLE_ADMIN_EXPORT, ENABLE_ADMIN_CHAT_ACCESS +from env import SRC_LOG_LEVELS +from fastapi import APIRouter, Depends, HTTPException, Request, status +from pydantic import BaseModel +from utils.utils import get_admin_user, get_verified_user log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -61,7 +42,6 @@ async def get_session_user_chat_list( @router.delete("/", response_model=bool) async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)): - if ( user.role == "user" and not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"] @@ -220,7 +200,6 @@ class TagNameForm(BaseModel): async def get_user_chat_list_by_tag_name( form_data: TagNameForm, user=Depends(get_verified_user) ): - chat_ids = [ chat_id_tag.chat_id for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id( @@ -299,7 +278,6 @@ async def update_chat_by_id( @router.delete("/{id}", response_model=bool) async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)): - if user.role == "admin": result = Chats.delete_chat_by_id(id) return result @@ -323,7 +301,6 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified async def clone_chat_by_id(id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: - chat_body = json.loads(chat.chat) updated_chat = { **chat_body, diff --git a/backend/apps/webui/routers/configs.py b/backend/apps/webui/routers/configs.py index 68c687374..5c891f32a 100644 --- a/backend/apps/webui/routers/configs.py +++ b/backend/apps/webui/routers/configs.py @@ -1,25 +1,7 @@ -from fastapi import Response, Request -from fastapi import Depends, FastAPI, HTTPException, status -from datetime import datetime, timedelta -from typing import Union - -from fastapi import APIRouter -from pydantic import BaseModel -import time -import uuid - from config import BannerModel - -from apps.webui.models.users import Users - -from utils.utils import ( - get_password_hash, - get_verified_user, - get_admin_user, - create_token, -) -from utils.misc import get_gravatar_url, validate_email_format -from constants import ERROR_MESSAGES +from fastapi import APIRouter, Depends, Request +from pydantic import BaseModel +from utils.utils import get_admin_user, get_verified_user router = APIRouter() diff --git a/backend/apps/webui/routers/documents.py b/backend/apps/webui/routers/documents.py index 3bb2aa15b..f4ffc100c 100644 --- a/backend/apps/webui/routers/documents.py +++ b/backend/apps/webui/routers/documents.py @@ -1,21 +1,16 @@ -from fastapi import Depends, FastAPI, HTTPException, status -from datetime import datetime, timedelta -from typing import Union, Optional - -from fastapi import APIRouter -from pydantic import BaseModel import json +from typing import Optional from apps.webui.models.documents import ( - Documents, DocumentForm, - DocumentUpdateForm, - DocumentModel, DocumentResponse, + Documents, + DocumentUpdateForm, ) - -from utils.utils import get_verified_user, get_admin_user from constants import ERROR_MESSAGES +from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import BaseModel +from utils.utils import get_admin_user, get_verified_user router = APIRouter() diff --git a/backend/apps/webui/routers/files.py b/backend/apps/webui/routers/files.py index 48ca366d8..2e005d63d 100644 --- a/backend/apps/webui/routers/files.py +++ b/backend/apps/webui/routers/files.py @@ -1,42 +1,17 @@ -from fastapi import ( - Depends, - FastAPI, - HTTPException, - status, - Request, - UploadFile, - File, - Form, -) - - -from datetime import datetime, timedelta -from typing import Union, Optional -from pathlib import Path - -from fastapi import APIRouter -from fastapi.responses import StreamingResponse, JSONResponse, FileResponse - -from pydantic import BaseModel -import json - -from apps.webui.models.files import ( - Files, - FileForm, - FileModel, - FileModelResponse, -) -from utils.utils import get_verified_user, get_admin_user -from constants import ERROR_MESSAGES - -from importlib import util +import logging import os +import shutil import uuid -import os, shutil, logging, re - - -from config import SRC_LOG_LEVELS, UPLOAD_DIR +from pathlib import Path +from typing import Optional +from apps.webui.models.files import FileForm, FileModel, Files +from config import UPLOAD_DIR +from constants import ERROR_MESSAGES +from env import SRC_LOG_LEVELS +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status +from fastapi.responses import FileResponse +from utils.utils import get_admin_user, get_verified_user log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) diff --git a/backend/apps/webui/routers/functions.py b/backend/apps/webui/routers/functions.py index f40d28264..09c41513f 100644 --- a/backend/apps/webui/routers/functions.py +++ b/backend/apps/webui/routers/functions.py @@ -1,27 +1,18 @@ -from fastapi import Depends, FastAPI, HTTPException, status, Request -from datetime import datetime, timedelta -from typing import Union, Optional - -from fastapi import APIRouter -from pydantic import BaseModel -import json +import os +from pathlib import Path +from typing import Optional from apps.webui.models.functions import ( - Functions, FunctionForm, FunctionModel, FunctionResponse, + Functions, ) from apps.webui.utils import load_function_module_by_id -from utils.utils import get_verified_user, get_admin_user +from config import CACHE_DIR, FUNCTIONS_DIR from constants import ERROR_MESSAGES - -from importlib import util -import os -from pathlib import Path - -from config import DATA_DIR, CACHE_DIR, FUNCTIONS_DIR - +from fastapi import APIRouter, Depends, HTTPException, Request, status +from utils.utils import get_admin_user, get_verified_user router = APIRouter() @@ -304,7 +295,6 @@ async def update_function_valves_by_id( ): function = Functions.get_function_by_id(id) if function: - if id in request.app.state.FUNCTIONS: function_module = request.app.state.FUNCTIONS[id] else: diff --git a/backend/apps/webui/routers/memories.py b/backend/apps/webui/routers/memories.py index ae0a9efcb..5a1588178 100644 --- a/backend/apps/webui/routers/memories.py +++ b/backend/apps/webui/routers/memories.py @@ -1,18 +1,12 @@ -from fastapi import Response, Request -from fastapi import Depends, FastAPI, HTTPException, status -from datetime import datetime, timedelta -from typing import Union, Optional - -from fastapi import APIRouter -from pydantic import BaseModel import logging +from typing import Optional from apps.webui.models.memories import Memories, MemoryModel - +from config import CHROMA_CLIENT +from env import SRC_LOG_LEVELS +from fastapi import APIRouter, Depends, HTTPException, Request +from pydantic import BaseModel from utils.utils import get_verified_user -from constants import ERROR_MESSAGES - -from config import SRC_LOG_LEVELS, CHROMA_CLIENT log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) diff --git a/backend/apps/webui/routers/models.py b/backend/apps/webui/routers/models.py index 8faeed7a6..2c12dddd7 100644 --- a/backend/apps/webui/routers/models.py +++ b/backend/apps/webui/routers/models.py @@ -1,15 +1,9 @@ -from fastapi import Depends, FastAPI, HTTPException, status, Request -from datetime import datetime, timedelta -from typing import Union, Optional +from typing import Optional -from fastapi import APIRouter -from pydantic import BaseModel -import json - -from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse - -from utils.utils import get_verified_user, get_admin_user +from apps.webui.models.models import ModelForm, ModelModel, ModelResponse, Models from constants import ERROR_MESSAGES +from fastapi import APIRouter, Depends, HTTPException, Request, status +from utils.utils import get_admin_user, get_verified_user router = APIRouter() diff --git a/backend/apps/webui/routers/prompts.py b/backend/apps/webui/routers/prompts.py index 39d79362a..c65fef8c6 100644 --- a/backend/apps/webui/routers/prompts.py +++ b/backend/apps/webui/routers/prompts.py @@ -1,15 +1,9 @@ -from fastapi import Depends, FastAPI, HTTPException, status -from datetime import datetime, timedelta -from typing import Union, Optional +from typing import Optional -from fastapi import APIRouter -from pydantic import BaseModel -import json - -from apps.webui.models.prompts import Prompts, PromptForm, PromptModel - -from utils.utils import get_verified_user, get_admin_user +from apps.webui.models.prompts import PromptForm, PromptModel, Prompts from constants import ERROR_MESSAGES +from fastapi import APIRouter, Depends, HTTPException, status +from utils.utils import get_admin_user, get_verified_user router = APIRouter() diff --git a/backend/apps/webui/routers/tools.py b/backend/apps/webui/routers/tools.py index d6da7ae92..b293f5b57 100644 --- a/backend/apps/webui/routers/tools.py +++ b/backend/apps/webui/routers/tools.py @@ -1,20 +1,14 @@ -from fastapi import Depends, HTTPException, status, Request -from typing import Optional - -from fastapi import APIRouter - -from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse -from apps.webui.utils import load_toolkit_module_by_id - -from utils.utils import get_admin_user, get_verified_user -from utils.tools import get_tools_specs -from constants import ERROR_MESSAGES - import os from pathlib import Path +from typing import Optional -from config import DATA_DIR, CACHE_DIR - +from apps.webui.models.tools import ToolForm, ToolModel, ToolResponse, Tools +from apps.webui.utils import load_toolkit_module_by_id +from config import CACHE_DIR, DATA_DIR +from constants import ERROR_MESSAGES +from fastapi import APIRouter, Depends, HTTPException, Request, status +from utils.tools import get_tools_specs +from utils.utils import get_admin_user, get_verified_user TOOLS_DIR = f"{DATA_DIR}/tools" os.makedirs(TOOLS_DIR, exist_ok=True) diff --git a/backend/apps/webui/routers/users.py b/backend/apps/webui/routers/users.py index 543757275..262c39434 100644 --- a/backend/apps/webui/routers/users.py +++ b/backend/apps/webui/routers/users.py @@ -1,33 +1,20 @@ -from fastapi import Response, Request -from fastapi import Depends, FastAPI, HTTPException, status -from datetime import datetime, timedelta -from typing import Union, Optional - -from fastapi import APIRouter -from pydantic import BaseModel -import time -import uuid import logging +from typing import Optional -from apps.webui.models.users import ( - UserModel, - UserUpdateForm, - UserRoleUpdateForm, - UserSettings, - Users, -) from apps.webui.models.auths import Auths from apps.webui.models.chats import Chats - -from utils.utils import ( - get_verified_user, - get_password_hash, - get_current_user, - get_admin_user, +from apps.webui.models.users import ( + UserModel, + UserRoleUpdateForm, + Users, + UserSettings, + UserUpdateForm, ) from constants import ERROR_MESSAGES - -from config import SRC_LOG_LEVELS +from env import SRC_LOG_LEVELS +from fastapi import APIRouter, Depends, HTTPException, Request, status +from pydantic import BaseModel +from utils.utils import get_admin_user, get_password_hash, get_verified_user log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -69,7 +56,6 @@ async def update_user_permissions( @router.post("/update/role", response_model=Optional[UserModel]) async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)): - if user.id != form_data.id and form_data.id != Users.get_first_user().id: return Users.update_user_role_by_id(form_data.id, form_data.role) @@ -173,7 +159,6 @@ class UserResponse(BaseModel): @router.get("/{user_id}", response_model=UserResponse) async def get_user_by_id(user_id: str, user=Depends(get_verified_user)): - # Check if user_id is a shared chat # If it is, get the user_id from the chat if user_id.startswith("shared-"): diff --git a/backend/apps/webui/routers/utils.py b/backend/apps/webui/routers/utils.py index 8bf8267da..569a11d6a 100644 --- a/backend/apps/webui/routers/utils.py +++ b/backend/apps/webui/routers/utils.py @@ -1,23 +1,16 @@ -from pathlib import Path import site +from pathlib import Path -from fastapi import APIRouter, UploadFile, File, Response -from fastapi import Depends, HTTPException, status -from starlette.responses import StreamingResponse, FileResponse -from pydantic import BaseModel - - -from fpdf import FPDF -import markdown import black - - -from utils.utils import get_admin_user -from utils.misc import calculate_sha256, get_gravatar_url - -from config import OLLAMA_BASE_URLS, DATA_DIR, UPLOAD_DIR, ENABLE_ADMIN_EXPORT +import markdown +from config import DATA_DIR, ENABLE_ADMIN_EXPORT from constants import ERROR_MESSAGES - +from fastapi import APIRouter, Depends, HTTPException, Response, status +from fpdf import FPDF +from pydantic import BaseModel +from starlette.responses import FileResponse +from utils.misc import get_gravatar_url +from utils.utils import get_admin_user router = APIRouter() @@ -115,7 +108,7 @@ async def download_chat_as_pdf( return Response( content=bytes(pdf_bytes), media_type="application/pdf", - headers={"Content-Disposition": f"attachment;filename=chat.pdf"}, + headers={"Content-Disposition": "attachment;filename=chat.pdf"}, ) diff --git a/backend/apps/webui/utils.py b/backend/apps/webui/utils.py index a556b8e8c..d04931c23 100644 --- a/backend/apps/webui/utils.py +++ b/backend/apps/webui/utils.py @@ -1,13 +1,12 @@ -from importlib import util import os import re -import sys import subprocess +import sys +from importlib import util - -from apps.webui.models.tools import Tools from apps.webui.models.functions import Functions -from config import TOOLS_DIR, FUNCTIONS_DIR +from apps.webui.models.tools import Tools +from config import FUNCTIONS_DIR, TOOLS_DIR def extract_frontmatter(file_path): diff --git a/backend/config.py b/backend/config.py index a11b105b7..1fc673ad9 100644 --- a/backend/config.py +++ b/backend/config.py @@ -1,58 +1,30 @@ -from sqlalchemy import create_engine, Column, Integer, DateTime, JSON, func -from contextlib import contextmanager - - -import os -import sys +import json import logging -import importlib.metadata -import pkgutil -from urllib.parse import urlparse +import os +import shutil from datetime import datetime +from pathlib import Path +from typing import Generic, Optional, TypeVar +from urllib.parse import urlparse import chromadb -from chromadb import Settings -from typing import TypeVar, Generic -from pydantic import BaseModel -from typing import Optional - -from pathlib import Path -import json -import yaml - import requests -import shutil - - +import yaml from apps.webui.internal.db import Base, get_db - -from constants import ERROR_MESSAGES - +from chromadb import Settings from env import ( - ENV, - VERSION, - SAFE_MODE, - GLOBAL_LOG_LEVEL, - SRC_LOG_LEVELS, - BASE_DIR, - DATA_DIR, BACKEND_DIR, - FRONTEND_BUILD_DIR, - WEBUI_NAME, - WEBUI_URL, - WEBUI_FAVICON_URL, - WEBUI_BUILD_HASH, CONFIG_DATA, - DATABASE_URL, - CHANGELOG, + DATA_DIR, + ENV, + FRONTEND_BUILD_DIR, WEBUI_AUTH, - WEBUI_AUTH_TRUSTED_EMAIL_HEADER, - WEBUI_AUTH_TRUSTED_NAME_HEADER, - WEBUI_SECRET_KEY, - WEBUI_SESSION_COOKIE_SAME_SITE, - WEBUI_SESSION_COOKIE_SECURE, + WEBUI_FAVICON_URL, + WEBUI_NAME, log, ) +from pydantic import BaseModel +from sqlalchemy import JSON, Column, DateTime, Integer, func class EndpointFilter(logging.Filter): @@ -72,8 +44,8 @@ logging.getLogger("uvicorn.access").addFilter(EndpointFilter()) def run_migrations(): print("Running migrations") try: - from alembic.config import Config from alembic import command + from alembic.config import Config alembic_cfg = Config(BACKEND_DIR / "alembic.ini") command.upgrade(alembic_cfg, "head") diff --git a/backend/env.py b/backend/env.py index 689dc1b6d..7cd0727ad 100644 --- a/backend/env.py +++ b/backend/env.py @@ -1,19 +1,13 @@ -from pathlib import Path -import os -import logging -import sys -import json - - import importlib.metadata +import json +import logging +import os import pkgutil -from urllib.parse import urlparse -from datetime import datetime - +import sys +from pathlib import Path import markdown from bs4 import BeautifulSoup - from constants import ERROR_MESSAGES #################################### @@ -26,7 +20,7 @@ BASE_DIR = BACKEND_DIR.parent # the path containing the backend/ print(BASE_DIR) try: - from dotenv import load_dotenv, find_dotenv + from dotenv import find_dotenv, load_dotenv load_dotenv(find_dotenv(str(BASE_DIR / ".env"))) except ImportError: diff --git a/backend/main.py b/backend/main.py index 4b91cdc84..29ae3f18c 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,130 +1,124 @@ import base64 +import inspect +import json +import logging +import mimetypes +import os +import shutil +import sys +import time import uuid from contextlib import asynccontextmanager -from authlib.integrations.starlette_client import OAuth -from authlib.oidc.core import UserInfo -import json -import time -import os -import sys -import logging -import aiohttp -import requests -import mimetypes -import shutil -import inspect from typing import Optional -from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form -from fastapi.staticfiles import StaticFiles -from fastapi.responses import JSONResponse -from fastapi import HTTPException +import aiohttp +import requests +from apps.audio.main import app as audio_app +from apps.images.main import app as images_app +from apps.ollama.main import app as ollama_app +from apps.ollama.main import ( + generate_openai_chat_completion as generate_ollama_chat_completion, +) +from apps.ollama.main import get_all_models as get_ollama_models +from apps.openai.main import app as openai_app +from apps.openai.main import generate_chat_completion as generate_openai_chat_completion +from apps.openai.main import get_all_models as get_openai_models +from apps.rag.main import app as rag_app +from apps.rag.utils import get_rag_context, rag_template +from apps.socket.main import app as socket_app +from apps.socket.main import get_event_call, get_event_emitter +from apps.webui.internal.db import Session +from apps.webui.main import app as webui_app +from apps.webui.main import generate_function_chat_completion, get_pipe_models +from apps.webui.models.auths import Auths +from apps.webui.models.functions import Functions +from apps.webui.models.models import Models +from apps.webui.models.users import UserModel, Users +from apps.webui.utils import load_function_module_by_id +from authlib.integrations.starlette_client import OAuth +from authlib.oidc.core import UserInfo +from config import ( + CACHE_DIR, + CORS_ALLOW_ORIGIN, + DEFAULT_LOCALE, + ENABLE_ADMIN_CHAT_ACCESS, + ENABLE_ADMIN_EXPORT, + ENABLE_MODEL_FILTER, + ENABLE_OAUTH_SIGNUP, + ENABLE_OLLAMA_API, + ENABLE_OPENAI_API, + ENV, + FRONTEND_BUILD_DIR, + MODEL_FILTER_LIST, + OAUTH_MERGE_ACCOUNTS_BY_EMAIL, + OAUTH_PROVIDERS, + SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, + SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, + STATIC_DIR, + TASK_MODEL, + TASK_MODEL_EXTERNAL, + TITLE_GENERATION_PROMPT_TEMPLATE, + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + WEBHOOK_URL, + WEBUI_AUTH, + WEBUI_NAME, + AppConfig, + run_migrations, +) +from constants import ERROR_MESSAGES, TASKS, WEBHOOK_MESSAGES +from env import ( + CHANGELOG, + GLOBAL_LOG_LEVEL, + SAFE_MODE, + SRC_LOG_LEVELS, + VERSION, + WEBUI_BUILD_HASH, + WEBUI_SECRET_KEY, + WEBUI_SESSION_COOKIE_SAME_SITE, + WEBUI_SESSION_COOKIE_SECURE, + WEBUI_URL, +) +from fastapi import ( + Depends, + FastAPI, + File, + Form, + HTTPException, + Request, + UploadFile, + status, +) from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from fastapi.staticfiles import StaticFiles +from pydantic import BaseModel from sqlalchemy import text from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.sessions import SessionMiddleware -from starlette.responses import StreamingResponse, Response, RedirectResponse - - -from apps.socket.main import app as socket_app, get_event_emitter, get_event_call -from apps.ollama.main import ( - app as ollama_app, - get_all_models as get_ollama_models, - generate_openai_chat_completion as generate_ollama_chat_completion, +from starlette.responses import RedirectResponse, Response, StreamingResponse +from utils.misc import ( + add_or_update_system_message, + get_last_user_message, + parse_duration, + prepend_to_first_user_message_content, ) -from apps.openai.main import ( - app as openai_app, - get_all_models as get_openai_models, - generate_chat_completion as generate_openai_chat_completion, +from utils.task import ( + moa_response_generation_template, + search_query_generation_template, + title_generation_template, + tools_function_calling_generation_template, ) - -from apps.audio.main import app as audio_app -from apps.images.main import app as images_app -from apps.rag.main import app as rag_app -from apps.webui.main import ( - app as webui_app, - get_pipe_models, - generate_function_chat_completion, -) -from apps.webui.internal.db import Session - - -from pydantic import BaseModel - -from apps.webui.models.auths import Auths -from apps.webui.models.models import Models -from apps.webui.models.functions import Functions -from apps.webui.models.users import Users, UserModel - -from apps.webui.utils import load_function_module_by_id - +from utils.tools import get_tools from utils.utils import ( + create_token, + decode_token, get_admin_user, - get_verified_user, get_current_user, get_http_authorization_cred, get_password_hash, - create_token, - decode_token, + get_verified_user, ) -from utils.task import ( - title_generation_template, - search_query_generation_template, - tools_function_calling_generation_template, - moa_response_generation_template, -) - -from utils.tools import get_tools -from utils.misc import ( - get_last_user_message, - add_or_update_system_message, - prepend_to_first_user_message_content, - parse_duration, -) - -from apps.rag.utils import get_rag_context, rag_template - -from config import ( - run_migrations, - WEBUI_NAME, - WEBUI_URL, - WEBUI_AUTH, - ENV, - VERSION, - CHANGELOG, - FRONTEND_BUILD_DIR, - CACHE_DIR, - STATIC_DIR, - DEFAULT_LOCALE, - ENABLE_OPENAI_API, - ENABLE_OLLAMA_API, - ENABLE_MODEL_FILTER, - MODEL_FILTER_LIST, - GLOBAL_LOG_LEVEL, - SRC_LOG_LEVELS, - WEBHOOK_URL, - ENABLE_ADMIN_EXPORT, - WEBUI_BUILD_HASH, - TASK_MODEL, - TASK_MODEL_EXTERNAL, - TITLE_GENERATION_PROMPT_TEMPLATE, - SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, - SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, - TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - SAFE_MODE, - OAUTH_PROVIDERS, - ENABLE_OAUTH_SIGNUP, - OAUTH_MERGE_ACCOUNTS_BY_EMAIL, - WEBUI_SECRET_KEY, - WEBUI_SESSION_COOKIE_SAME_SITE, - WEBUI_SESSION_COOKIE_SECURE, - ENABLE_ADMIN_CHAT_ACCESS, - AppConfig, - CORS_ALLOW_ORIGIN, -) - -from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS from utils.webhook import post_webhook if SAFE_MODE: diff --git a/backend/migrations/env.py b/backend/migrations/env.py index b3b3407fa..11fd03e4e 100644 --- a/backend/migrations/env.py +++ b/backend/migrations/env.py @@ -1,24 +1,9 @@ -import os from logging.config import fileConfig -from sqlalchemy import engine_from_config -from sqlalchemy import pool - from alembic import context - from apps.webui.models.auths import Auth -from apps.webui.models.chats import Chat -from apps.webui.models.documents import Document -from apps.webui.models.memories import Memory -from apps.webui.models.models import Model -from apps.webui.models.prompts import Prompt -from apps.webui.models.tags import Tag, ChatIdTag -from apps.webui.models.tools import Tool -from apps.webui.models.users import User -from apps.webui.models.files import File -from apps.webui.models.functions import Function - from env import DATABASE_URL +from sqlalchemy import engine_from_config, pool # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/backend/migrations/versions/7e5b5dc7342b_init.py b/backend/migrations/versions/7e5b5dc7342b_init.py index b82627f5b..a511c5247 100644 --- a/backend/migrations/versions/7e5b5dc7342b_init.py +++ b/backend/migrations/versions/7e5b5dc7342b_init.py @@ -1,16 +1,16 @@ """init Revision ID: 7e5b5dc7342b -Revises: +Revises: Create Date: 2024-06-24 13:15:33.808998 """ from typing import Sequence, Union -from alembic import op -import sqlalchemy as sa import apps.webui.internal.db +import sqlalchemy as sa +from alembic import op from migrations.util import get_existing_tables # revision identifiers, used by Alembic. diff --git a/backend/migrations/versions/ca81bd47c050_add_config_table.py b/backend/migrations/versions/ca81bd47c050_add_config_table.py index b9f708240..1540aa6a7 100644 --- a/backend/migrations/versions/ca81bd47c050_add_config_table.py +++ b/backend/migrations/versions/ca81bd47c050_add_config_table.py @@ -8,10 +8,8 @@ Create Date: 2024-08-25 15:26:35.241684 from typing import Sequence, Union -from alembic import op import sqlalchemy as sa -import apps.webui.internal.db - +from alembic import op # revision identifiers, used by Alembic. revision: str = "ca81bd47c050" diff --git a/backend/test/apps/webui/routers/test_auths.py b/backend/test/apps/webui/routers/test_auths.py index 3a8695a69..9f6890da1 100644 --- a/backend/test/apps/webui/routers/test_auths.py +++ b/backend/test/apps/webui/routers/test_auths.py @@ -1,5 +1,3 @@ -import pytest - from test.util.abstract_integration_test import AbstractPostgresTest from test.util.mock_user import mock_webui_user @@ -9,8 +7,8 @@ class TestAuths(AbstractPostgresTest): def setup_class(cls): super().setup_class() - from apps.webui.models.users import Users from apps.webui.models.auths import Auths + from apps.webui.models.users import Users cls.users = Users cls.auths = Auths diff --git a/backend/test/apps/webui/routers/test_chats.py b/backend/test/apps/webui/routers/test_chats.py index f4661b625..62244a978 100644 --- a/backend/test/apps/webui/routers/test_chats.py +++ b/backend/test/apps/webui/routers/test_chats.py @@ -5,7 +5,6 @@ from test.util.mock_user import mock_webui_user class TestChats(AbstractPostgresTest): - BASE_PATH = "/api/v1/chats" def setup_class(cls): @@ -13,8 +12,7 @@ class TestChats(AbstractPostgresTest): def setup_method(self): super().setup_method() - from apps.webui.models.chats import ChatForm - from apps.webui.models.chats import Chats + from apps.webui.models.chats import ChatForm, Chats self.chats = Chats self.chats.insert_new_chat( diff --git a/backend/test/apps/webui/routers/test_documents.py b/backend/test/apps/webui/routers/test_documents.py index 14ca339fd..7f601e344 100644 --- a/backend/test/apps/webui/routers/test_documents.py +++ b/backend/test/apps/webui/routers/test_documents.py @@ -3,7 +3,6 @@ from test.util.mock_user import mock_webui_user class TestDocuments(AbstractPostgresTest): - BASE_PATH = "/api/v1/documents" def setup_class(cls): diff --git a/backend/test/apps/webui/routers/test_models.py b/backend/test/apps/webui/routers/test_models.py index 410c4516a..09329d716 100644 --- a/backend/test/apps/webui/routers/test_models.py +++ b/backend/test/apps/webui/routers/test_models.py @@ -3,7 +3,6 @@ from test.util.mock_user import mock_webui_user class TestModels(AbstractPostgresTest): - BASE_PATH = "/api/v1/models" def setup_class(cls): diff --git a/backend/test/apps/webui/routers/test_prompts.py b/backend/test/apps/webui/routers/test_prompts.py index 9f47be992..d91bf77dc 100644 --- a/backend/test/apps/webui/routers/test_prompts.py +++ b/backend/test/apps/webui/routers/test_prompts.py @@ -3,7 +3,6 @@ from test.util.mock_user import mock_webui_user class TestPrompts(AbstractPostgresTest): - BASE_PATH = "/api/v1/prompts" def test_prompts(self): diff --git a/backend/test/apps/webui/routers/test_users.py b/backend/test/apps/webui/routers/test_users.py index 9736b4d32..f00cb2f9f 100644 --- a/backend/test/apps/webui/routers/test_users.py +++ b/backend/test/apps/webui/routers/test_users.py @@ -21,7 +21,6 @@ def _assert_user(data, id, **kwargs): class TestUsers(AbstractPostgresTest): - BASE_PATH = "/api/v1/users" def setup_class(cls): diff --git a/backend/utils/misc.py b/backend/utils/misc.py index df35732c0..6e5a96fcb 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -1,10 +1,10 @@ -from pathlib import Path import hashlib import re -from datetime import timedelta -from typing import Optional, Callable -import uuid import time +import uuid +from datetime import timedelta +from pathlib import Path +from typing import Callable, Optional from utils.task import prompt_template diff --git a/backend/utils/schemas.py b/backend/utils/schemas.py index cb029ade3..958e57318 100644 --- a/backend/utils/schemas.py +++ b/backend/utils/schemas.py @@ -1,7 +1,7 @@ from ast import literal_eval +from typing import Any, Literal, Optional, Type from pydantic import BaseModel, Field, create_model -from typing import Any, Optional, Type, Literal def json_schema_to_model(tool_dict: dict[str, Any]) -> Type[BaseModel]: diff --git a/backend/utils/task.py b/backend/utils/task.py index ea9254c4f..cf3d8a10c 100644 --- a/backend/utils/task.py +++ b/backend/utils/task.py @@ -1,6 +1,5 @@ -import re import math - +import re from datetime import datetime from typing import Optional diff --git a/backend/utils/tools.py b/backend/utils/tools.py index 1a2fea32b..0f619ca8e 100644 --- a/backend/utils/tools.py +++ b/backend/utils/tools.py @@ -5,7 +5,6 @@ from typing import Awaitable, Callable, get_type_hints from apps.webui.models.tools import Tools from apps.webui.models.users import UserModel from apps.webui.utils import load_toolkit_module_by_id - from utils.schemas import json_schema_to_model log = logging.getLogger(__name__) diff --git a/backend/utils/utils.py b/backend/utils/utils.py index 4c15ea237..0e768eb7b 100644 --- a/backend/utils/utils.py +++ b/backend/utils/utils.py @@ -1,16 +1,15 @@ -from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials -from fastapi import HTTPException, status, Depends, Request - -from apps.webui.models.users import Users - -from typing import Union, Optional -from constants import ERROR_MESSAGES -from passlib.context import CryptContext -from datetime import datetime, timedelta, UTC -import jwt -import uuid import logging +import uuid +from datetime import UTC, datetime, timedelta +from typing import Optional, Union + +import jwt +from apps.webui.models.users import Users +from constants import ERROR_MESSAGES from env import WEBUI_SECRET_KEY +from fastapi import Depends, HTTPException, Request, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from passlib.context import CryptContext logging.getLogger("passlib").setLevel(logging.ERROR) diff --git a/backend/utils/webhook.py b/backend/utils/webhook.py index b6692e53a..e903fdb2f 100644 --- a/backend/utils/webhook.py +++ b/backend/utils/webhook.py @@ -1,8 +1,9 @@ import json -import requests import logging -from config import SRC_LOG_LEVELS, VERSION, WEBUI_FAVICON_URL, WEBUI_NAME +import requests +from config import WEBUI_FAVICON_URL, WEBUI_NAME +from env import SRC_LOG_LEVELS, VERSION log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["WEBHOOK"])