mirror of
https://github.com/open-webui/open-webui
synced 2024-11-07 00:59:52 +00:00
minor refac
This commit is contained in:
parent
fdc89cbcee
commit
4042219b3e
@ -378,9 +378,8 @@ async def chat_completion_inlets_handler(body, model, extra_params):
|
|||||||
print(f"Error: {e}")
|
print(f"Error: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if skip_files:
|
if skip_files and "files" in body:
|
||||||
if "files" in body:
|
del body["files"]
|
||||||
del body["files"]
|
|
||||||
|
|
||||||
return body, {}
|
return body, {}
|
||||||
|
|
||||||
@ -431,12 +430,17 @@ def get_configured_tools(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for spec in toolkit.specs:
|
for spec in toolkit.specs:
|
||||||
|
# TODO: Fix hack for OpenAI API
|
||||||
|
for val in spec.get("parameters", {}).get("properties", {}).values():
|
||||||
|
if val["type"] == "str":
|
||||||
|
val["type"] = "string"
|
||||||
name = spec["name"]
|
name = spec["name"]
|
||||||
callable = getattr(module, name)
|
callable = getattr(module, name)
|
||||||
|
|
||||||
# convert to function that takes only model params and inserts custom params
|
# convert to function that takes only model params and inserts custom params
|
||||||
custom_callable = get_tool_with_custom_params(callable, extra_params)
|
custom_callable = get_tool_with_custom_params(callable, extra_params)
|
||||||
|
|
||||||
|
# TODO: This needs to be a pydantic model
|
||||||
tool_dict = {
|
tool_dict = {
|
||||||
"spec": spec,
|
"spec": spec,
|
||||||
"citation": has_citation,
|
"citation": has_citation,
|
||||||
@ -444,6 +448,7 @@ def get_configured_tools(
|
|||||||
"toolkit_id": tool_id,
|
"toolkit_id": tool_id,
|
||||||
"callable": custom_callable,
|
"callable": custom_callable,
|
||||||
}
|
}
|
||||||
|
# TODO: if collision, prepend toolkit name
|
||||||
if name in tools:
|
if name in tools:
|
||||||
log.warning(f"Tool {name} already exists in another toolkit!")
|
log.warning(f"Tool {name} already exists in another toolkit!")
|
||||||
log.warning(f"Collision between {toolkit} and {tool_id}.")
|
log.warning(f"Collision between {toolkit} and {tool_id}.")
|
||||||
@ -533,9 +538,9 @@ async def chat_completion_tools_handler(
|
|||||||
return body, {"contexts": contexts, "citations": citations}
|
return body, {"contexts": contexts, "citations": citations}
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_files_handler(body):
|
async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
|
||||||
contexts = []
|
contexts = []
|
||||||
citations = None
|
citations = []
|
||||||
|
|
||||||
if files := body.pop("files", None):
|
if files := body.pop("files", None):
|
||||||
contexts, citations = get_rag_context(
|
contexts, citations = get_rag_context(
|
||||||
@ -550,10 +555,7 @@ async def chat_completion_files_handler(body):
|
|||||||
|
|
||||||
log.debug(f"rag_contexts: {contexts}, citations: {citations}")
|
log.debug(f"rag_contexts: {contexts}, citations: {citations}")
|
||||||
|
|
||||||
return body, {
|
return body, {"contexts": contexts, "citations": citations}
|
||||||
**({"contexts": contexts} if contexts is not None else {}),
|
|
||||||
**({"citations": citations} if citations is not None else {}),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def is_chat_completion_request(request):
|
def is_chat_completion_request(request):
|
||||||
@ -618,16 +620,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
contexts.extend(flags.get("contexts", []))
|
contexts.extend(flags.get("contexts", []))
|
||||||
citations.extend(flags.get("citations", []))
|
citations.extend(flags.get("citations", []))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(e)
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
body, flags = await chat_completion_files_handler(body)
|
body, flags = await chat_completion_files_handler(body)
|
||||||
contexts.extend(flags.get("contexts", []))
|
contexts.extend(flags.get("contexts", []))
|
||||||
citations.extend(flags.get("citations", []))
|
citations.extend(flags.get("citations", []))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.exception(e)
|
||||||
pass
|
|
||||||
|
|
||||||
# If context is not empty, insert it into the messages
|
# If context is not empty, insert it into the messages
|
||||||
if len(contexts) > 0:
|
if len(contexts) > 0:
|
||||||
|
Loading…
Reference in New Issue
Block a user