feat: model update

This commit is contained in:
Timothy J. Baek
2024-05-24 18:26:36 -07:00
parent 0a48114bd2
commit 708d755eda
8 changed files with 396 additions and 407 deletions

View File

@@ -40,6 +40,9 @@ app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.MODELS = {}
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER

View File

@@ -33,6 +33,8 @@ class ModelParams(BaseModel):
# ModelMeta is a model for the data stored in the meta field of the Model table
# It isn't currently used in the backend, but it's here as a reference
class ModelMeta(BaseModel):
profile_image_url: Optional[str] = "/favicon.png"
description: Optional[str] = None
"""
User-facing description of the model.
@@ -84,6 +86,7 @@ class Model(pw.Model):
class ModelModel(BaseModel):
id: str
user_id: str
base_model_id: Optional[str] = None
name: str
@@ -123,18 +126,26 @@ class ModelsTable:
self.db = db
self.db.create_tables([Model])
def insert_new_model(self, model: ModelForm, user_id: str) -> Optional[ModelModel]:
def insert_new_model(
self, form_data: ModelForm, user_id: str
) -> Optional[ModelModel]:
model = ModelModel(
**{
**form_data.model_dump(),
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
try:
model = Model.create(
**{
**model.model_dump(),
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
return ModelModel(**model_to_dict(model))
except:
result = Model.create(**model.model_dump())
if result:
return model
else:
return None
except Exception as e:
print(e)
return None
def get_all_models(self) -> List[ModelModel]:

View File

@@ -1,4 +1,4 @@
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi import Depends, FastAPI, HTTPException, status, Request
from datetime import datetime, timedelta
from typing import List, Union, Optional
@@ -65,17 +65,28 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/update", response_model=Optional[ModelModel])
async def update_model_by_id(
id: str, form_data: ModelForm, user=Depends(get_admin_user)
request: Request, id: str, form_data: ModelForm, user=Depends(get_admin_user)
):
model = Models.get_model_by_id(id)
if model:
model = Models.update_model_by_id(id, form_data)
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if form_data.id in request.app.state.MODELS:
model = Models.insert_new_model(form_data, user.id)
print(model)
if model:
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
############################

View File

@@ -122,6 +122,9 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.MODELS = {}
origins = ["*"]
@@ -238,6 +241,11 @@ app.add_middleware(
@app.middleware("http")
async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0:
await get_all_models()
else:
pass
start_time = int(time.time())
response = await call_next(request)
process_time = int(time.time()) - start_time
@@ -269,8 +277,7 @@ app.mount("/api/v1", webui_app)
webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
@app.get("/api/models")
async def get_models(user=Depends(get_verified_user)):
async def get_all_models():
openai_models = []
ollama_models = []
@@ -282,8 +289,6 @@ async def get_models(user=Depends(get_verified_user)):
if app.state.config.ENABLE_OLLAMA_API:
ollama_models = await get_ollama_models()
print(ollama_models)
ollama_models = [
{
"id": model["model"],
@@ -296,9 +301,6 @@ async def get_models(user=Depends(get_verified_user)):
for model in ollama_models["models"]
]
print("openai", openai_models)
print("ollama", ollama_models)
models = openai_models + ollama_models
custom_models = Models.get_all_models()
@@ -330,6 +332,16 @@ async def get_models(user=Depends(get_verified_user)):
}
)
app.state.MODELS = {model["id"]: model for model in models}
webui_app.state.MODELS = app.state.MODELS
return models
@app.get("/api/models")
async def get_models(user=Depends(get_verified_user)):
models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user":
models = list(