mirror of
https://github.com/open-webui/open-webui
synced 2024-11-16 13:40:55 +00:00
refac: task flag
Co-Authored-By: Michael Poluektov <78477503+michaelpoluektov@users.noreply.github.com>
This commit is contained in:
parent
d0e0aba593
commit
c83704d6ca
@ -89,3 +89,14 @@ class ERROR_MESSAGES(str, Enum):
|
|||||||
OLLAMA_API_DISABLED = (
|
OLLAMA_API_DISABLED = (
|
||||||
"The Ollama API is disabled. Please enable it to use this feature."
|
"The Ollama API is disabled. Please enable it to use this feature."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TASKS(str, Enum):
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return super().__str__()
|
||||||
|
|
||||||
|
DEFAULT = lambda task="": f"{task if task else 'default'}"
|
||||||
|
TITLE_GENERATION = "Title Generation"
|
||||||
|
EMOJI_GENERATION = "Emoji Generation"
|
||||||
|
QUERY_GENERATION = "Query Generation"
|
||||||
|
FUNCTION_CALLING = "Function Calling"
|
||||||
|
@ -126,7 +126,7 @@ from config import (
|
|||||||
WEBUI_SESSION_COOKIE_SECURE,
|
WEBUI_SESSION_COOKIE_SECURE,
|
||||||
AppConfig,
|
AppConfig,
|
||||||
)
|
)
|
||||||
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
|
||||||
from utils.webhook import post_webhook
|
from utils.webhook import post_webhook
|
||||||
|
|
||||||
if SAFE_MODE:
|
if SAFE_MODE:
|
||||||
@ -311,6 +311,7 @@ async def get_function_call_response(
|
|||||||
{"role": "user", "content": f"Query: {prompt}"},
|
{"role": "user", "content": f"Query: {prompt}"},
|
||||||
],
|
],
|
||||||
"stream": False,
|
"stream": False,
|
||||||
|
"task": TASKS.FUNCTION_CALLING,
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -323,7 +324,6 @@ async def get_function_call_response(
|
|||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
response = await generate_chat_completions(form_data=payload, user=user)
|
response = await generate_chat_completions(form_data=payload, user=user)
|
||||||
|
|
||||||
content = None
|
content = None
|
||||||
|
|
||||||
if hasattr(response, "body_iterator"):
|
if hasattr(response, "body_iterator"):
|
||||||
@ -833,9 +833,6 @@ def filter_pipeline(payload, user):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if "pipeline" not in app.state.MODELS[model_id]:
|
if "pipeline" not in app.state.MODELS[model_id]:
|
||||||
if "title" in payload:
|
|
||||||
del payload["title"]
|
|
||||||
|
|
||||||
if "task" in payload:
|
if "task" in payload:
|
||||||
del payload["task"]
|
del payload["task"]
|
||||||
|
|
||||||
@ -1338,7 +1335,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
|||||||
"stream": False,
|
"stream": False,
|
||||||
"max_tokens": 50,
|
"max_tokens": 50,
|
||||||
"chat_id": form_data.get("chat_id", None),
|
"chat_id": form_data.get("chat_id", None),
|
||||||
"title": True,
|
"task": TASKS.TITLE_GENERATION,
|
||||||
}
|
}
|
||||||
|
|
||||||
log.debug(payload)
|
log.debug(payload)
|
||||||
@ -1401,7 +1398,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
|||||||
"messages": [{"role": "user", "content": content}],
|
"messages": [{"role": "user", "content": content}],
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"max_tokens": 30,
|
"max_tokens": 30,
|
||||||
"task": True,
|
"task": TASKS.QUERY_GENERATION,
|
||||||
}
|
}
|
||||||
|
|
||||||
print(payload)
|
print(payload)
|
||||||
@ -1468,7 +1465,7 @@ Message: """{{prompt}}"""
|
|||||||
"stream": False,
|
"stream": False,
|
||||||
"max_tokens": 4,
|
"max_tokens": 4,
|
||||||
"chat_id": form_data.get("chat_id", None),
|
"chat_id": form_data.get("chat_id", None),
|
||||||
"task": True,
|
"task": TASKS.EMOJI_GENERATION,
|
||||||
}
|
}
|
||||||
|
|
||||||
log.debug(payload)
|
log.debug(payload)
|
||||||
|
Loading…
Reference in New Issue
Block a user