mirror of
https://github.com/open-webui/open-webui
synced 2024-11-07 00:59:52 +00:00
refac
This commit is contained in:
parent
e71f55e58f
commit
15f3ebba93
100
backend/main.py
100
backend/main.py
@ -218,25 +218,6 @@ origins = ["*"]
|
||||
##################################
|
||||
|
||||
|
||||
async def get_body_and_model_and_user(request):
|
||||
# Read the original request body
|
||||
body = await request.body()
|
||||
body_str = body.decode("utf-8")
|
||||
body = json.loads(body_str) if body_str else {}
|
||||
|
||||
model_id = body["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
raise Exception("Model not found")
|
||||
model = app.state.MODELS[model_id]
|
||||
|
||||
user = get_current_user(
|
||||
request,
|
||||
get_http_authorization_cred(request.headers.get("Authorization")),
|
||||
)
|
||||
|
||||
return body, model, user
|
||||
|
||||
|
||||
def get_task_model_id(default_model_id):
|
||||
# Set the task model
|
||||
task_model_id = default_model_id
|
||||
@ -283,26 +264,6 @@ def get_filter_function_ids(model):
|
||||
return filter_ids
|
||||
|
||||
|
||||
def get_tools_function_calling_payload(messages, task_model_id, content):
|
||||
user_message = get_last_user_message(messages)
|
||||
history = "\n".join(
|
||||
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
|
||||
for message in messages[::-1][:4]
|
||||
)
|
||||
|
||||
prompt = f"History:\n{history}\nQuery: {user_message}"
|
||||
|
||||
return {
|
||||
"model": task_model_id,
|
||||
"messages": [
|
||||
{"role": "system", "content": content},
|
||||
{"role": "user", "content": f"Query: {prompt}"},
|
||||
],
|
||||
"stream": False,
|
||||
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
||||
}
|
||||
|
||||
|
||||
async def chat_completion_filter_functions_handler(body, model, extra_params):
|
||||
skip_files = None
|
||||
|
||||
@ -369,12 +330,32 @@ async def chat_completion_filter_functions_handler(body, model, extra_params):
|
||||
return body, {}
|
||||
|
||||
|
||||
def get_tools_function_calling_payload(messages, task_model_id, content):
|
||||
user_message = get_last_user_message(messages)
|
||||
history = "\n".join(
|
||||
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
|
||||
for message in messages[::-1][:4]
|
||||
)
|
||||
|
||||
prompt = f"History:\n{history}\nQuery: {user_message}"
|
||||
|
||||
return {
|
||||
"model": task_model_id,
|
||||
"messages": [
|
||||
{"role": "system", "content": content},
|
||||
{"role": "user", "content": f"Query: {prompt}"},
|
||||
],
|
||||
"stream": False,
|
||||
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
||||
}
|
||||
|
||||
|
||||
def apply_extra_params_to_tool_function(
|
||||
function: Callable, custom_params: dict
|
||||
function: Callable, extra_params: dict
|
||||
) -> Callable[..., Awaitable]:
|
||||
sig = inspect.signature(function)
|
||||
extra_params = {
|
||||
key: value for key, value in custom_params.items() if key in sig.parameters
|
||||
key: value for key, value in extra_params.items() if key in sig.parameters
|
||||
}
|
||||
is_coroutine = inspect.iscoroutinefunction(function)
|
||||
|
||||
@ -511,27 +492,27 @@ async def chat_completion_tools_handler(
|
||||
return body, {}
|
||||
|
||||
result = json.loads(content)
|
||||
tool_name = result.get("name", None)
|
||||
if tool_name not in tools:
|
||||
|
||||
tool_function_name = result.get("name", None)
|
||||
if tool_function_name not in tools:
|
||||
return body, {}
|
||||
|
||||
tool_params = result.get("parameters", {})
|
||||
toolkit_id = tools[tool_name]["toolkit_id"]
|
||||
tool_function_params = result.get("parameters", {})
|
||||
|
||||
try:
|
||||
tool_output = await tools[tool_name]["callable"](**tool_params)
|
||||
tool_output = await tools[tool_function_name]["callable"](**tool_function_params)
|
||||
except Exception as e:
|
||||
tool_output = str(e)
|
||||
|
||||
if tools[tool_name]["citation"]:
|
||||
if tools[tool_function_name]["citation"]:
|
||||
citations.append(
|
||||
{
|
||||
"source": {"name": f"TOOL:{toolkit_id}/{tool_name}"},
|
||||
"source": {"name": f"TOOL:{tools[tool_function_name]["toolkit_id"]}/{tool_function_name}"},
|
||||
"document": [tool_output],
|
||||
"metadata": [{"source": tool_name}],
|
||||
"metadata": [{"source": tool_function_name}],
|
||||
}
|
||||
)
|
||||
if tools[tool_name]["file_handler"]:
|
||||
if tools[tool_function_name]["file_handler"]:
|
||||
skip_files = True
|
||||
|
||||
if isinstance(tool_output, str):
|
||||
@ -576,6 +557,25 @@ def is_chat_completion_request(request):
|
||||
)
|
||||
|
||||
|
||||
async def get_body_and_model_and_user(request):
|
||||
# Read the original request body
|
||||
body = await request.body()
|
||||
body_str = body.decode("utf-8")
|
||||
body = json.loads(body_str) if body_str else {}
|
||||
|
||||
model_id = body["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
raise Exception("Model not found")
|
||||
model = app.state.MODELS[model_id]
|
||||
|
||||
user = get_current_user(
|
||||
request,
|
||||
get_http_authorization_cred(request.headers.get("Authorization")),
|
||||
)
|
||||
|
||||
return body, model, user
|
||||
|
||||
|
||||
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if not is_chat_completion_request(request):
|
||||
|
Loading…
Reference in New Issue
Block a user