mirror of
https://github.com/open-webui/open-webui
synced 2025-05-24 14:54:33 +00:00
sort and fix backend imports
This commit is contained in:
parent
08efabc696
commit
c386d0b1a5
@ -7,46 +7,33 @@ from functools import lru_cache
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from fastapi import (
|
from config import (
|
||||||
FastAPI,
|
AUDIO_STT_ENGINE,
|
||||||
Request,
|
AUDIO_STT_MODEL,
|
||||||
Depends,
|
AUDIO_STT_OPENAI_API_BASE_URL,
|
||||||
HTTPException,
|
AUDIO_STT_OPENAI_API_KEY,
|
||||||
status,
|
AUDIO_TTS_API_KEY,
|
||||||
UploadFile,
|
AUDIO_TTS_ENGINE,
|
||||||
File,
|
AUDIO_TTS_MODEL,
|
||||||
|
AUDIO_TTS_OPENAI_API_BASE_URL,
|
||||||
|
AUDIO_TTS_OPENAI_API_KEY,
|
||||||
|
AUDIO_TTS_SPLIT_ON,
|
||||||
|
AUDIO_TTS_VOICE,
|
||||||
|
CACHE_DIR,
|
||||||
|
CORS_ALLOW_ORIGIN,
|
||||||
|
DEVICE_TYPE,
|
||||||
|
WHISPER_MODEL,
|
||||||
|
WHISPER_MODEL_AUTO_UPDATE,
|
||||||
|
WHISPER_MODEL_DIR,
|
||||||
|
AppConfig,
|
||||||
)
|
)
|
||||||
|
from constants import ERROR_MESSAGES
|
||||||
|
from env import SRC_LOG_LEVELS
|
||||||
|
from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from utils.utils import get_admin_user, get_current_user, get_verified_user
|
||||||
from config import (
|
|
||||||
SRC_LOG_LEVELS,
|
|
||||||
CACHE_DIR,
|
|
||||||
WHISPER_MODEL,
|
|
||||||
WHISPER_MODEL_DIR,
|
|
||||||
WHISPER_MODEL_AUTO_UPDATE,
|
|
||||||
DEVICE_TYPE,
|
|
||||||
AUDIO_STT_OPENAI_API_BASE_URL,
|
|
||||||
AUDIO_STT_OPENAI_API_KEY,
|
|
||||||
AUDIO_TTS_OPENAI_API_BASE_URL,
|
|
||||||
AUDIO_TTS_OPENAI_API_KEY,
|
|
||||||
AUDIO_TTS_API_KEY,
|
|
||||||
AUDIO_STT_ENGINE,
|
|
||||||
AUDIO_STT_MODEL,
|
|
||||||
AUDIO_TTS_ENGINE,
|
|
||||||
AUDIO_TTS_MODEL,
|
|
||||||
AUDIO_TTS_VOICE,
|
|
||||||
AUDIO_TTS_SPLIT_ON,
|
|
||||||
AppConfig,
|
|
||||||
CORS_ALLOW_ORIGIN,
|
|
||||||
)
|
|
||||||
from constants import ERROR_MESSAGES
|
|
||||||
from utils.utils import (
|
|
||||||
get_current_user,
|
|
||||||
get_verified_user,
|
|
||||||
get_admin_user,
|
|
||||||
)
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
|
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
|
||||||
@ -211,7 +198,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|||||||
body = json.loads(body)
|
body = json.loads(body)
|
||||||
body["model"] = app.state.config.TTS_MODEL
|
body["model"] = app.state.config.TTS_MODEL
|
||||||
body = json.dumps(body).encode("utf-8")
|
body = json.dumps(body).encode("utf-8")
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
r = None
|
r = None
|
||||||
@ -488,7 +475,7 @@ def get_available_voices() -> dict:
|
|||||||
elif app.state.config.TTS_ENGINE == "elevenlabs":
|
elif app.state.config.TTS_ENGINE == "elevenlabs":
|
||||||
try:
|
try:
|
||||||
ret = get_elevenlabs_voices()
|
ret = get_elevenlabs_voices()
|
||||||
except Exception as e:
|
except Exception:
|
||||||
# Avoided @lru_cache with exception
|
# Avoided @lru_cache with exception
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -1,52 +1,42 @@
|
|||||||
from fastapi import (
|
import asyncio
|
||||||
FastAPI,
|
|
||||||
Request,
|
|
||||||
Depends,
|
|
||||||
HTTPException,
|
|
||||||
)
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from typing import Optional
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from pathlib import Path
|
|
||||||
import mimetypes
|
|
||||||
import uuid
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import mimetypes
|
||||||
import re
|
import re
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from utils.utils import (
|
|
||||||
get_verified_user,
|
|
||||||
get_admin_user,
|
|
||||||
)
|
|
||||||
|
|
||||||
from apps.images.utils.comfyui import (
|
from apps.images.utils.comfyui import (
|
||||||
ComfyUIWorkflow,
|
|
||||||
ComfyUIGenerateImageForm,
|
ComfyUIGenerateImageForm,
|
||||||
|
ComfyUIWorkflow,
|
||||||
comfyui_generate_image,
|
comfyui_generate_image,
|
||||||
)
|
)
|
||||||
|
|
||||||
from constants import ERROR_MESSAGES
|
|
||||||
from config import (
|
from config import (
|
||||||
SRC_LOG_LEVELS,
|
|
||||||
CACHE_DIR,
|
|
||||||
IMAGE_GENERATION_ENGINE,
|
|
||||||
ENABLE_IMAGE_GENERATION,
|
|
||||||
AUTOMATIC1111_BASE_URL,
|
|
||||||
AUTOMATIC1111_API_AUTH,
|
AUTOMATIC1111_API_AUTH,
|
||||||
|
AUTOMATIC1111_BASE_URL,
|
||||||
|
CACHE_DIR,
|
||||||
COMFYUI_BASE_URL,
|
COMFYUI_BASE_URL,
|
||||||
COMFYUI_WORKFLOW,
|
COMFYUI_WORKFLOW,
|
||||||
COMFYUI_WORKFLOW_NODES,
|
COMFYUI_WORKFLOW_NODES,
|
||||||
IMAGES_OPENAI_API_BASE_URL,
|
CORS_ALLOW_ORIGIN,
|
||||||
IMAGES_OPENAI_API_KEY,
|
ENABLE_IMAGE_GENERATION,
|
||||||
|
IMAGE_GENERATION_ENGINE,
|
||||||
IMAGE_GENERATION_MODEL,
|
IMAGE_GENERATION_MODEL,
|
||||||
IMAGE_SIZE,
|
IMAGE_SIZE,
|
||||||
IMAGE_STEPS,
|
IMAGE_STEPS,
|
||||||
CORS_ALLOW_ORIGIN,
|
IMAGES_OPENAI_API_BASE_URL,
|
||||||
|
IMAGES_OPENAI_API_KEY,
|
||||||
AppConfig,
|
AppConfig,
|
||||||
)
|
)
|
||||||
|
from constants import ERROR_MESSAGES
|
||||||
|
from env import SRC_LOG_LEVELS
|
||||||
|
from fastapi import Depends, FastAPI, HTTPException, Request
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from utils.utils import get_admin_user, get_verified_user
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
|
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
|
||||||
@ -186,7 +176,7 @@ async def verify_url(user=Depends(get_admin_user)):
|
|||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception:
|
||||||
app.state.config.ENABLED = False
|
app.state.config.ENABLED = False
|
||||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
||||||
elif app.state.config.ENGINE == "comfyui":
|
elif app.state.config.ENGINE == "comfyui":
|
||||||
@ -194,7 +184,7 @@ async def verify_url(user=Depends(get_admin_user)):
|
|||||||
r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
|
r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception:
|
||||||
app.state.config.ENABLED = False
|
app.state.config.ENABLED = False
|
||||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
||||||
else:
|
else:
|
||||||
@ -397,7 +387,6 @@ def save_url_image(url):
|
|||||||
r = requests.get(url)
|
r = requests.get(url)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
if r.headers["content-type"].split("/")[0] == "image":
|
if r.headers["content-type"].split("/")[0] == "image":
|
||||||
|
|
||||||
mime_type = r.headers["content-type"]
|
mime_type = r.headers["content-type"]
|
||||||
image_format = mimetypes.guess_extension(mime_type)
|
image_format = mimetypes.guess_extension(mime_type)
|
||||||
|
|
||||||
@ -412,7 +401,7 @@ def save_url_image(url):
|
|||||||
image_file.write(chunk)
|
image_file.write(chunk)
|
||||||
return image_filename
|
return image_filename
|
||||||
else:
|
else:
|
||||||
log.error(f"Url does not point to an image.")
|
log.error("Url does not point to an image.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -430,7 +419,6 @@ async def image_generations(
|
|||||||
r = None
|
r = None
|
||||||
try:
|
try:
|
||||||
if app.state.config.ENGINE == "openai":
|
if app.state.config.ENGINE == "openai":
|
||||||
|
|
||||||
headers = {}
|
headers = {}
|
||||||
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
|
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
|
||||||
headers["Content-Type"] = "application/json"
|
headers["Content-Type"] = "application/json"
|
||||||
|
@ -1,20 +1,18 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
|
|
||||||
import json
|
import json
|
||||||
import urllib.request
|
|
||||||
import urllib.parse
|
|
||||||
import random
|
|
||||||
import logging
|
import logging
|
||||||
|
import random
|
||||||
|
import urllib.parse
|
||||||
|
import urllib.request
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from config import SRC_LOG_LEVELS
|
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
|
||||||
|
from env import SRC_LOG_LEVELS
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["COMFYUI"])
|
log.setLevel(SRC_LOG_LEVELS["COMFYUI"])
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
default_headers = {"User-Agent": "Mozilla/5.0"}
|
default_headers = {"User-Agent": "Mozilla/5.0"}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,54 +1,40 @@
|
|||||||
from fastapi import (
|
|
||||||
FastAPI,
|
|
||||||
Request,
|
|
||||||
HTTPException,
|
|
||||||
Depends,
|
|
||||||
UploadFile,
|
|
||||||
File,
|
|
||||||
)
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import random
|
|
||||||
import requests
|
|
||||||
import json
|
|
||||||
import aiohttp
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
from urllib.parse import urlparse
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from starlette.background import BackgroundTask
|
import aiohttp
|
||||||
|
import requests
|
||||||
from apps.webui.models.models import Models
|
from apps.webui.models.models import Models
|
||||||
from constants import ERROR_MESSAGES
|
|
||||||
from utils.utils import (
|
|
||||||
get_verified_user,
|
|
||||||
get_admin_user,
|
|
||||||
)
|
|
||||||
|
|
||||||
from config import (
|
from config import (
|
||||||
SRC_LOG_LEVELS,
|
|
||||||
OLLAMA_BASE_URLS,
|
|
||||||
ENABLE_OLLAMA_API,
|
|
||||||
AIOHTTP_CLIENT_TIMEOUT,
|
AIOHTTP_CLIENT_TIMEOUT,
|
||||||
|
CORS_ALLOW_ORIGIN,
|
||||||
ENABLE_MODEL_FILTER,
|
ENABLE_MODEL_FILTER,
|
||||||
|
ENABLE_OLLAMA_API,
|
||||||
MODEL_FILTER_LIST,
|
MODEL_FILTER_LIST,
|
||||||
|
OLLAMA_BASE_URLS,
|
||||||
UPLOAD_DIR,
|
UPLOAD_DIR,
|
||||||
AppConfig,
|
AppConfig,
|
||||||
CORS_ALLOW_ORIGIN,
|
|
||||||
)
|
)
|
||||||
|
from constants import ERROR_MESSAGES
|
||||||
|
from env import SRC_LOG_LEVELS
|
||||||
|
from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from starlette.background import BackgroundTask
|
||||||
from utils.misc import (
|
from utils.misc import (
|
||||||
calculate_sha256,
|
|
||||||
apply_model_params_to_body_ollama,
|
apply_model_params_to_body_ollama,
|
||||||
apply_model_params_to_body_openai,
|
apply_model_params_to_body_openai,
|
||||||
apply_model_system_prompt_to_body,
|
apply_model_system_prompt_to_body,
|
||||||
|
calculate_sha256,
|
||||||
)
|
)
|
||||||
|
from utils.utils import get_admin_user, get_verified_user
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
|
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
|
||||||
|
@ -1,44 +1,36 @@
|
|||||||
from fastapi import FastAPI, Request, HTTPException, Depends
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from fastapi.responses import StreamingResponse, FileResponse
|
|
||||||
|
|
||||||
import requests
|
|
||||||
import aiohttp
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal, Optional, overload
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import requests
|
||||||
|
from apps.webui.models.models import Models
|
||||||
|
from config import (
|
||||||
|
AIOHTTP_CLIENT_TIMEOUT,
|
||||||
|
CACHE_DIR,
|
||||||
|
CORS_ALLOW_ORIGIN,
|
||||||
|
ENABLE_MODEL_FILTER,
|
||||||
|
ENABLE_OPENAI_API,
|
||||||
|
MODEL_FILTER_LIST,
|
||||||
|
OPENAI_API_BASE_URLS,
|
||||||
|
OPENAI_API_KEYS,
|
||||||
|
AppConfig,
|
||||||
|
)
|
||||||
|
from constants import ERROR_MESSAGES
|
||||||
|
from env import SRC_LOG_LEVELS
|
||||||
|
from fastapi import Depends, FastAPI, HTTPException, Request
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import FileResponse, StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from starlette.background import BackgroundTask
|
from starlette.background import BackgroundTask
|
||||||
|
|
||||||
from apps.webui.models.models import Models
|
|
||||||
from constants import ERROR_MESSAGES
|
|
||||||
from utils.utils import (
|
|
||||||
get_verified_user,
|
|
||||||
get_admin_user,
|
|
||||||
)
|
|
||||||
from utils.misc import (
|
from utils.misc import (
|
||||||
apply_model_params_to_body_openai,
|
apply_model_params_to_body_openai,
|
||||||
apply_model_system_prompt_to_body,
|
apply_model_system_prompt_to_body,
|
||||||
)
|
)
|
||||||
|
from utils.utils import get_admin_user, get_verified_user
|
||||||
from config import (
|
|
||||||
SRC_LOG_LEVELS,
|
|
||||||
ENABLE_OPENAI_API,
|
|
||||||
AIOHTTP_CLIENT_TIMEOUT,
|
|
||||||
OPENAI_API_BASE_URLS,
|
|
||||||
OPENAI_API_KEYS,
|
|
||||||
CACHE_DIR,
|
|
||||||
ENABLE_MODEL_FILTER,
|
|
||||||
MODEL_FILTER_LIST,
|
|
||||||
AppConfig,
|
|
||||||
CORS_ALLOW_ORIGIN,
|
|
||||||
)
|
|
||||||
from typing import Optional, Literal, overload
|
|
||||||
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["OPENAI"])
|
log.setLevel(SRC_LOG_LEVELS["OPENAI"])
|
||||||
|
@ -1,143 +1,118 @@
|
|||||||
from fastapi import (
|
|
||||||
FastAPI,
|
|
||||||
Depends,
|
|
||||||
HTTPException,
|
|
||||||
status,
|
|
||||||
UploadFile,
|
|
||||||
File,
|
|
||||||
Form,
|
|
||||||
)
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
import requests
|
|
||||||
import os, shutil, logging, re
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Union, Sequence, Iterator, Any
|
|
||||||
|
|
||||||
from chromadb.utils.batch_utils import create_batches
|
|
||||||
from langchain_core.documents import Document
|
|
||||||
|
|
||||||
from langchain_community.document_loaders import (
|
|
||||||
WebBaseLoader,
|
|
||||||
TextLoader,
|
|
||||||
PyPDFLoader,
|
|
||||||
CSVLoader,
|
|
||||||
BSHTMLLoader,
|
|
||||||
Docx2txtLoader,
|
|
||||||
UnstructuredEPubLoader,
|
|
||||||
UnstructuredWordDocumentLoader,
|
|
||||||
UnstructuredMarkdownLoader,
|
|
||||||
UnstructuredXMLLoader,
|
|
||||||
UnstructuredRSTLoader,
|
|
||||||
UnstructuredExcelLoader,
|
|
||||||
UnstructuredPowerPointLoader,
|
|
||||||
YoutubeLoader,
|
|
||||||
OutlookMessageLoader,
|
|
||||||
)
|
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
||||||
|
|
||||||
import validators
|
|
||||||
import urllib.parse
|
|
||||||
import socket
|
|
||||||
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing import Optional
|
|
||||||
import mimetypes
|
|
||||||
import uuid
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
import mimetypes
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import socket
|
||||||
|
import urllib.parse
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterator, Optional, Sequence, Union
|
||||||
|
|
||||||
from apps.webui.models.documents import (
|
import requests
|
||||||
Documents,
|
import validators
|
||||||
DocumentForm,
|
|
||||||
DocumentResponse,
|
|
||||||
)
|
|
||||||
from apps.webui.models.files import (
|
|
||||||
Files,
|
|
||||||
)
|
|
||||||
|
|
||||||
from apps.rag.utils import (
|
|
||||||
get_model_path,
|
|
||||||
get_embedding_function,
|
|
||||||
query_doc,
|
|
||||||
query_doc_with_hybrid_search,
|
|
||||||
query_collection,
|
|
||||||
query_collection_with_hybrid_search,
|
|
||||||
)
|
|
||||||
|
|
||||||
from apps.rag.search.brave import search_brave
|
from apps.rag.search.brave import search_brave
|
||||||
|
from apps.rag.search.duckduckgo import search_duckduckgo
|
||||||
from apps.rag.search.google_pse import search_google_pse
|
from apps.rag.search.google_pse import search_google_pse
|
||||||
|
from apps.rag.search.jina_search import search_jina
|
||||||
from apps.rag.search.main import SearchResult
|
from apps.rag.search.main import SearchResult
|
||||||
|
from apps.rag.search.searchapi import search_searchapi
|
||||||
from apps.rag.search.searxng import search_searxng
|
from apps.rag.search.searxng import search_searxng
|
||||||
from apps.rag.search.serper import search_serper
|
from apps.rag.search.serper import search_serper
|
||||||
from apps.rag.search.serpstack import search_serpstack
|
|
||||||
from apps.rag.search.serply import search_serply
|
from apps.rag.search.serply import search_serply
|
||||||
from apps.rag.search.duckduckgo import search_duckduckgo
|
from apps.rag.search.serpstack import search_serpstack
|
||||||
from apps.rag.search.tavily import search_tavily
|
from apps.rag.search.tavily import search_tavily
|
||||||
from apps.rag.search.jina_search import search_jina
|
from apps.rag.utils import (
|
||||||
from apps.rag.search.searchapi import search_searchapi
|
get_embedding_function,
|
||||||
|
get_model_path,
|
||||||
from utils.misc import (
|
query_collection,
|
||||||
calculate_sha256,
|
query_collection_with_hybrid_search,
|
||||||
calculate_sha256_string,
|
query_doc,
|
||||||
sanitize_filename,
|
query_doc_with_hybrid_search,
|
||||||
extract_folders_after_data_docs,
|
|
||||||
)
|
)
|
||||||
from utils.utils import get_verified_user, get_admin_user
|
from apps.webui.models.documents import DocumentForm, Documents
|
||||||
|
from apps.webui.models.files import Files
|
||||||
|
from chromadb.utils.batch_utils import create_batches
|
||||||
from config import (
|
from config import (
|
||||||
AppConfig,
|
BRAVE_SEARCH_API_KEY,
|
||||||
ENV,
|
CHROMA_CLIENT,
|
||||||
SRC_LOG_LEVELS,
|
CHUNK_OVERLAP,
|
||||||
UPLOAD_DIR,
|
CHUNK_SIZE,
|
||||||
DOCS_DIR,
|
|
||||||
CONTENT_EXTRACTION_ENGINE,
|
CONTENT_EXTRACTION_ENGINE,
|
||||||
TIKA_SERVER_URL,
|
CORS_ALLOW_ORIGIN,
|
||||||
RAG_TOP_K,
|
DEVICE_TYPE,
|
||||||
RAG_RELEVANCE_THRESHOLD,
|
DOCS_DIR,
|
||||||
RAG_FILE_MAX_SIZE,
|
ENABLE_RAG_HYBRID_SEARCH,
|
||||||
RAG_FILE_MAX_COUNT,
|
ENABLE_RAG_LOCAL_WEB_FETCH,
|
||||||
|
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||||
|
ENABLE_RAG_WEB_SEARCH,
|
||||||
|
ENV,
|
||||||
|
GOOGLE_PSE_API_KEY,
|
||||||
|
GOOGLE_PSE_ENGINE_ID,
|
||||||
|
PDF_EXTRACT_IMAGES,
|
||||||
RAG_EMBEDDING_ENGINE,
|
RAG_EMBEDDING_ENGINE,
|
||||||
RAG_EMBEDDING_MODEL,
|
RAG_EMBEDDING_MODEL,
|
||||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||||
ENABLE_RAG_HYBRID_SEARCH,
|
RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||||
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
RAG_FILE_MAX_COUNT,
|
||||||
RAG_RERANKING_MODEL,
|
RAG_FILE_MAX_SIZE,
|
||||||
PDF_EXTRACT_IMAGES,
|
|
||||||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
|
||||||
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
|
||||||
RAG_OPENAI_API_BASE_URL,
|
RAG_OPENAI_API_BASE_URL,
|
||||||
RAG_OPENAI_API_KEY,
|
RAG_OPENAI_API_KEY,
|
||||||
DEVICE_TYPE,
|
RAG_RELEVANCE_THRESHOLD,
|
||||||
CHROMA_CLIENT,
|
RAG_RERANKING_MODEL,
|
||||||
CHUNK_SIZE,
|
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||||
CHUNK_OVERLAP,
|
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
||||||
RAG_TEMPLATE,
|
RAG_TEMPLATE,
|
||||||
ENABLE_RAG_LOCAL_WEB_FETCH,
|
RAG_TOP_K,
|
||||||
YOUTUBE_LOADER_LANGUAGE,
|
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||||
ENABLE_RAG_WEB_SEARCH,
|
|
||||||
RAG_WEB_SEARCH_ENGINE,
|
|
||||||
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||||
SEARXNG_QUERY_URL,
|
RAG_WEB_SEARCH_ENGINE,
|
||||||
GOOGLE_PSE_API_KEY,
|
RAG_WEB_SEARCH_RESULT_COUNT,
|
||||||
GOOGLE_PSE_ENGINE_ID,
|
|
||||||
BRAVE_SEARCH_API_KEY,
|
|
||||||
SERPSTACK_API_KEY,
|
|
||||||
SERPSTACK_HTTPS,
|
|
||||||
SERPER_API_KEY,
|
|
||||||
SERPLY_API_KEY,
|
|
||||||
TAVILY_API_KEY,
|
|
||||||
SEARCHAPI_API_KEY,
|
SEARCHAPI_API_KEY,
|
||||||
SEARCHAPI_ENGINE,
|
SEARCHAPI_ENGINE,
|
||||||
RAG_WEB_SEARCH_RESULT_COUNT,
|
SEARXNG_QUERY_URL,
|
||||||
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
SERPER_API_KEY,
|
||||||
RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
SERPLY_API_KEY,
|
||||||
CORS_ALLOW_ORIGIN,
|
SERPSTACK_API_KEY,
|
||||||
|
SERPSTACK_HTTPS,
|
||||||
|
TAVILY_API_KEY,
|
||||||
|
TIKA_SERVER_URL,
|
||||||
|
UPLOAD_DIR,
|
||||||
|
YOUTUBE_LOADER_LANGUAGE,
|
||||||
|
AppConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
from env import SRC_LOG_LEVELS
|
||||||
|
from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
from langchain_community.document_loaders import (
|
||||||
|
BSHTMLLoader,
|
||||||
|
CSVLoader,
|
||||||
|
Docx2txtLoader,
|
||||||
|
OutlookMessageLoader,
|
||||||
|
PyPDFLoader,
|
||||||
|
TextLoader,
|
||||||
|
UnstructuredEPubLoader,
|
||||||
|
UnstructuredExcelLoader,
|
||||||
|
UnstructuredMarkdownLoader,
|
||||||
|
UnstructuredPowerPointLoader,
|
||||||
|
UnstructuredRSTLoader,
|
||||||
|
UnstructuredXMLLoader,
|
||||||
|
WebBaseLoader,
|
||||||
|
YoutubeLoader,
|
||||||
|
)
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from utils.misc import (
|
||||||
|
calculate_sha256,
|
||||||
|
calculate_sha256_string,
|
||||||
|
extract_folders_after_data_docs,
|
||||||
|
sanitize_filename,
|
||||||
|
)
|
||||||
|
from utils.utils import get_admin_user, get_verified_user
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
@ -539,9 +514,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
|
|||||||
app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
|
app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
|
||||||
app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
|
app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
|
||||||
app.state.config.SEARCHAPI_API_KEY = form_data.web.search.searchapi_api_key
|
app.state.config.SEARCHAPI_API_KEY = form_data.web.search.searchapi_api_key
|
||||||
app.state.config.SEARCHAPI_ENGINE = (
|
app.state.config.SEARCHAPI_ENGINE = form_data.web.search.searchapi_engine
|
||||||
form_data.web.search.searchapi_engine
|
|
||||||
)
|
|
||||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
|
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
|
||||||
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
|
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
|
||||||
form_data.web.search.concurrent_requests
|
form_data.web.search.concurrent_requests
|
||||||
@ -981,7 +954,6 @@ def store_web_search(form_data: SearchForm, user=Depends(get_verified_user)):
|
|||||||
def store_data_in_vector_db(
|
def store_data_in_vector_db(
|
||||||
data, collection_name, metadata: Optional[dict] = None, overwrite: bool = False
|
data, collection_name, metadata: Optional[dict] = None, overwrite: bool = False
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
chunk_size=app.state.config.CHUNK_SIZE,
|
chunk_size=app.state.config.CHUNK_SIZE,
|
||||||
chunk_overlap=app.state.config.CHUNK_OVERLAP,
|
chunk_overlap=app.state.config.CHUNK_OVERLAP,
|
||||||
@ -1342,7 +1314,6 @@ def store_text(
|
|||||||
form_data: TextRAGForm,
|
form_data: TextRAGForm,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
|
|
||||||
collection_name = form_data.collection_name
|
collection_name = form_data.collection_name
|
||||||
if collection_name is None:
|
if collection_name is None:
|
||||||
collection_name = calculate_sha256_string(form_data.content)
|
collection_name = calculate_sha256_string(form_data.content)
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import requests
|
|
||||||
|
|
||||||
|
import requests
|
||||||
from apps.rag.search.main import SearchResult, get_filtered_results
|
from apps.rag.search.main import SearchResult, get_filtered_results
|
||||||
from config import SRC_LOG_LEVELS
|
from env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from apps.rag.search.main import SearchResult, get_filtered_results
|
from apps.rag.search.main import SearchResult, get_filtered_results
|
||||||
from duckduckgo_search import DDGS
|
from duckduckgo_search import DDGS
|
||||||
from config import SRC_LOG_LEVELS
|
from env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import requests
|
|
||||||
|
|
||||||
|
import requests
|
||||||
from apps.rag.search.main import SearchResult, get_filtered_results
|
from apps.rag.search.main import SearchResult, get_filtered_results
|
||||||
from config import SRC_LOG_LEVELS
|
from env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
import requests
|
|
||||||
from yarl import URL
|
|
||||||
|
|
||||||
|
import requests
|
||||||
from apps.rag.search.main import SearchResult
|
from apps.rag.search.main import SearchResult
|
||||||
from config import SRC_LOG_LEVELS
|
from env import SRC_LOG_LEVELS
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
import requests
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
from apps.rag.search.main import SearchResult, get_filtered_results
|
from apps.rag.search.main import SearchResult, get_filtered_results
|
||||||
from config import SRC_LOG_LEVELS
|
from env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import requests
|
|
||||||
|
|
||||||
|
import requests
|
||||||
from apps.rag.search.main import SearchResult, get_filtered_results
|
from apps.rag.search.main import SearchResult, get_filtered_results
|
||||||
from config import SRC_LOG_LEVELS
|
from env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import requests
|
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
import requests
|
||||||
from apps.rag.search.main import SearchResult, get_filtered_results
|
from apps.rag.search.main import SearchResult, get_filtered_results
|
||||||
from config import SRC_LOG_LEVELS
|
from env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import requests
|
|
||||||
|
|
||||||
|
import requests
|
||||||
from apps.rag.search.main import SearchResult, get_filtered_results
|
from apps.rag.search.main import SearchResult, get_filtered_results
|
||||||
from config import SRC_LOG_LEVELS
|
from env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from apps.rag.search.main import SearchResult
|
from apps.rag.search.main import SearchResult
|
||||||
from config import SRC_LOG_LEVELS
|
from env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
@ -1,27 +1,16 @@
|
|||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from apps.ollama.main import GenerateEmbeddingsForm, generate_ollama_embeddings
|
||||||
from typing import Union
|
from config import CHROMA_CLIENT
|
||||||
|
from env import SRC_LOG_LEVELS
|
||||||
from apps.ollama.main import (
|
|
||||||
generate_ollama_embeddings,
|
|
||||||
GenerateEmbeddingsForm,
|
|
||||||
)
|
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
|
||||||
from langchain_core.documents import Document
|
|
||||||
from langchain_community.retrievers import BM25Retriever
|
from langchain_community.retrievers import BM25Retriever
|
||||||
from langchain.retrievers import (
|
from langchain_core.documents import Document
|
||||||
ContextualCompressionRetriever,
|
from utils.misc import get_last_user_message
|
||||||
EnsembleRetriever,
|
|
||||||
)
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from utils.misc import get_last_user_message, add_or_update_system_message
|
|
||||||
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
@ -261,7 +250,9 @@ def get_rag_context(
|
|||||||
collection_names = (
|
collection_names = (
|
||||||
file["collection_names"]
|
file["collection_names"]
|
||||||
if file["type"] == "collection"
|
if file["type"] == "collection"
|
||||||
else [file["collection_name"]] if file["collection_name"] else []
|
else [file["collection_name"]]
|
||||||
|
if file["collection_name"]
|
||||||
|
else []
|
||||||
)
|
)
|
||||||
|
|
||||||
collection_names = set(collection_names).difference(extracted_collections)
|
collection_names = set(collection_names).difference(extracted_collections)
|
||||||
@ -401,8 +392,8 @@ def generate_openai_batch_embeddings(
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.retrievers import BaseRetriever
|
|
||||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||||
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
|
||||||
|
|
||||||
class ChromaRetriever(BaseRetriever):
|
class ChromaRetriever(BaseRetriever):
|
||||||
@ -439,11 +430,10 @@ class ChromaRetriever(BaseRetriever):
|
|||||||
|
|
||||||
|
|
||||||
import operator
|
import operator
|
||||||
|
|
||||||
from typing import Optional, Sequence
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
|
||||||
from langchain_core.callbacks import Callbacks
|
from langchain_core.callbacks import Callbacks
|
||||||
|
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||||
from langchain_core.pydantic_v1 import Extra
|
from langchain_core.pydantic_v1 import Extra
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import socketio
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
import socketio
|
||||||
from apps.webui.models.users import Users
|
from apps.webui.models.users import Users
|
||||||
from utils.utils import decode_token
|
from utils.utils import decode_token
|
||||||
|
|
||||||
|
@ -1,21 +1,16 @@
|
|||||||
import os
|
|
||||||
import logging
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
from typing import Optional, Any
|
|
||||||
from typing_extensions import Self
|
|
||||||
|
|
||||||
from sqlalchemy import create_engine, types, Dialect
|
|
||||||
from sqlalchemy.sql.type_api import _T
|
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
|
||||||
|
|
||||||
|
|
||||||
from peewee_migrate import Router
|
|
||||||
from apps.webui.internal.wrappers import register_connection
|
from apps.webui.internal.wrappers import register_connection
|
||||||
from env import SRC_LOG_LEVELS, BACKEND_DIR, DATABASE_URL
|
from env import BACKEND_DIR, DATABASE_URL, SRC_LOG_LEVELS
|
||||||
|
from peewee_migrate import Router
|
||||||
|
from sqlalchemy import Dialect, create_engine, types
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||||
|
from sqlalchemy.sql.type_api import _T
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["DB"])
|
log.setLevel(SRC_LOG_LEVELS["DB"])
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
from contextvars import ContextVar
|
|
||||||
from peewee import *
|
|
||||||
from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from playhouse.db_url import connect, parse
|
from contextvars import ContextVar
|
||||||
from playhouse.shortcuts import ReconnectMixin
|
|
||||||
|
|
||||||
from env import SRC_LOG_LEVELS
|
from env import SRC_LOG_LEVELS
|
||||||
|
from peewee import *
|
||||||
|
from peewee import InterfaceError as PeeWeeInterfaceError
|
||||||
|
from peewee import PostgresqlDatabase
|
||||||
|
from playhouse.db_url import connect, parse
|
||||||
|
from playhouse.shortcuts import ReconnectMixin
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["DB"])
|
log.setLevel(SRC_LOG_LEVELS["DB"])
|
||||||
|
@ -1,65 +1,59 @@
|
|||||||
from fastapi import FastAPI
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from apps.webui.routers import (
|
|
||||||
auths,
|
|
||||||
users,
|
|
||||||
chats,
|
|
||||||
documents,
|
|
||||||
tools,
|
|
||||||
models,
|
|
||||||
prompts,
|
|
||||||
configs,
|
|
||||||
memories,
|
|
||||||
utils,
|
|
||||||
files,
|
|
||||||
functions,
|
|
||||||
)
|
|
||||||
from apps.webui.models.functions import Functions
|
|
||||||
from apps.webui.models.models import Models
|
|
||||||
from apps.webui.utils import load_function_module_by_id
|
|
||||||
|
|
||||||
from utils.misc import (
|
|
||||||
openai_chat_chunk_message_template,
|
|
||||||
openai_chat_completion_message_template,
|
|
||||||
apply_model_params_to_body_openai,
|
|
||||||
apply_model_system_prompt_to_body,
|
|
||||||
)
|
|
||||||
|
|
||||||
from utils.tools import get_tools
|
|
||||||
|
|
||||||
from config import (
|
|
||||||
SHOW_ADMIN_DETAILS,
|
|
||||||
ADMIN_EMAIL,
|
|
||||||
WEBUI_AUTH,
|
|
||||||
DEFAULT_MODELS,
|
|
||||||
DEFAULT_PROMPT_SUGGESTIONS,
|
|
||||||
DEFAULT_USER_ROLE,
|
|
||||||
ENABLE_SIGNUP,
|
|
||||||
ENABLE_LOGIN_FORM,
|
|
||||||
USER_PERMISSIONS,
|
|
||||||
WEBHOOK_URL,
|
|
||||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
|
||||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
|
||||||
JWT_EXPIRES_IN,
|
|
||||||
WEBUI_BANNERS,
|
|
||||||
ENABLE_COMMUNITY_SHARING,
|
|
||||||
ENABLE_MESSAGE_RATING,
|
|
||||||
AppConfig,
|
|
||||||
OAUTH_USERNAME_CLAIM,
|
|
||||||
OAUTH_PICTURE_CLAIM,
|
|
||||||
OAUTH_EMAIL_CLAIM,
|
|
||||||
CORS_ALLOW_ORIGIN,
|
|
||||||
)
|
|
||||||
|
|
||||||
from apps.socket.main import get_event_call, get_event_emitter
|
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from typing import AsyncGenerator, Generator, Iterator
|
||||||
|
|
||||||
from typing import Iterator, Generator, AsyncGenerator
|
from apps.socket.main import get_event_call, get_event_emitter
|
||||||
|
from apps.webui.models.functions import Functions
|
||||||
|
from apps.webui.models.models import Models
|
||||||
|
from apps.webui.routers import (
|
||||||
|
auths,
|
||||||
|
chats,
|
||||||
|
configs,
|
||||||
|
documents,
|
||||||
|
files,
|
||||||
|
functions,
|
||||||
|
memories,
|
||||||
|
models,
|
||||||
|
prompts,
|
||||||
|
tools,
|
||||||
|
users,
|
||||||
|
utils,
|
||||||
|
)
|
||||||
|
from apps.webui.utils import load_function_module_by_id
|
||||||
|
from config import (
|
||||||
|
ADMIN_EMAIL,
|
||||||
|
CORS_ALLOW_ORIGIN,
|
||||||
|
DEFAULT_MODELS,
|
||||||
|
DEFAULT_PROMPT_SUGGESTIONS,
|
||||||
|
DEFAULT_USER_ROLE,
|
||||||
|
ENABLE_COMMUNITY_SHARING,
|
||||||
|
ENABLE_LOGIN_FORM,
|
||||||
|
ENABLE_MESSAGE_RATING,
|
||||||
|
ENABLE_SIGNUP,
|
||||||
|
JWT_EXPIRES_IN,
|
||||||
|
OAUTH_EMAIL_CLAIM,
|
||||||
|
OAUTH_PICTURE_CLAIM,
|
||||||
|
OAUTH_USERNAME_CLAIM,
|
||||||
|
SHOW_ADMIN_DETAILS,
|
||||||
|
USER_PERMISSIONS,
|
||||||
|
WEBHOOK_URL,
|
||||||
|
WEBUI_AUTH,
|
||||||
|
WEBUI_BANNERS,
|
||||||
|
AppConfig,
|
||||||
|
)
|
||||||
|
from env import WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from utils.misc import (
|
||||||
|
apply_model_params_to_body_openai,
|
||||||
|
apply_model_system_prompt_to_body,
|
||||||
|
openai_chat_chunk_message_template,
|
||||||
|
openai_chat_completion_message_template,
|
||||||
|
)
|
||||||
|
from utils.tools import get_tools
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
@ -1,15 +1,13 @@
|
|||||||
from pydantic import BaseModel
|
|
||||||
from typing import Optional
|
|
||||||
import uuid
|
|
||||||
import logging
|
import logging
|
||||||
from sqlalchemy import String, Column, Boolean, Text
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from utils.utils import verify_password
|
|
||||||
|
|
||||||
from apps.webui.models.users import UserModel, Users
|
|
||||||
from apps.webui.internal.db import Base, get_db
|
from apps.webui.internal.db import Base, get_db
|
||||||
|
from apps.webui.models.users import UserModel, Users
|
||||||
from env import SRC_LOG_LEVELS
|
from env import SRC_LOG_LEVELS
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import Boolean, Column, String, Text
|
||||||
|
from utils.utils import verify_password
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
@ -92,7 +90,6 @@ class AddUserForm(SignupForm):
|
|||||||
|
|
||||||
|
|
||||||
class AuthsTable:
|
class AuthsTable:
|
||||||
|
|
||||||
def insert_new_auth(
|
def insert_new_auth(
|
||||||
self,
|
self,
|
||||||
email: str,
|
email: str,
|
||||||
@ -103,7 +100,6 @@ class AuthsTable:
|
|||||||
oauth_sub: Optional[str] = None,
|
oauth_sub: Optional[str] = None,
|
||||||
) -> Optional[UserModel]:
|
) -> Optional[UserModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
log.info("insert_new_auth")
|
log.info("insert_new_auth")
|
||||||
|
|
||||||
id = str(uuid.uuid4())
|
id = str(uuid.uuid4())
|
||||||
@ -130,7 +126,6 @@ class AuthsTable:
|
|||||||
log.info(f"authenticate_user: {email}")
|
log.info(f"authenticate_user: {email}")
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
auth = db.query(Auth).filter_by(email=email, active=True).first()
|
auth = db.query(Auth).filter_by(email=email, active=True).first()
|
||||||
if auth:
|
if auth:
|
||||||
if verify_password(password, auth.password):
|
if verify_password(password, auth.password):
|
||||||
@ -189,7 +184,6 @@ class AuthsTable:
|
|||||||
def delete_auth_by_id(self, id: str) -> bool:
|
def delete_auth_by_id(self, id: str) -> bool:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
# Delete User
|
# Delete User
|
||||||
result = Users.delete_user_by_id(id)
|
result = Users.delete_user_by_id(id)
|
||||||
|
|
||||||
|
@ -1,14 +1,11 @@
|
|||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
from typing import Union, Optional
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from sqlalchemy import Column, String, BigInteger, Boolean, Text
|
from typing import Optional
|
||||||
|
|
||||||
from apps.webui.internal.db import Base, get_db
|
from apps.webui.internal.db import Base, get_db
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from sqlalchemy import BigInteger, Boolean, Column, String, Text
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# Chat DB Schema
|
# Chat DB Schema
|
||||||
@ -77,10 +74,8 @@ class ChatTitleIdResponse(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ChatTable:
|
class ChatTable:
|
||||||
|
|
||||||
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
|
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
id = str(uuid.uuid4())
|
id = str(uuid.uuid4())
|
||||||
chat = ChatModel(
|
chat = ChatModel(
|
||||||
**{
|
**{
|
||||||
@ -106,7 +101,6 @@ class ChatTable:
|
|||||||
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
|
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
chat_obj = db.get(Chat, id)
|
chat_obj = db.get(Chat, id)
|
||||||
chat_obj.chat = json.dumps(chat)
|
chat_obj.chat = json.dumps(chat)
|
||||||
chat_obj.title = chat["title"] if "title" in chat else "New Chat"
|
chat_obj.title = chat["title"] if "title" in chat else "New Chat"
|
||||||
@ -115,12 +109,11 @@ class ChatTable:
|
|||||||
db.refresh(chat_obj)
|
db.refresh(chat_obj)
|
||||||
|
|
||||||
return ChatModel.model_validate(chat_obj)
|
return ChatModel.model_validate(chat_obj)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
|
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
# Get the existing chat to share
|
# Get the existing chat to share
|
||||||
chat = db.get(Chat, chat_id)
|
chat = db.get(Chat, chat_id)
|
||||||
# Check if the chat is already shared
|
# Check if the chat is already shared
|
||||||
@ -154,7 +147,6 @@ class ChatTable:
|
|||||||
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
|
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
print("update_shared_chat_by_id")
|
print("update_shared_chat_by_id")
|
||||||
chat = db.get(Chat, chat_id)
|
chat = db.get(Chat, chat_id)
|
||||||
print(chat)
|
print(chat)
|
||||||
@ -170,7 +162,6 @@ class ChatTable:
|
|||||||
def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
|
def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
|
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
@ -183,7 +174,6 @@ class ChatTable:
|
|||||||
) -> Optional[ChatModel]:
|
) -> Optional[ChatModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
chat = db.get(Chat, id)
|
chat = db.get(Chat, id)
|
||||||
chat.share_id = share_id
|
chat.share_id = share_id
|
||||||
db.commit()
|
db.commit()
|
||||||
@ -195,7 +185,6 @@ class ChatTable:
|
|||||||
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
|
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
chat = db.get(Chat, id)
|
chat = db.get(Chat, id)
|
||||||
chat.archived = not chat.archived
|
chat.archived = not chat.archived
|
||||||
db.commit()
|
db.commit()
|
||||||
@ -217,7 +206,6 @@ class ChatTable:
|
|||||||
self, user_id: str, skip: int = 0, limit: int = 50
|
self, user_id: str, skip: int = 0, limit: int = 50
|
||||||
) -> list[ChatModel]:
|
) -> list[ChatModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
all_chats = (
|
all_chats = (
|
||||||
db.query(Chat)
|
db.query(Chat)
|
||||||
.filter_by(user_id=user_id, archived=True)
|
.filter_by(user_id=user_id, archived=True)
|
||||||
@ -297,7 +285,6 @@ class ChatTable:
|
|||||||
def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
|
def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
chat = db.get(Chat, id)
|
chat = db.get(Chat, id)
|
||||||
return ChatModel.model_validate(chat)
|
return ChatModel.model_validate(chat)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -306,20 +293,18 @@ class ChatTable:
|
|||||||
def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
|
def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
chat = db.query(Chat).filter_by(share_id=id).first()
|
chat = db.query(Chat).filter_by(share_id=id).first()
|
||||||
|
|
||||||
if chat:
|
if chat:
|
||||||
return self.get_chat_by_id(id)
|
return self.get_chat_by_id(id)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
|
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
|
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
|
||||||
return ChatModel.model_validate(chat)
|
return ChatModel.model_validate(chat)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -327,7 +312,6 @@ class ChatTable:
|
|||||||
|
|
||||||
def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]:
|
def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
all_chats = (
|
all_chats = (
|
||||||
db.query(Chat)
|
db.query(Chat)
|
||||||
# .limit(limit).offset(skip)
|
# .limit(limit).offset(skip)
|
||||||
@ -337,7 +321,6 @@ class ChatTable:
|
|||||||
|
|
||||||
def get_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
|
def get_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
all_chats = (
|
all_chats = (
|
||||||
db.query(Chat)
|
db.query(Chat)
|
||||||
.filter_by(user_id=user_id)
|
.filter_by(user_id=user_id)
|
||||||
@ -347,7 +330,6 @@ class ChatTable:
|
|||||||
|
|
||||||
def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
|
def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
all_chats = (
|
all_chats = (
|
||||||
db.query(Chat)
|
db.query(Chat)
|
||||||
.filter_by(user_id=user_id, archived=True)
|
.filter_by(user_id=user_id, archived=True)
|
||||||
@ -358,7 +340,6 @@ class ChatTable:
|
|||||||
def delete_chat_by_id(self, id: str) -> bool:
|
def delete_chat_by_id(self, id: str) -> bool:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
db.query(Chat).filter_by(id=id).delete()
|
db.query(Chat).filter_by(id=id).delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
@ -369,7 +350,6 @@ class ChatTable:
|
|||||||
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
db.query(Chat).filter_by(id=id, user_id=user_id).delete()
|
db.query(Chat).filter_by(id=id, user_id=user_id).delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
@ -379,9 +359,7 @@ class ChatTable:
|
|||||||
|
|
||||||
def delete_chats_by_user_id(self, user_id: str) -> bool:
|
def delete_chats_by_user_id(self, user_id: str) -> bool:
|
||||||
try:
|
try:
|
||||||
|
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
self.delete_shared_chats_by_user_id(user_id)
|
self.delete_shared_chats_by_user_id(user_id)
|
||||||
|
|
||||||
db.query(Chat).filter_by(user_id=user_id).delete()
|
db.query(Chat).filter_by(user_id=user_id).delete()
|
||||||
@ -393,9 +371,7 @@ class ChatTable:
|
|||||||
|
|
||||||
def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
|
def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
|
||||||
try:
|
try:
|
||||||
|
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
|
chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
|
||||||
shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
|
shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
|
||||||
|
|
||||||
|
@ -1,15 +1,12 @@
|
|||||||
from pydantic import BaseModel, ConfigDict
|
import json
|
||||||
from typing import Optional
|
|
||||||
import time
|
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from sqlalchemy import String, Column, BigInteger, Text
|
from typing import Optional
|
||||||
|
|
||||||
from apps.webui.internal.db import Base, get_db
|
from apps.webui.internal.db import Base, get_db
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
from env import SRC_LOG_LEVELS
|
from env import SRC_LOG_LEVELS
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from sqlalchemy import BigInteger, Column, String, Text
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
@ -70,12 +67,10 @@ class DocumentForm(DocumentUpdateForm):
|
|||||||
|
|
||||||
|
|
||||||
class DocumentsTable:
|
class DocumentsTable:
|
||||||
|
|
||||||
def insert_new_doc(
|
def insert_new_doc(
|
||||||
self, user_id: str, form_data: DocumentForm
|
self, user_id: str, form_data: DocumentForm
|
||||||
) -> Optional[DocumentModel]:
|
) -> Optional[DocumentModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
document = DocumentModel(
|
document = DocumentModel(
|
||||||
**{
|
**{
|
||||||
**form_data.model_dump(),
|
**form_data.model_dump(),
|
||||||
@ -99,7 +94,6 @@ class DocumentsTable:
|
|||||||
def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
|
def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
document = db.query(Document).filter_by(name=name).first()
|
document = db.query(Document).filter_by(name=name).first()
|
||||||
return DocumentModel.model_validate(document) if document else None
|
return DocumentModel.model_validate(document) if document else None
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -107,7 +101,6 @@ class DocumentsTable:
|
|||||||
|
|
||||||
def get_docs(self) -> list[DocumentModel]:
|
def get_docs(self) -> list[DocumentModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
return [
|
return [
|
||||||
DocumentModel.model_validate(doc) for doc in db.query(Document).all()
|
DocumentModel.model_validate(doc) for doc in db.query(Document).all()
|
||||||
]
|
]
|
||||||
@ -117,7 +110,6 @@ class DocumentsTable:
|
|||||||
) -> Optional[DocumentModel]:
|
) -> Optional[DocumentModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
db.query(Document).filter_by(name=name).update(
|
db.query(Document).filter_by(name=name).update(
|
||||||
{
|
{
|
||||||
"title": form_data.title,
|
"title": form_data.title,
|
||||||
@ -140,7 +132,6 @@ class DocumentsTable:
|
|||||||
doc_content = {**doc_content, **updated}
|
doc_content = {**doc_content, **updated}
|
||||||
|
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
db.query(Document).filter_by(name=name).update(
|
db.query(Document).filter_by(name=name).update(
|
||||||
{
|
{
|
||||||
"content": json.dumps(doc_content),
|
"content": json.dumps(doc_content),
|
||||||
@ -156,7 +147,6 @@ class DocumentsTable:
|
|||||||
def delete_doc_by_name(self, name: str) -> bool:
|
def delete_doc_by_name(self, name: str) -> bool:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
db.query(Document).filter_by(name=name).delete()
|
db.query(Document).filter_by(name=name).delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
return True
|
return True
|
||||||
|
@ -1,15 +1,11 @@
|
|||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
from typing import Union, Optional
|
|
||||||
import time
|
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy import Column, String, BigInteger, Text
|
from apps.webui.internal.db import Base, JSONField, get_db
|
||||||
|
|
||||||
from apps.webui.internal.db import JSONField, Base, get_db
|
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
from env import SRC_LOG_LEVELS
|
from env import SRC_LOG_LEVELS
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from sqlalchemy import BigInteger, Column, String, Text
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
@ -59,10 +55,8 @@ class FileForm(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class FilesTable:
|
class FilesTable:
|
||||||
|
|
||||||
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
|
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
file = FileModel(
|
file = FileModel(
|
||||||
**{
|
**{
|
||||||
**form_data.model_dump(),
|
**form_data.model_dump(),
|
||||||
@ -86,7 +80,6 @@ class FilesTable:
|
|||||||
|
|
||||||
def get_file_by_id(self, id: str) -> Optional[FileModel]:
|
def get_file_by_id(self, id: str) -> Optional[FileModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
file = db.get(File, id)
|
file = db.get(File, id)
|
||||||
return FileModel.model_validate(file)
|
return FileModel.model_validate(file)
|
||||||
@ -95,7 +88,6 @@ class FilesTable:
|
|||||||
|
|
||||||
def get_files(self) -> list[FileModel]:
|
def get_files(self) -> list[FileModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
return [FileModel.model_validate(file) for file in db.query(File).all()]
|
return [FileModel.model_validate(file) for file in db.query(File).all()]
|
||||||
|
|
||||||
def get_files_by_user_id(self, user_id: str) -> list[FileModel]:
|
def get_files_by_user_id(self, user_id: str) -> list[FileModel]:
|
||||||
@ -106,9 +98,7 @@ class FilesTable:
|
|||||||
]
|
]
|
||||||
|
|
||||||
def delete_file_by_id(self, id: str) -> bool:
|
def delete_file_by_id(self, id: str) -> bool:
|
||||||
|
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db.query(File).filter_by(id=id).delete()
|
db.query(File).filter_by(id=id).delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
@ -118,9 +108,7 @@ class FilesTable:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def delete_all_files(self) -> bool:
|
def delete_all_files(self) -> bool:
|
||||||
|
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db.query(File).delete()
|
db.query(File).delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
|
@ -1,18 +1,12 @@
|
|||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
from typing import Union, Optional
|
|
||||||
import time
|
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy import Column, String, Text, BigInteger, Boolean
|
from apps.webui.internal.db import Base, JSONField, get_db
|
||||||
|
|
||||||
from apps.webui.internal.db import JSONField, Base, get_db
|
|
||||||
from apps.webui.models.users import Users
|
from apps.webui.models.users import Users
|
||||||
|
|
||||||
import json
|
|
||||||
import copy
|
|
||||||
|
|
||||||
|
|
||||||
from env import SRC_LOG_LEVELS
|
from env import SRC_LOG_LEVELS
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from sqlalchemy import BigInteger, Boolean, Column, String, Text
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
@ -87,11 +81,9 @@ class FunctionValves(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class FunctionsTable:
|
class FunctionsTable:
|
||||||
|
|
||||||
def insert_new_function(
|
def insert_new_function(
|
||||||
self, user_id: str, type: str, form_data: FunctionForm
|
self, user_id: str, type: str, form_data: FunctionForm
|
||||||
) -> Optional[FunctionModel]:
|
) -> Optional[FunctionModel]:
|
||||||
|
|
||||||
function = FunctionModel(
|
function = FunctionModel(
|
||||||
**{
|
**{
|
||||||
**form_data.model_dump(),
|
**form_data.model_dump(),
|
||||||
@ -119,7 +111,6 @@ class FunctionsTable:
|
|||||||
def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
|
def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
function = db.get(Function, id)
|
function = db.get(Function, id)
|
||||||
return FunctionModel.model_validate(function)
|
return FunctionModel.model_validate(function)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -127,7 +118,6 @@ class FunctionsTable:
|
|||||||
|
|
||||||
def get_functions(self, active_only=False) -> list[FunctionModel]:
|
def get_functions(self, active_only=False) -> list[FunctionModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
if active_only:
|
if active_only:
|
||||||
return [
|
return [
|
||||||
FunctionModel.model_validate(function)
|
FunctionModel.model_validate(function)
|
||||||
@ -143,7 +133,6 @@ class FunctionsTable:
|
|||||||
self, type: str, active_only=False
|
self, type: str, active_only=False
|
||||||
) -> list[FunctionModel]:
|
) -> list[FunctionModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
if active_only:
|
if active_only:
|
||||||
return [
|
return [
|
||||||
FunctionModel.model_validate(function)
|
FunctionModel.model_validate(function)
|
||||||
@ -159,7 +148,6 @@ class FunctionsTable:
|
|||||||
|
|
||||||
def get_global_filter_functions(self) -> list[FunctionModel]:
|
def get_global_filter_functions(self) -> list[FunctionModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
return [
|
return [
|
||||||
FunctionModel.model_validate(function)
|
FunctionModel.model_validate(function)
|
||||||
for function in db.query(Function)
|
for function in db.query(Function)
|
||||||
@ -178,7 +166,6 @@ class FunctionsTable:
|
|||||||
|
|
||||||
def get_function_valves_by_id(self, id: str) -> Optional[dict]:
|
def get_function_valves_by_id(self, id: str) -> Optional[dict]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
function = db.get(Function, id)
|
function = db.get(Function, id)
|
||||||
return function.valves if function.valves else {}
|
return function.valves if function.valves else {}
|
||||||
@ -190,7 +177,6 @@ class FunctionsTable:
|
|||||||
self, id: str, valves: dict
|
self, id: str, valves: dict
|
||||||
) -> Optional[FunctionValves]:
|
) -> Optional[FunctionValves]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
function = db.get(Function, id)
|
function = db.get(Function, id)
|
||||||
function.valves = valves
|
function.valves = valves
|
||||||
@ -204,7 +190,6 @@ class FunctionsTable:
|
|||||||
def get_user_valves_by_id_and_user_id(
|
def get_user_valves_by_id_and_user_id(
|
||||||
self, id: str, user_id: str
|
self, id: str, user_id: str
|
||||||
) -> Optional[dict]:
|
) -> Optional[dict]:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user = Users.get_user_by_id(user_id)
|
user = Users.get_user_by_id(user_id)
|
||||||
user_settings = user.settings.model_dump() if user.settings else {}
|
user_settings = user.settings.model_dump() if user.settings else {}
|
||||||
@ -223,7 +208,6 @@ class FunctionsTable:
|
|||||||
def update_user_valves_by_id_and_user_id(
|
def update_user_valves_by_id_and_user_id(
|
||||||
self, id: str, user_id: str, valves: dict
|
self, id: str, user_id: str, valves: dict
|
||||||
) -> Optional[dict]:
|
) -> Optional[dict]:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user = Users.get_user_by_id(user_id)
|
user = Users.get_user_by_id(user_id)
|
||||||
user_settings = user.settings.model_dump() if user.settings else {}
|
user_settings = user.settings.model_dump() if user.settings else {}
|
||||||
@ -246,7 +230,6 @@ class FunctionsTable:
|
|||||||
|
|
||||||
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
|
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db.query(Function).filter_by(id=id).update(
|
db.query(Function).filter_by(id=id).update(
|
||||||
{
|
{
|
||||||
@ -261,7 +244,6 @@ class FunctionsTable:
|
|||||||
|
|
||||||
def deactivate_all_functions(self) -> Optional[bool]:
|
def deactivate_all_functions(self) -> Optional[bool]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db.query(Function).update(
|
db.query(Function).update(
|
||||||
{
|
{
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
from typing import Union, Optional
|
|
||||||
|
|
||||||
from sqlalchemy import Column, String, BigInteger, Text
|
|
||||||
|
|
||||||
from apps.webui.internal.db import Base, get_db
|
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from apps.webui.internal.db import Base, get_db
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from sqlalchemy import BigInteger, Column, String, Text
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# Memory DB Schema
|
# Memory DB Schema
|
||||||
@ -39,13 +37,11 @@ class MemoryModel(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class MemoriesTable:
|
class MemoriesTable:
|
||||||
|
|
||||||
def insert_new_memory(
|
def insert_new_memory(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
content: str,
|
content: str,
|
||||||
) -> Optional[MemoryModel]:
|
) -> Optional[MemoryModel]:
|
||||||
|
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
id = str(uuid.uuid4())
|
id = str(uuid.uuid4())
|
||||||
|
|
||||||
@ -73,7 +69,6 @@ class MemoriesTable:
|
|||||||
content: str,
|
content: str,
|
||||||
) -> Optional[MemoryModel]:
|
) -> Optional[MemoryModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db.query(Memory).filter_by(id=id).update(
|
db.query(Memory).filter_by(id=id).update(
|
||||||
{"content": content, "updated_at": int(time.time())}
|
{"content": content, "updated_at": int(time.time())}
|
||||||
@ -85,7 +80,6 @@ class MemoriesTable:
|
|||||||
|
|
||||||
def get_memories(self) -> list[MemoryModel]:
|
def get_memories(self) -> list[MemoryModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
memories = db.query(Memory).all()
|
memories = db.query(Memory).all()
|
||||||
return [MemoryModel.model_validate(memory) for memory in memories]
|
return [MemoryModel.model_validate(memory) for memory in memories]
|
||||||
@ -94,7 +88,6 @@ class MemoriesTable:
|
|||||||
|
|
||||||
def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]:
|
def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
memories = db.query(Memory).filter_by(user_id=user_id).all()
|
memories = db.query(Memory).filter_by(user_id=user_id).all()
|
||||||
return [MemoryModel.model_validate(memory) for memory in memories]
|
return [MemoryModel.model_validate(memory) for memory in memories]
|
||||||
@ -103,7 +96,6 @@ class MemoriesTable:
|
|||||||
|
|
||||||
def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
|
def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
memory = db.get(Memory, id)
|
memory = db.get(Memory, id)
|
||||||
return MemoryModel.model_validate(memory)
|
return MemoryModel.model_validate(memory)
|
||||||
@ -112,7 +104,6 @@ class MemoriesTable:
|
|||||||
|
|
||||||
def delete_memory_by_id(self, id: str) -> bool:
|
def delete_memory_by_id(self, id: str) -> bool:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db.query(Memory).filter_by(id=id).delete()
|
db.query(Memory).filter_by(id=id).delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
@ -124,7 +115,6 @@ class MemoriesTable:
|
|||||||
|
|
||||||
def delete_memories_by_user_id(self, user_id: str) -> bool:
|
def delete_memories_by_user_id(self, user_id: str) -> bool:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db.query(Memory).filter_by(user_id=user_id).delete()
|
db.query(Memory).filter_by(user_id=user_id).delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
@ -135,7 +125,6 @@ class MemoriesTable:
|
|||||||
|
|
||||||
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db.query(Memory).filter_by(id=id, user_id=user_id).delete()
|
db.query(Memory).filter_by(id=id, user_id=user_id).delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
|
@ -1,14 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional, List
|
import time
|
||||||
|
from typing import Optional
|
||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
from sqlalchemy import Column, BigInteger, Text
|
|
||||||
|
|
||||||
from apps.webui.internal.db import Base, JSONField, get_db
|
from apps.webui.internal.db import Base, JSONField, get_db
|
||||||
|
|
||||||
from env import SRC_LOG_LEVELS
|
from env import SRC_LOG_LEVELS
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
import time
|
from sqlalchemy import BigInteger, Column, Text
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
@ -1,12 +1,9 @@
|
|||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
from typing import Optional
|
|
||||||
import time
|
import time
|
||||||
|
from typing import Optional
|
||||||
from sqlalchemy import String, Column, BigInteger, Text
|
|
||||||
|
|
||||||
from apps.webui.internal.db import Base, get_db
|
from apps.webui.internal.db import Base, get_db
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
import json
|
from sqlalchemy import BigInteger, Column, String, Text
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# Prompts DB Schema
|
# Prompts DB Schema
|
||||||
@ -45,7 +42,6 @@ class PromptForm(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class PromptsTable:
|
class PromptsTable:
|
||||||
|
|
||||||
def insert_new_prompt(
|
def insert_new_prompt(
|
||||||
self, user_id: str, form_data: PromptForm
|
self, user_id: str, form_data: PromptForm
|
||||||
) -> Optional[PromptModel]:
|
) -> Optional[PromptModel]:
|
||||||
@ -61,7 +57,6 @@ class PromptsTable:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
result = Prompt(**prompt.dict())
|
result = Prompt(**prompt.dict())
|
||||||
db.add(result)
|
db.add(result)
|
||||||
db.commit()
|
db.commit()
|
||||||
@ -70,13 +65,12 @@ class PromptsTable:
|
|||||||
return PromptModel.model_validate(result)
|
return PromptModel.model_validate(result)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
|
def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
prompt = db.query(Prompt).filter_by(command=command).first()
|
prompt = db.query(Prompt).filter_by(command=command).first()
|
||||||
return PromptModel.model_validate(prompt)
|
return PromptModel.model_validate(prompt)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -84,7 +78,6 @@ class PromptsTable:
|
|||||||
|
|
||||||
def get_prompts(self) -> list[PromptModel]:
|
def get_prompts(self) -> list[PromptModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
return [
|
return [
|
||||||
PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
|
PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
|
||||||
]
|
]
|
||||||
@ -94,7 +87,6 @@ class PromptsTable:
|
|||||||
) -> Optional[PromptModel]:
|
) -> Optional[PromptModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
prompt = db.query(Prompt).filter_by(command=command).first()
|
prompt = db.query(Prompt).filter_by(command=command).first()
|
||||||
prompt.title = form_data.title
|
prompt.title = form_data.title
|
||||||
prompt.content = form_data.content
|
prompt.content = form_data.content
|
||||||
@ -107,7 +99,6 @@ class PromptsTable:
|
|||||||
def delete_prompt_by_command(self, command: str) -> bool:
|
def delete_prompt_by_command(self, command: str) -> bool:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
db.query(Prompt).filter_by(command=command).delete()
|
db.query(Prompt).filter_by(command=command).delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
@ -1,16 +1,12 @@
|
|||||||
from pydantic import BaseModel, ConfigDict
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import json
|
|
||||||
import uuid
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from sqlalchemy import String, Column, BigInteger, Text
|
|
||||||
|
|
||||||
from apps.webui.internal.db import Base, get_db
|
from apps.webui.internal.db import Base, get_db
|
||||||
|
|
||||||
from env import SRC_LOG_LEVELS
|
from env import SRC_LOG_LEVELS
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from sqlalchemy import BigInteger, Column, String, Text
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
@ -77,10 +73,8 @@ class ChatTagsResponse(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class TagTable:
|
class TagTable:
|
||||||
|
|
||||||
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
|
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
id = str(uuid.uuid4())
|
id = str(uuid.uuid4())
|
||||||
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
|
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
|
||||||
try:
|
try:
|
||||||
@ -92,7 +86,7 @@ class TagTable:
|
|||||||
return TagModel.model_validate(result)
|
return TagModel.model_validate(result)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_tag_by_name_and_user_id(
|
def get_tag_by_name_and_user_id(
|
||||||
@ -102,7 +96,7 @@ class TagTable:
|
|||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
tag = db.query(Tag).filter_by(name=name, user_id=user_id).first()
|
tag = db.query(Tag).filter_by(name=name, user_id=user_id).first()
|
||||||
return TagModel.model_validate(tag)
|
return TagModel.model_validate(tag)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def add_tag_to_chat(
|
def add_tag_to_chat(
|
||||||
@ -161,7 +155,6 @@ class TagTable:
|
|||||||
self, chat_id: str, user_id: str
|
self, chat_id: str, user_id: str
|
||||||
) -> list[TagModel]:
|
) -> list[TagModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
tag_names = [
|
tag_names = [
|
||||||
chat_id_tag.tag_name
|
chat_id_tag.tag_name
|
||||||
for chat_id_tag in (
|
for chat_id_tag in (
|
||||||
@ -186,7 +179,6 @@ class TagTable:
|
|||||||
self, tag_name: str, user_id: str
|
self, tag_name: str, user_id: str
|
||||||
) -> list[ChatIdTagModel]:
|
) -> list[ChatIdTagModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ChatIdTagModel.model_validate(chat_id_tag)
|
ChatIdTagModel.model_validate(chat_id_tag)
|
||||||
for chat_id_tag in (
|
for chat_id_tag in (
|
||||||
@ -201,7 +193,6 @@ class TagTable:
|
|||||||
self, tag_name: str, user_id: str
|
self, tag_name: str, user_id: str
|
||||||
) -> int:
|
) -> int:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
return (
|
return (
|
||||||
db.query(ChatIdTag)
|
db.query(ChatIdTag)
|
||||||
.filter_by(tag_name=tag_name, user_id=user_id)
|
.filter_by(tag_name=tag_name, user_id=user_id)
|
||||||
@ -236,7 +227,6 @@ class TagTable:
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
res = (
|
res = (
|
||||||
db.query(ChatIdTag)
|
db.query(ChatIdTag)
|
||||||
.filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
|
.filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
|
||||||
|
@ -1,17 +1,12 @@
|
|||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
from typing import Optional
|
|
||||||
import time
|
|
||||||
import logging
|
import logging
|
||||||
from sqlalchemy import String, Column, BigInteger, Text
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from apps.webui.internal.db import Base, JSONField, get_db
|
from apps.webui.internal.db import Base, JSONField, get_db
|
||||||
from apps.webui.models.users import Users
|
from apps.webui.models.users import Users
|
||||||
|
|
||||||
import json
|
|
||||||
import copy
|
|
||||||
|
|
||||||
|
|
||||||
from env import SRC_LOG_LEVELS
|
from env import SRC_LOG_LEVELS
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from sqlalchemy import BigInteger, Column, String, Text
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
@ -79,13 +74,10 @@ class ToolValves(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ToolsTable:
|
class ToolsTable:
|
||||||
|
|
||||||
def insert_new_tool(
|
def insert_new_tool(
|
||||||
self, user_id: str, form_data: ToolForm, specs: list[dict]
|
self, user_id: str, form_data: ToolForm, specs: list[dict]
|
||||||
) -> Optional[ToolModel]:
|
) -> Optional[ToolModel]:
|
||||||
|
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
tool = ToolModel(
|
tool = ToolModel(
|
||||||
**{
|
**{
|
||||||
**form_data.model_dump(),
|
**form_data.model_dump(),
|
||||||
@ -112,7 +104,6 @@ class ToolsTable:
|
|||||||
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
|
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
tool = db.get(Tool, id)
|
tool = db.get(Tool, id)
|
||||||
return ToolModel.model_validate(tool)
|
return ToolModel.model_validate(tool)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -125,7 +116,6 @@ class ToolsTable:
|
|||||||
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
|
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
tool = db.get(Tool, id)
|
tool = db.get(Tool, id)
|
||||||
return tool.valves if tool.valves else {}
|
return tool.valves if tool.valves else {}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -135,7 +125,6 @@ class ToolsTable:
|
|||||||
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
|
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
db.query(Tool).filter_by(id=id).update(
|
db.query(Tool).filter_by(id=id).update(
|
||||||
{"valves": valves, "updated_at": int(time.time())}
|
{"valves": valves, "updated_at": int(time.time())}
|
||||||
)
|
)
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
from typing import Optional
|
|
||||||
import time
|
import time
|
||||||
|
from typing import Optional
|
||||||
from sqlalchemy import String, Column, BigInteger, Text
|
|
||||||
|
|
||||||
from apps.webui.internal.db import Base, JSONField, get_db
|
from apps.webui.internal.db import Base, JSONField, get_db
|
||||||
from apps.webui.models.chats import Chats
|
from apps.webui.models.chats import Chats
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from sqlalchemy import BigInteger, Column, String, Text
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# User DB Schema
|
# User DB Schema
|
||||||
@ -113,7 +112,7 @@ class UsersTable:
|
|||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
user = db.query(User).filter_by(id=id).first()
|
user = db.query(User).filter_by(id=id).first()
|
||||||
return UserModel.model_validate(user)
|
return UserModel.model_validate(user)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
|
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
|
||||||
@ -221,7 +220,7 @@ class UsersTable:
|
|||||||
user = db.query(User).filter_by(id=id).first()
|
user = db.query(User).filter_by(id=id).first()
|
||||||
return UserModel.model_validate(user)
|
return UserModel.model_validate(user)
|
||||||
# return UserModel(**user.dict())
|
# return UserModel(**user.dict())
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def delete_user_by_id(self, id: str) -> bool:
|
def delete_user_by_id(self, id: str) -> bool:
|
||||||
@ -255,7 +254,7 @@ class UsersTable:
|
|||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
user = db.query(User).filter_by(id=id).first()
|
user = db.query(User).filter_by(id=id).first()
|
||||||
return user.api_key
|
return user.api_key
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,43 +1,33 @@
|
|||||||
import logging
|
|
||||||
|
|
||||||
from fastapi import Request, UploadFile, File
|
|
||||||
from fastapi import Depends, HTTPException, status
|
|
||||||
from fastapi.responses import Response
|
|
||||||
|
|
||||||
from fastapi import APIRouter
|
|
||||||
from pydantic import BaseModel
|
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
import csv
|
|
||||||
|
|
||||||
from apps.webui.models.auths import (
|
from apps.webui.models.auths import (
|
||||||
SigninForm,
|
|
||||||
SignupForm,
|
|
||||||
AddUserForm,
|
AddUserForm,
|
||||||
UpdateProfileForm,
|
|
||||||
UpdatePasswordForm,
|
|
||||||
UserResponse,
|
|
||||||
SigninResponse,
|
|
||||||
Auths,
|
|
||||||
ApiKey,
|
ApiKey,
|
||||||
|
Auths,
|
||||||
|
SigninForm,
|
||||||
|
SigninResponse,
|
||||||
|
SignupForm,
|
||||||
|
UpdatePasswordForm,
|
||||||
|
UpdateProfileForm,
|
||||||
|
UserResponse,
|
||||||
)
|
)
|
||||||
from apps.webui.models.users import Users
|
from apps.webui.models.users import Users
|
||||||
|
from config import WEBUI_AUTH
|
||||||
from utils.utils import (
|
|
||||||
get_password_hash,
|
|
||||||
get_current_user,
|
|
||||||
get_admin_user,
|
|
||||||
create_token,
|
|
||||||
create_api_key,
|
|
||||||
)
|
|
||||||
from utils.misc import parse_duration, validate_email_format
|
|
||||||
from utils.webhook import post_webhook
|
|
||||||
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||||
from config import (
|
from env import WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER
|
||||||
WEBUI_AUTH,
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
from fastapi.responses import Response
|
||||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
from pydantic import BaseModel
|
||||||
|
from utils.misc import parse_duration, validate_email_format
|
||||||
|
from utils.utils import (
|
||||||
|
create_api_key,
|
||||||
|
create_token,
|
||||||
|
get_admin_user,
|
||||||
|
get_current_user,
|
||||||
|
get_password_hash,
|
||||||
)
|
)
|
||||||
|
from utils.webhook import post_webhook
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -273,7 +263,6 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||||||
|
|
||||||
@router.post("/add", response_model=SigninResponse)
|
@router.post("/add", response_model=SigninResponse)
|
||||||
async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
|
async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
|
||||||
|
|
||||||
if not validate_email_format(form_data.email.lower()):
|
if not validate_email_format(form_data.email.lower()):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
||||||
@ -283,7 +272,6 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
|
|||||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
print(form_data)
|
print(form_data)
|
||||||
hashed = get_password_hash(form_data.password)
|
hashed = get_password_hash(form_data.password)
|
||||||
user = Auths.insert_new_auth(
|
user = Auths.insert_new_auth(
|
||||||
|
@ -1,34 +1,15 @@
|
|||||||
from fastapi import Depends, Request, HTTPException, status
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Union, Optional
|
|
||||||
from utils.utils import get_verified_user, get_admin_user
|
|
||||||
from fastapi import APIRouter
|
|
||||||
from pydantic import BaseModel
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from apps.webui.models.users import Users
|
from apps.webui.models.chats import ChatForm, ChatResponse, Chats, ChatTitleIdResponse
|
||||||
from apps.webui.models.chats import (
|
from apps.webui.models.tags import ChatIdTagForm, ChatIdTagModel, TagModel, Tags
|
||||||
ChatModel,
|
from config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
|
||||||
ChatResponse,
|
|
||||||
ChatTitleForm,
|
|
||||||
ChatForm,
|
|
||||||
ChatTitleIdResponse,
|
|
||||||
Chats,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
from apps.webui.models.tags import (
|
|
||||||
TagModel,
|
|
||||||
ChatIdTagModel,
|
|
||||||
ChatIdTagForm,
|
|
||||||
ChatTagsResponse,
|
|
||||||
Tags,
|
|
||||||
)
|
|
||||||
|
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
from env import SRC_LOG_LEVELS
|
||||||
from config import SRC_LOG_LEVELS, ENABLE_ADMIN_EXPORT, ENABLE_ADMIN_CHAT_ACCESS
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from utils.utils import get_admin_user, get_verified_user
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
@ -61,7 +42,6 @@ async def get_session_user_chat_list(
|
|||||||
|
|
||||||
@router.delete("/", response_model=bool)
|
@router.delete("/", response_model=bool)
|
||||||
async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
|
async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
if (
|
if (
|
||||||
user.role == "user"
|
user.role == "user"
|
||||||
and not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]
|
and not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]
|
||||||
@ -220,7 +200,6 @@ class TagNameForm(BaseModel):
|
|||||||
async def get_user_chat_list_by_tag_name(
|
async def get_user_chat_list_by_tag_name(
|
||||||
form_data: TagNameForm, user=Depends(get_verified_user)
|
form_data: TagNameForm, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
|
|
||||||
chat_ids = [
|
chat_ids = [
|
||||||
chat_id_tag.chat_id
|
chat_id_tag.chat_id
|
||||||
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
|
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
|
||||||
@ -299,7 +278,6 @@ async def update_chat_by_id(
|
|||||||
|
|
||||||
@router.delete("/{id}", response_model=bool)
|
@router.delete("/{id}", response_model=bool)
|
||||||
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
|
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
if user.role == "admin":
|
if user.role == "admin":
|
||||||
result = Chats.delete_chat_by_id(id)
|
result = Chats.delete_chat_by_id(id)
|
||||||
return result
|
return result
|
||||||
@ -323,7 +301,6 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
|
|||||||
async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
|
async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||||
if chat:
|
if chat:
|
||||||
|
|
||||||
chat_body = json.loads(chat.chat)
|
chat_body = json.loads(chat.chat)
|
||||||
updated_chat = {
|
updated_chat = {
|
||||||
**chat_body,
|
**chat_body,
|
||||||
|
@ -1,25 +1,7 @@
|
|||||||
from fastapi import Response, Request
|
|
||||||
from fastapi import Depends, FastAPI, HTTPException, status
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from fastapi import APIRouter
|
|
||||||
from pydantic import BaseModel
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from config import BannerModel
|
from config import BannerModel
|
||||||
|
from fastapi import APIRouter, Depends, Request
|
||||||
from apps.webui.models.users import Users
|
from pydantic import BaseModel
|
||||||
|
from utils.utils import get_admin_user, get_verified_user
|
||||||
from utils.utils import (
|
|
||||||
get_password_hash,
|
|
||||||
get_verified_user,
|
|
||||||
get_admin_user,
|
|
||||||
create_token,
|
|
||||||
)
|
|
||||||
from utils.misc import get_gravatar_url, validate_email_format
|
|
||||||
from constants import ERROR_MESSAGES
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
@ -1,21 +1,16 @@
|
|||||||
from fastapi import Depends, FastAPI, HTTPException, status
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Union, Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter
|
|
||||||
from pydantic import BaseModel
|
|
||||||
import json
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from apps.webui.models.documents import (
|
from apps.webui.models.documents import (
|
||||||
Documents,
|
|
||||||
DocumentForm,
|
DocumentForm,
|
||||||
DocumentUpdateForm,
|
|
||||||
DocumentModel,
|
|
||||||
DocumentResponse,
|
DocumentResponse,
|
||||||
|
Documents,
|
||||||
|
DocumentUpdateForm,
|
||||||
)
|
)
|
||||||
|
|
||||||
from utils.utils import get_verified_user, get_admin_user
|
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from utils.utils import get_admin_user, get_verified_user
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
@ -1,42 +1,17 @@
|
|||||||
from fastapi import (
|
import logging
|
||||||
Depends,
|
|
||||||
FastAPI,
|
|
||||||
HTTPException,
|
|
||||||
status,
|
|
||||||
Request,
|
|
||||||
UploadFile,
|
|
||||||
File,
|
|
||||||
Form,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Union, Optional
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from fastapi import APIRouter
|
|
||||||
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
import json
|
|
||||||
|
|
||||||
from apps.webui.models.files import (
|
|
||||||
Files,
|
|
||||||
FileForm,
|
|
||||||
FileModel,
|
|
||||||
FileModelResponse,
|
|
||||||
)
|
|
||||||
from utils.utils import get_verified_user, get_admin_user
|
|
||||||
from constants import ERROR_MESSAGES
|
|
||||||
|
|
||||||
from importlib import util
|
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
import uuid
|
import uuid
|
||||||
import os, shutil, logging, re
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from config import SRC_LOG_LEVELS, UPLOAD_DIR
|
|
||||||
|
|
||||||
|
from apps.webui.models.files import FileForm, FileModel, Files
|
||||||
|
from config import UPLOAD_DIR
|
||||||
|
from constants import ERROR_MESSAGES
|
||||||
|
from env import SRC_LOG_LEVELS
|
||||||
|
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
from utils.utils import get_admin_user, get_verified_user
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
@ -1,27 +1,18 @@
|
|||||||
from fastapi import Depends, FastAPI, HTTPException, status, Request
|
import os
|
||||||
from datetime import datetime, timedelta
|
from pathlib import Path
|
||||||
from typing import Union, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter
|
|
||||||
from pydantic import BaseModel
|
|
||||||
import json
|
|
||||||
|
|
||||||
from apps.webui.models.functions import (
|
from apps.webui.models.functions import (
|
||||||
Functions,
|
|
||||||
FunctionForm,
|
FunctionForm,
|
||||||
FunctionModel,
|
FunctionModel,
|
||||||
FunctionResponse,
|
FunctionResponse,
|
||||||
|
Functions,
|
||||||
)
|
)
|
||||||
from apps.webui.utils import load_function_module_by_id
|
from apps.webui.utils import load_function_module_by_id
|
||||||
from utils.utils import get_verified_user, get_admin_user
|
from config import CACHE_DIR, FUNCTIONS_DIR
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
from importlib import util
|
from utils.utils import get_admin_user, get_verified_user
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from config import DATA_DIR, CACHE_DIR, FUNCTIONS_DIR
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -304,7 +295,6 @@ async def update_function_valves_by_id(
|
|||||||
):
|
):
|
||||||
function = Functions.get_function_by_id(id)
|
function = Functions.get_function_by_id(id)
|
||||||
if function:
|
if function:
|
||||||
|
|
||||||
if id in request.app.state.FUNCTIONS:
|
if id in request.app.state.FUNCTIONS:
|
||||||
function_module = request.app.state.FUNCTIONS[id]
|
function_module = request.app.state.FUNCTIONS[id]
|
||||||
else:
|
else:
|
||||||
|
@ -1,18 +1,12 @@
|
|||||||
from fastapi import Response, Request
|
|
||||||
from fastapi import Depends, FastAPI, HTTPException, status
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Union, Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter
|
|
||||||
from pydantic import BaseModel
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from apps.webui.models.memories import Memories, MemoryModel
|
from apps.webui.models.memories import Memories, MemoryModel
|
||||||
|
from config import CHROMA_CLIENT
|
||||||
|
from env import SRC_LOG_LEVELS
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
from pydantic import BaseModel
|
||||||
from utils.utils import get_verified_user
|
from utils.utils import get_verified_user
|
||||||
from constants import ERROR_MESSAGES
|
|
||||||
|
|
||||||
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
@ -1,15 +1,9 @@
|
|||||||
from fastapi import Depends, FastAPI, HTTPException, status, Request
|
from typing import Optional
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Union, Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from apps.webui.models.models import ModelForm, ModelModel, ModelResponse, Models
|
||||||
from pydantic import BaseModel
|
|
||||||
import json
|
|
||||||
|
|
||||||
from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
|
|
||||||
|
|
||||||
from utils.utils import get_verified_user, get_admin_user
|
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
from utils.utils import get_admin_user, get_verified_user
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
@ -1,15 +1,9 @@
|
|||||||
from fastapi import Depends, FastAPI, HTTPException, status
|
from typing import Optional
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Union, Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from apps.webui.models.prompts import PromptForm, PromptModel, Prompts
|
||||||
from pydantic import BaseModel
|
|
||||||
import json
|
|
||||||
|
|
||||||
from apps.webui.models.prompts import Prompts, PromptForm, PromptModel
|
|
||||||
|
|
||||||
from utils.utils import get_verified_user, get_admin_user
|
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from utils.utils import get_admin_user, get_verified_user
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
@ -1,20 +1,14 @@
|
|||||||
from fastapi import Depends, HTTPException, status, Request
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter
|
|
||||||
|
|
||||||
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
|
|
||||||
from apps.webui.utils import load_toolkit_module_by_id
|
|
||||||
|
|
||||||
from utils.utils import get_admin_user, get_verified_user
|
|
||||||
from utils.tools import get_tools_specs
|
|
||||||
from constants import ERROR_MESSAGES
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from config import DATA_DIR, CACHE_DIR
|
from apps.webui.models.tools import ToolForm, ToolModel, ToolResponse, Tools
|
||||||
|
from apps.webui.utils import load_toolkit_module_by_id
|
||||||
|
from config import CACHE_DIR, DATA_DIR
|
||||||
|
from constants import ERROR_MESSAGES
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
from utils.tools import get_tools_specs
|
||||||
|
from utils.utils import get_admin_user, get_verified_user
|
||||||
|
|
||||||
TOOLS_DIR = f"{DATA_DIR}/tools"
|
TOOLS_DIR = f"{DATA_DIR}/tools"
|
||||||
os.makedirs(TOOLS_DIR, exist_ok=True)
|
os.makedirs(TOOLS_DIR, exist_ok=True)
|
||||||
|
@ -1,33 +1,20 @@
|
|||||||
from fastapi import Response, Request
|
|
||||||
from fastapi import Depends, FastAPI, HTTPException, status
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Union, Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter
|
|
||||||
from pydantic import BaseModel
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from apps.webui.models.users import (
|
|
||||||
UserModel,
|
|
||||||
UserUpdateForm,
|
|
||||||
UserRoleUpdateForm,
|
|
||||||
UserSettings,
|
|
||||||
Users,
|
|
||||||
)
|
|
||||||
from apps.webui.models.auths import Auths
|
from apps.webui.models.auths import Auths
|
||||||
from apps.webui.models.chats import Chats
|
from apps.webui.models.chats import Chats
|
||||||
|
from apps.webui.models.users import (
|
||||||
from utils.utils import (
|
UserModel,
|
||||||
get_verified_user,
|
UserRoleUpdateForm,
|
||||||
get_password_hash,
|
Users,
|
||||||
get_current_user,
|
UserSettings,
|
||||||
get_admin_user,
|
UserUpdateForm,
|
||||||
)
|
)
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
from env import SRC_LOG_LEVELS
|
||||||
from config import SRC_LOG_LEVELS
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from utils.utils import get_admin_user, get_password_hash, get_verified_user
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
@ -69,7 +56,6 @@ async def update_user_permissions(
|
|||||||
|
|
||||||
@router.post("/update/role", response_model=Optional[UserModel])
|
@router.post("/update/role", response_model=Optional[UserModel])
|
||||||
async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)):
|
async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)):
|
||||||
|
|
||||||
if user.id != form_data.id and form_data.id != Users.get_first_user().id:
|
if user.id != form_data.id and form_data.id != Users.get_first_user().id:
|
||||||
return Users.update_user_role_by_id(form_data.id, form_data.role)
|
return Users.update_user_role_by_id(form_data.id, form_data.role)
|
||||||
|
|
||||||
@ -173,7 +159,6 @@ class UserResponse(BaseModel):
|
|||||||
|
|
||||||
@router.get("/{user_id}", response_model=UserResponse)
|
@router.get("/{user_id}", response_model=UserResponse)
|
||||||
async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
|
async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
# Check if user_id is a shared chat
|
# Check if user_id is a shared chat
|
||||||
# If it is, get the user_id from the chat
|
# If it is, get the user_id from the chat
|
||||||
if user_id.startswith("shared-"):
|
if user_id.startswith("shared-"):
|
||||||
|
@ -1,23 +1,16 @@
|
|||||||
from pathlib import Path
|
|
||||||
import site
|
import site
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from fastapi import APIRouter, UploadFile, File, Response
|
|
||||||
from fastapi import Depends, HTTPException, status
|
|
||||||
from starlette.responses import StreamingResponse, FileResponse
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
from fpdf import FPDF
|
|
||||||
import markdown
|
|
||||||
import black
|
import black
|
||||||
|
import markdown
|
||||||
|
from config import DATA_DIR, ENABLE_ADMIN_EXPORT
|
||||||
from utils.utils import get_admin_user
|
|
||||||
from utils.misc import calculate_sha256, get_gravatar_url
|
|
||||||
|
|
||||||
from config import OLLAMA_BASE_URLS, DATA_DIR, UPLOAD_DIR, ENABLE_ADMIN_EXPORT
|
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||||
|
from fpdf import FPDF
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from starlette.responses import FileResponse
|
||||||
|
from utils.misc import get_gravatar_url
|
||||||
|
from utils.utils import get_admin_user
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -115,7 +108,7 @@ async def download_chat_as_pdf(
|
|||||||
return Response(
|
return Response(
|
||||||
content=bytes(pdf_bytes),
|
content=bytes(pdf_bytes),
|
||||||
media_type="application/pdf",
|
media_type="application/pdf",
|
||||||
headers={"Content-Disposition": f"attachment;filename=chat.pdf"},
|
headers={"Content-Disposition": "attachment;filename=chat.pdf"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
from importlib import util
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from importlib import util
|
||||||
|
|
||||||
|
|
||||||
from apps.webui.models.tools import Tools
|
|
||||||
from apps.webui.models.functions import Functions
|
from apps.webui.models.functions import Functions
|
||||||
from config import TOOLS_DIR, FUNCTIONS_DIR
|
from apps.webui.models.tools import Tools
|
||||||
|
from config import FUNCTIONS_DIR, TOOLS_DIR
|
||||||
|
|
||||||
|
|
||||||
def extract_frontmatter(file_path):
|
def extract_frontmatter(file_path):
|
||||||
|
@ -1,58 +1,30 @@
|
|||||||
from sqlalchemy import create_engine, Column, Integer, DateTime, JSON, func
|
import json
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import logging
|
import logging
|
||||||
import importlib.metadata
|
import os
|
||||||
import pkgutil
|
import shutil
|
||||||
from urllib.parse import urlparse
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Generic, Optional, TypeVar
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
from chromadb import Settings
|
|
||||||
from typing import TypeVar, Generic
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
import json
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import shutil
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
from apps.webui.internal.db import Base, get_db
|
from apps.webui.internal.db import Base, get_db
|
||||||
|
from chromadb import Settings
|
||||||
from constants import ERROR_MESSAGES
|
|
||||||
|
|
||||||
from env import (
|
from env import (
|
||||||
ENV,
|
|
||||||
VERSION,
|
|
||||||
SAFE_MODE,
|
|
||||||
GLOBAL_LOG_LEVEL,
|
|
||||||
SRC_LOG_LEVELS,
|
|
||||||
BASE_DIR,
|
|
||||||
DATA_DIR,
|
|
||||||
BACKEND_DIR,
|
BACKEND_DIR,
|
||||||
FRONTEND_BUILD_DIR,
|
|
||||||
WEBUI_NAME,
|
|
||||||
WEBUI_URL,
|
|
||||||
WEBUI_FAVICON_URL,
|
|
||||||
WEBUI_BUILD_HASH,
|
|
||||||
CONFIG_DATA,
|
CONFIG_DATA,
|
||||||
DATABASE_URL,
|
DATA_DIR,
|
||||||
CHANGELOG,
|
ENV,
|
||||||
|
FRONTEND_BUILD_DIR,
|
||||||
WEBUI_AUTH,
|
WEBUI_AUTH,
|
||||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
WEBUI_FAVICON_URL,
|
||||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
WEBUI_NAME,
|
||||||
WEBUI_SECRET_KEY,
|
|
||||||
WEBUI_SESSION_COOKIE_SAME_SITE,
|
|
||||||
WEBUI_SESSION_COOKIE_SECURE,
|
|
||||||
log,
|
log,
|
||||||
)
|
)
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import JSON, Column, DateTime, Integer, func
|
||||||
|
|
||||||
|
|
||||||
class EndpointFilter(logging.Filter):
|
class EndpointFilter(logging.Filter):
|
||||||
@ -72,8 +44,8 @@ logging.getLogger("uvicorn.access").addFilter(EndpointFilter())
|
|||||||
def run_migrations():
|
def run_migrations():
|
||||||
print("Running migrations")
|
print("Running migrations")
|
||||||
try:
|
try:
|
||||||
from alembic.config import Config
|
|
||||||
from alembic import command
|
from alembic import command
|
||||||
|
from alembic.config import Config
|
||||||
|
|
||||||
alembic_cfg = Config(BACKEND_DIR / "alembic.ini")
|
alembic_cfg = Config(BACKEND_DIR / "alembic.ini")
|
||||||
command.upgrade(alembic_cfg, "head")
|
command.upgrade(alembic_cfg, "head")
|
||||||
|
@ -1,19 +1,13 @@
|
|||||||
from pathlib import Path
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
import pkgutil
|
import pkgutil
|
||||||
from urllib.parse import urlparse
|
import sys
|
||||||
from datetime import datetime
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
import markdown
|
import markdown
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
@ -26,7 +20,7 @@ BASE_DIR = BACKEND_DIR.parent # the path containing the backend/
|
|||||||
print(BASE_DIR)
|
print(BASE_DIR)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from dotenv import load_dotenv, find_dotenv
|
from dotenv import find_dotenv, load_dotenv
|
||||||
|
|
||||||
load_dotenv(find_dotenv(str(BASE_DIR / ".env")))
|
load_dotenv(find_dotenv(str(BASE_DIR / ".env")))
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
212
backend/main.py
212
backend/main.py
@ -1,130 +1,124 @@
|
|||||||
import base64
|
import base64
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import mimetypes
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from authlib.integrations.starlette_client import OAuth
|
|
||||||
from authlib.oidc.core import UserInfo
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import logging
|
|
||||||
import aiohttp
|
|
||||||
import requests
|
|
||||||
import mimetypes
|
|
||||||
import shutil
|
|
||||||
import inspect
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
|
import aiohttp
|
||||||
from fastapi.staticfiles import StaticFiles
|
import requests
|
||||||
from fastapi.responses import JSONResponse
|
from apps.audio.main import app as audio_app
|
||||||
from fastapi import HTTPException
|
from apps.images.main import app as images_app
|
||||||
|
from apps.ollama.main import app as ollama_app
|
||||||
|
from apps.ollama.main import (
|
||||||
|
generate_openai_chat_completion as generate_ollama_chat_completion,
|
||||||
|
)
|
||||||
|
from apps.ollama.main import get_all_models as get_ollama_models
|
||||||
|
from apps.openai.main import app as openai_app
|
||||||
|
from apps.openai.main import generate_chat_completion as generate_openai_chat_completion
|
||||||
|
from apps.openai.main import get_all_models as get_openai_models
|
||||||
|
from apps.rag.main import app as rag_app
|
||||||
|
from apps.rag.utils import get_rag_context, rag_template
|
||||||
|
from apps.socket.main import app as socket_app
|
||||||
|
from apps.socket.main import get_event_call, get_event_emitter
|
||||||
|
from apps.webui.internal.db import Session
|
||||||
|
from apps.webui.main import app as webui_app
|
||||||
|
from apps.webui.main import generate_function_chat_completion, get_pipe_models
|
||||||
|
from apps.webui.models.auths import Auths
|
||||||
|
from apps.webui.models.functions import Functions
|
||||||
|
from apps.webui.models.models import Models
|
||||||
|
from apps.webui.models.users import UserModel, Users
|
||||||
|
from apps.webui.utils import load_function_module_by_id
|
||||||
|
from authlib.integrations.starlette_client import OAuth
|
||||||
|
from authlib.oidc.core import UserInfo
|
||||||
|
from config import (
|
||||||
|
CACHE_DIR,
|
||||||
|
CORS_ALLOW_ORIGIN,
|
||||||
|
DEFAULT_LOCALE,
|
||||||
|
ENABLE_ADMIN_CHAT_ACCESS,
|
||||||
|
ENABLE_ADMIN_EXPORT,
|
||||||
|
ENABLE_MODEL_FILTER,
|
||||||
|
ENABLE_OAUTH_SIGNUP,
|
||||||
|
ENABLE_OLLAMA_API,
|
||||||
|
ENABLE_OPENAI_API,
|
||||||
|
ENV,
|
||||||
|
FRONTEND_BUILD_DIR,
|
||||||
|
MODEL_FILTER_LIST,
|
||||||
|
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
|
||||||
|
OAUTH_PROVIDERS,
|
||||||
|
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||||
|
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
||||||
|
STATIC_DIR,
|
||||||
|
TASK_MODEL,
|
||||||
|
TASK_MODEL_EXTERNAL,
|
||||||
|
TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||||
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||||
|
WEBHOOK_URL,
|
||||||
|
WEBUI_AUTH,
|
||||||
|
WEBUI_NAME,
|
||||||
|
AppConfig,
|
||||||
|
run_migrations,
|
||||||
|
)
|
||||||
|
from constants import ERROR_MESSAGES, TASKS, WEBHOOK_MESSAGES
|
||||||
|
from env import (
|
||||||
|
CHANGELOG,
|
||||||
|
GLOBAL_LOG_LEVEL,
|
||||||
|
SAFE_MODE,
|
||||||
|
SRC_LOG_LEVELS,
|
||||||
|
VERSION,
|
||||||
|
WEBUI_BUILD_HASH,
|
||||||
|
WEBUI_SECRET_KEY,
|
||||||
|
WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||||
|
WEBUI_SESSION_COOKIE_SECURE,
|
||||||
|
WEBUI_URL,
|
||||||
|
)
|
||||||
|
from fastapi import (
|
||||||
|
Depends,
|
||||||
|
FastAPI,
|
||||||
|
File,
|
||||||
|
Form,
|
||||||
|
HTTPException,
|
||||||
|
Request,
|
||||||
|
UploadFile,
|
||||||
|
status,
|
||||||
|
)
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from starlette.middleware.sessions import SessionMiddleware
|
from starlette.middleware.sessions import SessionMiddleware
|
||||||
from starlette.responses import StreamingResponse, Response, RedirectResponse
|
from starlette.responses import RedirectResponse, Response, StreamingResponse
|
||||||
|
from utils.misc import (
|
||||||
|
add_or_update_system_message,
|
||||||
from apps.socket.main import app as socket_app, get_event_emitter, get_event_call
|
get_last_user_message,
|
||||||
from apps.ollama.main import (
|
parse_duration,
|
||||||
app as ollama_app,
|
prepend_to_first_user_message_content,
|
||||||
get_all_models as get_ollama_models,
|
|
||||||
generate_openai_chat_completion as generate_ollama_chat_completion,
|
|
||||||
)
|
)
|
||||||
from apps.openai.main import (
|
from utils.task import (
|
||||||
app as openai_app,
|
moa_response_generation_template,
|
||||||
get_all_models as get_openai_models,
|
search_query_generation_template,
|
||||||
generate_chat_completion as generate_openai_chat_completion,
|
title_generation_template,
|
||||||
|
tools_function_calling_generation_template,
|
||||||
)
|
)
|
||||||
|
from utils.tools import get_tools
|
||||||
from apps.audio.main import app as audio_app
|
|
||||||
from apps.images.main import app as images_app
|
|
||||||
from apps.rag.main import app as rag_app
|
|
||||||
from apps.webui.main import (
|
|
||||||
app as webui_app,
|
|
||||||
get_pipe_models,
|
|
||||||
generate_function_chat_completion,
|
|
||||||
)
|
|
||||||
from apps.webui.internal.db import Session
|
|
||||||
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from apps.webui.models.auths import Auths
|
|
||||||
from apps.webui.models.models import Models
|
|
||||||
from apps.webui.models.functions import Functions
|
|
||||||
from apps.webui.models.users import Users, UserModel
|
|
||||||
|
|
||||||
from apps.webui.utils import load_function_module_by_id
|
|
||||||
|
|
||||||
from utils.utils import (
|
from utils.utils import (
|
||||||
|
create_token,
|
||||||
|
decode_token,
|
||||||
get_admin_user,
|
get_admin_user,
|
||||||
get_verified_user,
|
|
||||||
get_current_user,
|
get_current_user,
|
||||||
get_http_authorization_cred,
|
get_http_authorization_cred,
|
||||||
get_password_hash,
|
get_password_hash,
|
||||||
create_token,
|
get_verified_user,
|
||||||
decode_token,
|
|
||||||
)
|
)
|
||||||
from utils.task import (
|
|
||||||
title_generation_template,
|
|
||||||
search_query_generation_template,
|
|
||||||
tools_function_calling_generation_template,
|
|
||||||
moa_response_generation_template,
|
|
||||||
)
|
|
||||||
|
|
||||||
from utils.tools import get_tools
|
|
||||||
from utils.misc import (
|
|
||||||
get_last_user_message,
|
|
||||||
add_or_update_system_message,
|
|
||||||
prepend_to_first_user_message_content,
|
|
||||||
parse_duration,
|
|
||||||
)
|
|
||||||
|
|
||||||
from apps.rag.utils import get_rag_context, rag_template
|
|
||||||
|
|
||||||
from config import (
|
|
||||||
run_migrations,
|
|
||||||
WEBUI_NAME,
|
|
||||||
WEBUI_URL,
|
|
||||||
WEBUI_AUTH,
|
|
||||||
ENV,
|
|
||||||
VERSION,
|
|
||||||
CHANGELOG,
|
|
||||||
FRONTEND_BUILD_DIR,
|
|
||||||
CACHE_DIR,
|
|
||||||
STATIC_DIR,
|
|
||||||
DEFAULT_LOCALE,
|
|
||||||
ENABLE_OPENAI_API,
|
|
||||||
ENABLE_OLLAMA_API,
|
|
||||||
ENABLE_MODEL_FILTER,
|
|
||||||
MODEL_FILTER_LIST,
|
|
||||||
GLOBAL_LOG_LEVEL,
|
|
||||||
SRC_LOG_LEVELS,
|
|
||||||
WEBHOOK_URL,
|
|
||||||
ENABLE_ADMIN_EXPORT,
|
|
||||||
WEBUI_BUILD_HASH,
|
|
||||||
TASK_MODEL,
|
|
||||||
TASK_MODEL_EXTERNAL,
|
|
||||||
TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
||||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
||||||
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
|
||||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
||||||
SAFE_MODE,
|
|
||||||
OAUTH_PROVIDERS,
|
|
||||||
ENABLE_OAUTH_SIGNUP,
|
|
||||||
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
|
|
||||||
WEBUI_SECRET_KEY,
|
|
||||||
WEBUI_SESSION_COOKIE_SAME_SITE,
|
|
||||||
WEBUI_SESSION_COOKIE_SECURE,
|
|
||||||
ENABLE_ADMIN_CHAT_ACCESS,
|
|
||||||
AppConfig,
|
|
||||||
CORS_ALLOW_ORIGIN,
|
|
||||||
)
|
|
||||||
|
|
||||||
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
|
|
||||||
from utils.webhook import post_webhook
|
from utils.webhook import post_webhook
|
||||||
|
|
||||||
if SAFE_MODE:
|
if SAFE_MODE:
|
||||||
|
@ -1,24 +1,9 @@
|
|||||||
import os
|
|
||||||
from logging.config import fileConfig
|
from logging.config import fileConfig
|
||||||
|
|
||||||
from sqlalchemy import engine_from_config
|
|
||||||
from sqlalchemy import pool
|
|
||||||
|
|
||||||
from alembic import context
|
from alembic import context
|
||||||
|
|
||||||
from apps.webui.models.auths import Auth
|
from apps.webui.models.auths import Auth
|
||||||
from apps.webui.models.chats import Chat
|
|
||||||
from apps.webui.models.documents import Document
|
|
||||||
from apps.webui.models.memories import Memory
|
|
||||||
from apps.webui.models.models import Model
|
|
||||||
from apps.webui.models.prompts import Prompt
|
|
||||||
from apps.webui.models.tags import Tag, ChatIdTag
|
|
||||||
from apps.webui.models.tools import Tool
|
|
||||||
from apps.webui.models.users import User
|
|
||||||
from apps.webui.models.files import File
|
|
||||||
from apps.webui.models.functions import Function
|
|
||||||
|
|
||||||
from env import DATABASE_URL
|
from env import DATABASE_URL
|
||||||
|
from sqlalchemy import engine_from_config, pool
|
||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
# this is the Alembic Config object, which provides
|
||||||
# access to the values within the .ini file in use.
|
# access to the values within the .ini file in use.
|
||||||
|
@ -1,16 +1,16 @@
|
|||||||
"""init
|
"""init
|
||||||
|
|
||||||
Revision ID: 7e5b5dc7342b
|
Revision ID: 7e5b5dc7342b
|
||||||
Revises:
|
Revises:
|
||||||
Create Date: 2024-06-24 13:15:33.808998
|
Create Date: 2024-06-24 13:15:33.808998
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
import apps.webui.internal.db
|
import apps.webui.internal.db
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
from migrations.util import get_existing_tables
|
from migrations.util import get_existing_tables
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
|
@ -8,10 +8,8 @@ Create Date: 2024-08-25 15:26:35.241684
|
|||||||
|
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
import apps.webui.internal.db
|
from alembic import op
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "ca81bd47c050"
|
revision: str = "ca81bd47c050"
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
from test.util.abstract_integration_test import AbstractPostgresTest
|
from test.util.abstract_integration_test import AbstractPostgresTest
|
||||||
from test.util.mock_user import mock_webui_user
|
from test.util.mock_user import mock_webui_user
|
||||||
|
|
||||||
@ -9,8 +7,8 @@ class TestAuths(AbstractPostgresTest):
|
|||||||
|
|
||||||
def setup_class(cls):
|
def setup_class(cls):
|
||||||
super().setup_class()
|
super().setup_class()
|
||||||
from apps.webui.models.users import Users
|
|
||||||
from apps.webui.models.auths import Auths
|
from apps.webui.models.auths import Auths
|
||||||
|
from apps.webui.models.users import Users
|
||||||
|
|
||||||
cls.users = Users
|
cls.users = Users
|
||||||
cls.auths = Auths
|
cls.auths = Auths
|
||||||
|
@ -5,7 +5,6 @@ from test.util.mock_user import mock_webui_user
|
|||||||
|
|
||||||
|
|
||||||
class TestChats(AbstractPostgresTest):
|
class TestChats(AbstractPostgresTest):
|
||||||
|
|
||||||
BASE_PATH = "/api/v1/chats"
|
BASE_PATH = "/api/v1/chats"
|
||||||
|
|
||||||
def setup_class(cls):
|
def setup_class(cls):
|
||||||
@ -13,8 +12,7 @@ class TestChats(AbstractPostgresTest):
|
|||||||
|
|
||||||
def setup_method(self):
|
def setup_method(self):
|
||||||
super().setup_method()
|
super().setup_method()
|
||||||
from apps.webui.models.chats import ChatForm
|
from apps.webui.models.chats import ChatForm, Chats
|
||||||
from apps.webui.models.chats import Chats
|
|
||||||
|
|
||||||
self.chats = Chats
|
self.chats = Chats
|
||||||
self.chats.insert_new_chat(
|
self.chats.insert_new_chat(
|
||||||
|
@ -3,7 +3,6 @@ from test.util.mock_user import mock_webui_user
|
|||||||
|
|
||||||
|
|
||||||
class TestDocuments(AbstractPostgresTest):
|
class TestDocuments(AbstractPostgresTest):
|
||||||
|
|
||||||
BASE_PATH = "/api/v1/documents"
|
BASE_PATH = "/api/v1/documents"
|
||||||
|
|
||||||
def setup_class(cls):
|
def setup_class(cls):
|
||||||
|
@ -3,7 +3,6 @@ from test.util.mock_user import mock_webui_user
|
|||||||
|
|
||||||
|
|
||||||
class TestModels(AbstractPostgresTest):
|
class TestModels(AbstractPostgresTest):
|
||||||
|
|
||||||
BASE_PATH = "/api/v1/models"
|
BASE_PATH = "/api/v1/models"
|
||||||
|
|
||||||
def setup_class(cls):
|
def setup_class(cls):
|
||||||
|
@ -3,7 +3,6 @@ from test.util.mock_user import mock_webui_user
|
|||||||
|
|
||||||
|
|
||||||
class TestPrompts(AbstractPostgresTest):
|
class TestPrompts(AbstractPostgresTest):
|
||||||
|
|
||||||
BASE_PATH = "/api/v1/prompts"
|
BASE_PATH = "/api/v1/prompts"
|
||||||
|
|
||||||
def test_prompts(self):
|
def test_prompts(self):
|
||||||
|
@ -21,7 +21,6 @@ def _assert_user(data, id, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
class TestUsers(AbstractPostgresTest):
|
class TestUsers(AbstractPostgresTest):
|
||||||
|
|
||||||
BASE_PATH = "/api/v1/users"
|
BASE_PATH = "/api/v1/users"
|
||||||
|
|
||||||
def setup_class(cls):
|
def setup_class(cls):
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
from pathlib import Path
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import re
|
import re
|
||||||
from datetime import timedelta
|
|
||||||
from typing import Optional, Callable
|
|
||||||
import uuid
|
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
|
from datetime import timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
from utils.task import prompt_template
|
from utils.task import prompt_template
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from ast import literal_eval
|
from ast import literal_eval
|
||||||
|
from typing import Any, Literal, Optional, Type
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, create_model
|
from pydantic import BaseModel, Field, create_model
|
||||||
from typing import Any, Optional, Type, Literal
|
|
||||||
|
|
||||||
|
|
||||||
def json_schema_to_model(tool_dict: dict[str, Any]) -> Type[BaseModel]:
|
def json_schema_to_model(tool_dict: dict[str, Any]) -> Type[BaseModel]:
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import re
|
|
||||||
import math
|
import math
|
||||||
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
@ -5,7 +5,6 @@ from typing import Awaitable, Callable, get_type_hints
|
|||||||
from apps.webui.models.tools import Tools
|
from apps.webui.models.tools import Tools
|
||||||
from apps.webui.models.users import UserModel
|
from apps.webui.models.users import UserModel
|
||||||
from apps.webui.utils import load_toolkit_module_by_id
|
from apps.webui.utils import load_toolkit_module_by_id
|
||||||
|
|
||||||
from utils.schemas import json_schema_to_model
|
from utils.schemas import json_schema_to_model
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -1,16 +1,15 @@
|
|||||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
||||||
from fastapi import HTTPException, status, Depends, Request
|
|
||||||
|
|
||||||
from apps.webui.models.users import Users
|
|
||||||
|
|
||||||
from typing import Union, Optional
|
|
||||||
from constants import ERROR_MESSAGES
|
|
||||||
from passlib.context import CryptContext
|
|
||||||
from datetime import datetime, timedelta, UTC
|
|
||||||
import jwt
|
|
||||||
import uuid
|
|
||||||
import logging
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
from apps.webui.models.users import Users
|
||||||
|
from constants import ERROR_MESSAGES
|
||||||
from env import WEBUI_SECRET_KEY
|
from env import WEBUI_SECRET_KEY
|
||||||
|
from fastapi import Depends, HTTPException, Request, status
|
||||||
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
|
from passlib.context import CryptContext
|
||||||
|
|
||||||
logging.getLogger("passlib").setLevel(logging.ERROR)
|
logging.getLogger("passlib").setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import json
|
import json
|
||||||
import requests
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from config import SRC_LOG_LEVELS, VERSION, WEBUI_FAVICON_URL, WEBUI_NAME
|
import requests
|
||||||
|
from config import WEBUI_FAVICON_URL, WEBUI_NAME
|
||||||
|
from env import SRC_LOG_LEVELS, VERSION
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["WEBHOOK"])
|
log.setLevel(SRC_LOG_LEVELS["WEBHOOK"])
|
||||||
|
Loading…
Reference in New Issue
Block a user