refac: task flag

Co-Authored-By: Michael Poluektov <78477503+michaelpoluektov@users.noreply.github.com>
This commit is contained in:
Timothy J. Baek 2024-07-03 15:46:56 -07:00
parent d0e0aba593
commit c83704d6ca
2 changed files with 16 additions and 8 deletions

View File

@ -89,3 +89,14 @@ class ERROR_MESSAGES(str, Enum):
OLLAMA_API_DISABLED = ( OLLAMA_API_DISABLED = (
"The Ollama API is disabled. Please enable it to use this feature." "The Ollama API is disabled. Please enable it to use this feature."
) )
class TASKS(str, Enum):
def __str__(self) -> str:
return super().__str__()
DEFAULT = lambda task="": f"{task if task else 'default'}"
TITLE_GENERATION = "Title Generation"
EMOJI_GENERATION = "Emoji Generation"
QUERY_GENERATION = "Query Generation"
FUNCTION_CALLING = "Function Calling"

View File

@ -126,7 +126,7 @@ from config import (
WEBUI_SESSION_COOKIE_SECURE, WEBUI_SESSION_COOKIE_SECURE,
AppConfig, AppConfig,
) )
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
from utils.webhook import post_webhook from utils.webhook import post_webhook
if SAFE_MODE: if SAFE_MODE:
@ -311,6 +311,7 @@ async def get_function_call_response(
{"role": "user", "content": f"Query: {prompt}"}, {"role": "user", "content": f"Query: {prompt}"},
], ],
"stream": False, "stream": False,
"task": TASKS.FUNCTION_CALLING,
} }
try: try:
@ -323,7 +324,6 @@ async def get_function_call_response(
response = None response = None
try: try:
response = await generate_chat_completions(form_data=payload, user=user) response = await generate_chat_completions(form_data=payload, user=user)
content = None content = None
if hasattr(response, "body_iterator"): if hasattr(response, "body_iterator"):
@ -833,9 +833,6 @@ def filter_pipeline(payload, user):
pass pass
if "pipeline" not in app.state.MODELS[model_id]: if "pipeline" not in app.state.MODELS[model_id]:
if "title" in payload:
del payload["title"]
if "task" in payload: if "task" in payload:
del payload["task"] del payload["task"]
@ -1338,7 +1335,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
"stream": False, "stream": False,
"max_tokens": 50, "max_tokens": 50,
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
"title": True, "task": TASKS.TITLE_GENERATION,
} }
log.debug(payload) log.debug(payload)
@ -1401,7 +1398,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
"messages": [{"role": "user", "content": content}], "messages": [{"role": "user", "content": content}],
"stream": False, "stream": False,
"max_tokens": 30, "max_tokens": 30,
"task": True, "task": TASKS.QUERY_GENERATION,
} }
print(payload) print(payload)
@ -1468,7 +1465,7 @@ Message: """{{prompt}}"""
"stream": False, "stream": False,
"max_tokens": 4, "max_tokens": 4,
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
"task": True, "task": TASKS.EMOJI_GENERATION,
} }
log.debug(payload) log.debug(payload)