diff --git a/backend/constants.py b/backend/constants.py index f1eed43d3..7c366c222 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -89,3 +89,14 @@ class ERROR_MESSAGES(str, Enum): OLLAMA_API_DISABLED = ( "The Ollama API is disabled. Please enable it to use this feature." ) + + +class TASKS(str, Enum): + def __str__(self) -> str: + return super().__str__() + + DEFAULT = lambda task="": f"{task if task else 'default'}" + TITLE_GENERATION = "Title Generation" + EMOJI_GENERATION = "Emoji Generation" + QUERY_GENERATION = "Query Generation" + FUNCTION_CALLING = "Function Calling" diff --git a/backend/main.py b/backend/main.py index 0e3986f21..49e068a75 100644 --- a/backend/main.py +++ b/backend/main.py @@ -126,7 +126,7 @@ from config import ( WEBUI_SESSION_COOKIE_SECURE, AppConfig, ) -from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES +from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS from utils.webhook import post_webhook if SAFE_MODE: @@ -311,6 +311,7 @@ async def get_function_call_response( {"role": "user", "content": f"Query: {prompt}"}, ], "stream": False, + "task": TASKS.FUNCTION_CALLING, } try: @@ -323,7 +324,6 @@ async def get_function_call_response( response = None try: response = await generate_chat_completions(form_data=payload, user=user) - content = None if hasattr(response, "body_iterator"): @@ -833,9 +833,6 @@ def filter_pipeline(payload, user): pass if "pipeline" not in app.state.MODELS[model_id]: - if "title" in payload: - del payload["title"] - if "task" in payload: del payload["task"] @@ -1338,7 +1335,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): "stream": False, "max_tokens": 50, "chat_id": form_data.get("chat_id", None), - "title": True, + "task": TASKS.TITLE_GENERATION, } log.debug(payload) @@ -1401,7 +1398,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) "messages": [{"role": "user", "content": content}], "stream": False, "max_tokens": 30, - "task": True, + "task": TASKS.QUERY_GENERATION, } print(payload) @@ -1468,7 +1465,7 @@ Message: """{{prompt}}""" "stream": False, "max_tokens": 4, "chat_id": form_data.get("chat_id", None), - "task": True, + "task": TASKS.EMOJI_GENERATION, } log.debug(payload)