mirror of
https://github.com/open-webui/open-webui
synced 2025-06-08 15:37:22 +00:00
refac
This commit is contained in:
parent
e71f55e58f
commit
15f3ebba93
100
backend/main.py
100
backend/main.py
@ -218,25 +218,6 @@ origins = ["*"]
|
|||||||
##################################
|
##################################
|
||||||
|
|
||||||
|
|
||||||
async def get_body_and_model_and_user(request):
|
|
||||||
# Read the original request body
|
|
||||||
body = await request.body()
|
|
||||||
body_str = body.decode("utf-8")
|
|
||||||
body = json.loads(body_str) if body_str else {}
|
|
||||||
|
|
||||||
model_id = body["model"]
|
|
||||||
if model_id not in app.state.MODELS:
|
|
||||||
raise Exception("Model not found")
|
|
||||||
model = app.state.MODELS[model_id]
|
|
||||||
|
|
||||||
user = get_current_user(
|
|
||||||
request,
|
|
||||||
get_http_authorization_cred(request.headers.get("Authorization")),
|
|
||||||
)
|
|
||||||
|
|
||||||
return body, model, user
|
|
||||||
|
|
||||||
|
|
||||||
def get_task_model_id(default_model_id):
|
def get_task_model_id(default_model_id):
|
||||||
# Set the task model
|
# Set the task model
|
||||||
task_model_id = default_model_id
|
task_model_id = default_model_id
|
||||||
@ -283,26 +264,6 @@ def get_filter_function_ids(model):
|
|||||||
return filter_ids
|
return filter_ids
|
||||||
|
|
||||||
|
|
||||||
def get_tools_function_calling_payload(messages, task_model_id, content):
|
|
||||||
user_message = get_last_user_message(messages)
|
|
||||||
history = "\n".join(
|
|
||||||
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
|
|
||||||
for message in messages[::-1][:4]
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt = f"History:\n{history}\nQuery: {user_message}"
|
|
||||||
|
|
||||||
return {
|
|
||||||
"model": task_model_id,
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": content},
|
|
||||||
{"role": "user", "content": f"Query: {prompt}"},
|
|
||||||
],
|
|
||||||
"stream": False,
|
|
||||||
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_filter_functions_handler(body, model, extra_params):
|
async def chat_completion_filter_functions_handler(body, model, extra_params):
|
||||||
skip_files = None
|
skip_files = None
|
||||||
|
|
||||||
@ -369,12 +330,32 @@ async def chat_completion_filter_functions_handler(body, model, extra_params):
|
|||||||
return body, {}
|
return body, {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_tools_function_calling_payload(messages, task_model_id, content):
|
||||||
|
user_message = get_last_user_message(messages)
|
||||||
|
history = "\n".join(
|
||||||
|
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
|
||||||
|
for message in messages[::-1][:4]
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = f"History:\n{history}\nQuery: {user_message}"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"model": task_model_id,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": content},
|
||||||
|
{"role": "user", "content": f"Query: {prompt}"},
|
||||||
|
],
|
||||||
|
"stream": False,
|
||||||
|
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def apply_extra_params_to_tool_function(
|
def apply_extra_params_to_tool_function(
|
||||||
function: Callable, custom_params: dict
|
function: Callable, extra_params: dict
|
||||||
) -> Callable[..., Awaitable]:
|
) -> Callable[..., Awaitable]:
|
||||||
sig = inspect.signature(function)
|
sig = inspect.signature(function)
|
||||||
extra_params = {
|
extra_params = {
|
||||||
key: value for key, value in custom_params.items() if key in sig.parameters
|
key: value for key, value in extra_params.items() if key in sig.parameters
|
||||||
}
|
}
|
||||||
is_coroutine = inspect.iscoroutinefunction(function)
|
is_coroutine = inspect.iscoroutinefunction(function)
|
||||||
|
|
||||||
@ -511,27 +492,27 @@ async def chat_completion_tools_handler(
|
|||||||
return body, {}
|
return body, {}
|
||||||
|
|
||||||
result = json.loads(content)
|
result = json.loads(content)
|
||||||
tool_name = result.get("name", None)
|
|
||||||
if tool_name not in tools:
|
tool_function_name = result.get("name", None)
|
||||||
|
if tool_function_name not in tools:
|
||||||
return body, {}
|
return body, {}
|
||||||
|
|
||||||
tool_params = result.get("parameters", {})
|
tool_function_params = result.get("parameters", {})
|
||||||
toolkit_id = tools[tool_name]["toolkit_id"]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tool_output = await tools[tool_name]["callable"](**tool_params)
|
tool_output = await tools[tool_function_name]["callable"](**tool_function_params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tool_output = str(e)
|
tool_output = str(e)
|
||||||
|
|
||||||
if tools[tool_name]["citation"]:
|
if tools[tool_function_name]["citation"]:
|
||||||
citations.append(
|
citations.append(
|
||||||
{
|
{
|
||||||
"source": {"name": f"TOOL:{toolkit_id}/{tool_name}"},
|
"source": {"name": f"TOOL:{tools[tool_function_name]["toolkit_id"]}/{tool_function_name}"},
|
||||||
"document": [tool_output],
|
"document": [tool_output],
|
||||||
"metadata": [{"source": tool_name}],
|
"metadata": [{"source": tool_function_name}],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if tools[tool_name]["file_handler"]:
|
if tools[tool_function_name]["file_handler"]:
|
||||||
skip_files = True
|
skip_files = True
|
||||||
|
|
||||||
if isinstance(tool_output, str):
|
if isinstance(tool_output, str):
|
||||||
@ -576,6 +557,25 @@ def is_chat_completion_request(request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_body_and_model_and_user(request):
|
||||||
|
# Read the original request body
|
||||||
|
body = await request.body()
|
||||||
|
body_str = body.decode("utf-8")
|
||||||
|
body = json.loads(body_str) if body_str else {}
|
||||||
|
|
||||||
|
model_id = body["model"]
|
||||||
|
if model_id not in app.state.MODELS:
|
||||||
|
raise Exception("Model not found")
|
||||||
|
model = app.state.MODELS[model_id]
|
||||||
|
|
||||||
|
user = get_current_user(
|
||||||
|
request,
|
||||||
|
get_http_authorization_cred(request.headers.get("Authorization")),
|
||||||
|
)
|
||||||
|
|
||||||
|
return body, model, user
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
if not is_chat_completion_request(request):
|
if not is_chat_completion_request(request):
|
||||||
|
Loading…
Reference in New Issue
Block a user