mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
feat: create model
This commit is contained in:
@@ -39,6 +39,8 @@ from utils.utils import (
|
||||
get_admin_user,
|
||||
)
|
||||
|
||||
from utils.models import get_model_id_from_custom_model_id
|
||||
|
||||
|
||||
from config import (
|
||||
SRC_LOG_LEVELS,
|
||||
@@ -873,10 +875,10 @@ async def generate_chat_completion(
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
model_id = get_model_id_from_custom_model_id(form_data.model)
|
||||
model = model_id
|
||||
|
||||
if url_idx == None:
|
||||
model = form_data.model
|
||||
|
||||
if ":" not in model:
|
||||
model = f"{model}:latest"
|
||||
|
||||
@@ -893,6 +895,13 @@ async def generate_chat_completion(
|
||||
|
||||
r = None
|
||||
|
||||
# payload = {
|
||||
# **form_data.model_dump_json(exclude_none=True).encode(),
|
||||
# "model": model,
|
||||
# "messages": form_data.messages,
|
||||
|
||||
# }
|
||||
|
||||
log.debug(
|
||||
"form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
|
||||
form_data.model_dump_json(exclude_none=True).encode()
|
||||
|
||||
@@ -166,7 +166,9 @@ class ModelsTable:
|
||||
|
||||
model = Model.get(Model.id == id)
|
||||
return ModelModel(**model_to_dict(model))
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
return None
|
||||
|
||||
def delete_model_by_id(self, id: str) -> bool:
|
||||
|
||||
@@ -28,16 +28,24 @@ async def get_models(user=Depends(get_verified_user)):
|
||||
|
||||
|
||||
@router.post("/add", response_model=Optional[ModelModel])
|
||||
async def add_new_model(form_data: ModelForm, user=Depends(get_admin_user)):
|
||||
model = Models.insert_new_model(form_data, user.id)
|
||||
|
||||
if model:
|
||||
return model
|
||||
else:
|
||||
async def add_new_model(
|
||||
request: Request, form_data: ModelForm, user=Depends(get_admin_user)
|
||||
):
|
||||
if form_data.id in request.app.state.MODELS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.DEFAULT(),
|
||||
detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
|
||||
)
|
||||
else:
|
||||
model = Models.insert_new_model(form_data, user.id)
|
||||
|
||||
if model:
|
||||
return model
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.DEFAULT(),
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
|
||||
@@ -32,6 +32,8 @@ class ERROR_MESSAGES(str, Enum):
|
||||
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
|
||||
FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file."
|
||||
|
||||
MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string."
|
||||
|
||||
NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string."
|
||||
INVALID_TOKEN = (
|
||||
"Your session has expired or the token is invalid. Please sign in again."
|
||||
|
||||
10
backend/utils/models.py
Normal file
10
backend/utils/models.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from apps.web.models.models import Models, ModelModel, ModelForm, ModelResponse
|
||||
|
||||
|
||||
def get_model_id_from_custom_model_id(id: str):
|
||||
model = Models.get_model_by_id(id)
|
||||
|
||||
if model:
|
||||
return model.id
|
||||
else:
|
||||
return id
|
||||
Reference in New Issue
Block a user