From 1d20c27553f019477f01d7233ebe40b11d31e479 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Tue, 9 Jul 2024 16:08:54 +0100 Subject: [PATCH] refac: use get_task_model_id() --- backend/main.py | 44 ++++---------------------------------------- 1 file changed, 4 insertions(+), 40 deletions(-) diff --git a/backend/main.py b/backend/main.py index eb1e3ffb8..89252e164 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1293,16 +1293,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): # Check if the user has a custom task model # If the user has a custom task model, use that model - if app.state.MODELS[model_id]["owned_by"] == "ollama": - if app.state.config.TASK_MODEL: - task_model_id = app.state.config.TASK_MODEL - if task_model_id in app.state.MODELS: - model_id = task_model_id - else: - if app.state.config.TASK_MODEL_EXTERNAL: - task_model_id = app.state.config.TASK_MODEL_EXTERNAL - if task_model_id in app.state.MODELS: - model_id = task_model_id + model_id = get_task_model_id(model_id) print(model_id) @@ -1361,16 +1352,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) # Check if the user has a custom task model # If the user has a custom task model, use that model - if app.state.MODELS[model_id]["owned_by"] == "ollama": - if app.state.config.TASK_MODEL: - task_model_id = app.state.config.TASK_MODEL - if task_model_id in app.state.MODELS: - model_id = task_model_id - else: - if app.state.config.TASK_MODEL_EXTERNAL: - task_model_id = app.state.config.TASK_MODEL_EXTERNAL - if task_model_id in app.state.MODELS: - model_id = task_model_id + model_id = get_task_model_id(model_id) print(model_id) @@ -1417,16 +1399,7 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): # Check if the user has a custom task model # If the user has a custom task model, use that model - if app.state.MODELS[model_id]["owned_by"] == "ollama": - if app.state.config.TASK_MODEL: - task_model_id = app.state.config.TASK_MODEL - if task_model_id in app.state.MODELS: - model_id = task_model_id - else: - if app.state.config.TASK_MODEL_EXTERNAL: - task_model_id = app.state.config.TASK_MODEL_EXTERNAL - if task_model_id in app.state.MODELS: - model_id = task_model_id + model_id = get_task_model_id(model_id) print(model_id) @@ -1483,16 +1456,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ # Check if the user has a custom task model # If the user has a custom task model, use that model - if app.state.MODELS[model_id]["owned_by"] == "ollama": - if app.state.config.TASK_MODEL: - task_model_id = app.state.config.TASK_MODEL - if task_model_id in app.state.MODELS: - model_id = task_model_id - else: - if app.state.config.TASK_MODEL_EXTERNAL: - task_model_id = app.state.config.TASK_MODEL_EXTERNAL - if task_model_id in app.state.MODELS: - model_id = task_model_id + model_id = get_task_model_id(model_id) print(model_id) template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE