This commit is contained in:
Timothy J. Baek 2024-06-27 11:29:59 -07:00
parent 3c7f45ced4
commit 3f5f410453
7 changed files with 47 additions and 47 deletions

View File

@ -16,7 +16,7 @@ from faster_whisper import WhisperModel
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import ( from utils.utils import (
get_current_user, get_verified_user,
get_admin_user, get_admin_user,
) )
@ -258,7 +258,7 @@ async def update_image_size(
@app.get("/models") @app.get("/models")
def get_models(user=Depends(get_current_user)): def get_models(user=Depends(get_verified_user)):
try: try:
if app.state.config.ENGINE == "openai": if app.state.config.ENGINE == "openai":
return [ return [
@ -347,7 +347,7 @@ def set_model_handler(model: str):
@app.post("/models/default/update") @app.post("/models/default/update")
def update_default_model( def update_default_model(
form_data: UpdateModelForm, form_data: UpdateModelForm,
user=Depends(get_current_user), user=Depends(get_verified_user),
): ):
return set_model_handler(form_data.model) return set_model_handler(form_data.model)
@ -424,7 +424,7 @@ def save_url_image(url):
@app.post("/generations") @app.post("/generations")
def generate_image( def generate_image(
form_data: GenerateImageForm, form_data: GenerateImageForm,
user=Depends(get_current_user), user=Depends(get_verified_user),
): ):
width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x"))) width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))

View File

@ -16,7 +16,7 @@ from apps.webui.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import ( from utils.utils import (
decode_token, decode_token,
get_current_user, get_verified_user,
get_verified_user, get_verified_user,
get_admin_user, get_admin_user,
) )
@ -296,7 +296,7 @@ async def get_all_models(raw: bool = False):
@app.get("/models") @app.get("/models")
@app.get("/models/{url_idx}") @app.get("/models/{url_idx}")
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)): async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
if url_idx == None: if url_idx == None:
models = await get_all_models() models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER: if app.state.config.ENABLE_MODEL_FILTER:

View File

@ -85,7 +85,7 @@ from utils.misc import (
sanitize_filename, sanitize_filename,
extract_folders_after_data_docs, extract_folders_after_data_docs,
) )
from utils.utils import get_current_user, get_admin_user from utils.utils import get_verified_user, get_admin_user
from config import ( from config import (
AppConfig, AppConfig,
@ -529,7 +529,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
@app.get("/template") @app.get("/template")
async def get_rag_template(user=Depends(get_current_user)): async def get_rag_template(user=Depends(get_verified_user)):
return { return {
"status": True, "status": True,
"template": app.state.config.RAG_TEMPLATE, "template": app.state.config.RAG_TEMPLATE,
@ -586,7 +586,7 @@ class QueryDocForm(BaseModel):
@app.post("/query/doc") @app.post("/query/doc")
def query_doc_handler( def query_doc_handler(
form_data: QueryDocForm, form_data: QueryDocForm,
user=Depends(get_current_user), user=Depends(get_verified_user),
): ):
try: try:
if app.state.config.ENABLE_RAG_HYBRID_SEARCH: if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
@ -626,7 +626,7 @@ class QueryCollectionsForm(BaseModel):
@app.post("/query/collection") @app.post("/query/collection")
def query_collection_handler( def query_collection_handler(
form_data: QueryCollectionsForm, form_data: QueryCollectionsForm,
user=Depends(get_current_user), user=Depends(get_verified_user),
): ):
try: try:
if app.state.config.ENABLE_RAG_HYBRID_SEARCH: if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
@ -657,7 +657,7 @@ def query_collection_handler(
@app.post("/youtube") @app.post("/youtube")
def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)): def store_youtube_video(form_data: UrlForm, user=Depends(get_verified_user)):
try: try:
loader = YoutubeLoader.from_youtube_url( loader = YoutubeLoader.from_youtube_url(
form_data.url, form_data.url,
@ -686,7 +686,7 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
@app.post("/web") @app.post("/web")
def store_web(form_data: UrlForm, user=Depends(get_current_user)): def store_web(form_data: UrlForm, user=Depends(get_verified_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try: try:
loader = get_web_loader( loader = get_web_loader(
@ -864,7 +864,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
@app.post("/web/search") @app.post("/web/search")
def store_web_search(form_data: SearchForm, user=Depends(get_current_user)): def store_web_search(form_data: SearchForm, user=Depends(get_verified_user)):
try: try:
logging.info( logging.info(
f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}" f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
@ -1084,7 +1084,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
def store_doc( def store_doc(
collection_name: Optional[str] = Form(None), collection_name: Optional[str] = Form(None),
file: UploadFile = File(...), file: UploadFile = File(...),
user=Depends(get_current_user), user=Depends(get_verified_user),
): ):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
@ -1145,7 +1145,7 @@ class ProcessDocForm(BaseModel):
@app.post("/process/doc") @app.post("/process/doc")
def process_doc( def process_doc(
form_data: ProcessDocForm, form_data: ProcessDocForm,
user=Depends(get_current_user), user=Depends(get_verified_user),
): ):
try: try:
file = Files.get_file_by_id(form_data.file_id) file = Files.get_file_by_id(form_data.file_id)
@ -1200,7 +1200,7 @@ class TextRAGForm(BaseModel):
@app.post("/text") @app.post("/text")
def store_text( def store_text(
form_data: TextRAGForm, form_data: TextRAGForm,
user=Depends(get_current_user), user=Depends(get_verified_user),
): ):
collection_name = form_data.collection_name collection_name = form_data.collection_name

View File

@ -1,7 +1,7 @@
from fastapi import Depends, Request, HTTPException, status from fastapi import Depends, Request, HTTPException, status
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Union, Optional from typing import List, Union, Optional
from utils.utils import get_current_user, get_admin_user from utils.utils import get_verified_user, get_admin_user
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
@ -43,7 +43,7 @@ router = APIRouter()
@router.get("/", response_model=List[ChatTitleIdResponse]) @router.get("/", response_model=List[ChatTitleIdResponse])
@router.get("/list", response_model=List[ChatTitleIdResponse]) @router.get("/list", response_model=List[ChatTitleIdResponse])
async def get_session_user_chat_list( async def get_session_user_chat_list(
user=Depends(get_current_user), skip: int = 0, limit: int = 50 user=Depends(get_verified_user), skip: int = 0, limit: int = 50
): ):
return Chats.get_chat_list_by_user_id(user.id, skip, limit) return Chats.get_chat_list_by_user_id(user.id, skip, limit)
@ -54,7 +54,7 @@ async def get_session_user_chat_list(
@router.delete("/", response_model=bool) @router.delete("/", response_model=bool)
async def delete_all_user_chats(request: Request, user=Depends(get_current_user)): async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
if ( if (
user.role == "user" user.role == "user"
@ -89,7 +89,7 @@ async def get_user_chat_list_by_user_id(
@router.post("/new", response_model=Optional[ChatResponse]) @router.post("/new", response_model=Optional[ChatResponse])
async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
try: try:
chat = Chats.insert_new_chat(user.id, form_data) chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
@ -106,7 +106,7 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
@router.get("/all", response_model=List[ChatResponse]) @router.get("/all", response_model=List[ChatResponse])
async def get_user_chats(user=Depends(get_current_user)): async def get_user_chats(user=Depends(get_verified_user)):
return [ return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_chats_by_user_id(user.id) for chat in Chats.get_chats_by_user_id(user.id)
@ -119,7 +119,7 @@ async def get_user_chats(user=Depends(get_current_user)):
@router.get("/all/archived", response_model=List[ChatResponse]) @router.get("/all/archived", response_model=List[ChatResponse])
async def get_user_chats(user=Depends(get_current_user)): async def get_user_chats(user=Depends(get_verified_user)):
return [ return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_archived_chats_by_user_id(user.id) for chat in Chats.get_archived_chats_by_user_id(user.id)
@ -151,7 +151,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
@router.get("/archived", response_model=List[ChatTitleIdResponse]) @router.get("/archived", response_model=List[ChatTitleIdResponse])
async def get_archived_session_user_chat_list( async def get_archived_session_user_chat_list(
user=Depends(get_current_user), skip: int = 0, limit: int = 50 user=Depends(get_verified_user), skip: int = 0, limit: int = 50
): ):
return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit) return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
@ -162,7 +162,7 @@ async def get_archived_session_user_chat_list(
@router.post("/archive/all", response_model=bool) @router.post("/archive/all", response_model=bool)
async def archive_all_chats(user=Depends(get_current_user)): async def archive_all_chats(user=Depends(get_verified_user)):
return Chats.archive_all_chats_by_user_id(user.id) return Chats.archive_all_chats_by_user_id(user.id)
@ -172,7 +172,7 @@ async def archive_all_chats(user=Depends(get_current_user)):
@router.get("/share/{share_id}", response_model=Optional[ChatResponse]) @router.get("/share/{share_id}", response_model=Optional[ChatResponse])
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)): async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
if user.role == "pending": if user.role == "pending":
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@ -204,7 +204,7 @@ class TagNameForm(BaseModel):
@router.post("/tags", response_model=List[ChatTitleIdResponse]) @router.post("/tags", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name( async def get_user_chat_list_by_tag_name(
form_data: TagNameForm, user=Depends(get_current_user) form_data: TagNameForm, user=Depends(get_verified_user)
): ):
print(form_data) print(form_data)
@ -229,7 +229,7 @@ async def get_user_chat_list_by_tag_name(
@router.get("/tags/all", response_model=List[TagModel]) @router.get("/tags/all", response_model=List[TagModel])
async def get_all_tags(user=Depends(get_current_user)): async def get_all_tags(user=Depends(get_verified_user)):
try: try:
tags = Tags.get_tags_by_user_id(user.id) tags = Tags.get_tags_by_user_id(user.id)
return tags return tags
@ -246,7 +246,7 @@ async def get_all_tags(user=Depends(get_current_user)):
@router.get("/{id}", response_model=Optional[ChatResponse]) @router.get("/{id}", response_model=Optional[ChatResponse])
async def get_chat_by_id(id: str, user=Depends(get_current_user)): async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
@ -264,7 +264,7 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)):
@router.post("/{id}", response_model=Optional[ChatResponse]) @router.post("/{id}", response_model=Optional[ChatResponse])
async def update_chat_by_id( async def update_chat_by_id(
id: str, form_data: ChatForm, user=Depends(get_current_user) id: str, form_data: ChatForm, user=Depends(get_verified_user)
): ):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
@ -285,7 +285,7 @@ async def update_chat_by_id(
@router.delete("/{id}", response_model=bool) @router.delete("/{id}", response_model=bool)
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_user)): async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
if user.role == "admin": if user.role == "admin":
result = Chats.delete_chat_by_id(id) result = Chats.delete_chat_by_id(id)
@ -307,7 +307,7 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_
@router.get("/{id}/clone", response_model=Optional[ChatResponse]) @router.get("/{id}/clone", response_model=Optional[ChatResponse])
async def clone_chat_by_id(id: str, user=Depends(get_current_user)): async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
@ -333,7 +333,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
@router.get("/{id}/archive", response_model=Optional[ChatResponse]) @router.get("/{id}/archive", response_model=Optional[ChatResponse])
async def archive_chat_by_id(id: str, user=Depends(get_current_user)): async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
chat = Chats.toggle_chat_archive_by_id(id) chat = Chats.toggle_chat_archive_by_id(id)
@ -350,7 +350,7 @@ async def archive_chat_by_id(id: str, user=Depends(get_current_user)):
@router.post("/{id}/share", response_model=Optional[ChatResponse]) @router.post("/{id}/share", response_model=Optional[ChatResponse])
async def share_chat_by_id(id: str, user=Depends(get_current_user)): async def share_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
if chat.share_id: if chat.share_id:
@ -382,7 +382,7 @@ async def share_chat_by_id(id: str, user=Depends(get_current_user)):
@router.delete("/{id}/share", response_model=Optional[bool]) @router.delete("/{id}/share", response_model=Optional[bool])
async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)): async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
if not chat.share_id: if not chat.share_id:
@ -405,7 +405,7 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
@router.get("/{id}/tags", response_model=List[TagModel]) @router.get("/{id}/tags", response_model=List[TagModel])
async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)): async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
if tags != None: if tags != None:
@ -423,7 +423,7 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel]) @router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
async def add_chat_tag_by_id( async def add_chat_tag_by_id(
id: str, form_data: ChatIdTagForm, user=Depends(get_current_user) id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
): ):
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
@ -450,7 +450,7 @@ async def add_chat_tag_by_id(
@router.delete("/{id}/tags", response_model=Optional[bool]) @router.delete("/{id}/tags", response_model=Optional[bool])
async def delete_chat_tag_by_id( async def delete_chat_tag_by_id(
id: str, form_data: ChatIdTagForm, user=Depends(get_current_user) id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
): ):
result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id( result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
form_data.tag_name, id, user.id form_data.tag_name, id, user.id
@ -470,7 +470,7 @@ async def delete_chat_tag_by_id(
@router.delete("/{id}/tags/all", response_model=Optional[bool]) @router.delete("/{id}/tags/all", response_model=Optional[bool])
async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)): async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id) result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
if result: if result:

View File

@ -14,7 +14,7 @@ from apps.webui.models.users import Users
from utils.utils import ( from utils.utils import (
get_password_hash, get_password_hash,
get_current_user, get_verified_user,
get_admin_user, get_admin_user,
create_token, create_token,
) )
@ -84,6 +84,6 @@ async def set_banners(
@router.get("/banners", response_model=List[BannerModel]) @router.get("/banners", response_model=List[BannerModel])
async def get_banners( async def get_banners(
request: Request, request: Request,
user=Depends(get_current_user), user=Depends(get_verified_user),
): ):
return request.app.state.config.BANNERS return request.app.state.config.BANNERS

View File

@ -14,7 +14,7 @@ from apps.webui.models.documents import (
DocumentResponse, DocumentResponse,
) )
from utils.utils import get_current_user, get_admin_user from utils.utils import get_verified_user, get_admin_user
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
@ -25,7 +25,7 @@ router = APIRouter()
@router.get("/", response_model=List[DocumentResponse]) @router.get("/", response_model=List[DocumentResponse])
async def get_documents(user=Depends(get_current_user)): async def get_documents(user=Depends(get_verified_user)):
docs = [ docs = [
DocumentResponse( DocumentResponse(
**{ **{
@ -74,7 +74,7 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
@router.get("/doc", response_model=Optional[DocumentResponse]) @router.get("/doc", response_model=Optional[DocumentResponse])
async def get_doc_by_name(name: str, user=Depends(get_current_user)): async def get_doc_by_name(name: str, user=Depends(get_verified_user)):
doc = Documents.get_doc_by_name(name) doc = Documents.get_doc_by_name(name)
if doc: if doc:
@ -106,7 +106,7 @@ class TagDocumentForm(BaseModel):
@router.post("/doc/tags", response_model=Optional[DocumentResponse]) @router.post("/doc/tags", response_model=Optional[DocumentResponse])
async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)): async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_verified_user)):
doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags}) doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags})
if doc: if doc:

View File

@ -8,7 +8,7 @@ import json
from apps.webui.models.prompts import Prompts, PromptForm, PromptModel from apps.webui.models.prompts import Prompts, PromptForm, PromptModel
from utils.utils import get_current_user, get_admin_user from utils.utils import get_verified_user, get_admin_user
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
@ -19,7 +19,7 @@ router = APIRouter()
@router.get("/", response_model=List[PromptModel]) @router.get("/", response_model=List[PromptModel])
async def get_prompts(user=Depends(get_current_user)): async def get_prompts(user=Depends(get_verified_user)):
return Prompts.get_prompts() return Prompts.get_prompts()
@ -52,7 +52,7 @@ async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user))
@router.get("/command/{command}", response_model=Optional[PromptModel]) @router.get("/command/{command}", response_model=Optional[PromptModel])
async def get_prompt_by_command(command: str, user=Depends(get_current_user)): async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
prompt = Prompts.get_prompt_by_command(f"/{command}") prompt = Prompts.get_prompt_by_command(f"/{command}")
if prompt: if prompt: