diff --git a/backend/main.py b/backend/main.py index 3f72c5710..fa9563e13 100644 --- a/backend/main.py +++ b/backend/main.py @@ -185,39 +185,48 @@ async def get_function_call_response(prompt, tool_id, template, task_model_id, u model = app.state.MODELS[task_model_id] response = None - if model["owned_by"] == "ollama": - response = await generate_ollama_chat_completion( - OpenAIChatCompletionForm(**payload), user=user - ) - else: - response = await generate_openai_chat_completion(payload, user=user) + try: + if model["owned_by"] == "ollama": + response = await generate_ollama_chat_completion( + OpenAIChatCompletionForm(**payload), user=user + ) + else: + response = await generate_openai_chat_completion(payload, user=user) - print(response) - content = response["choices"][0]["message"]["content"] + content = None + async for chunk in response.body_iterator: + data = json.loads(chunk.decode("utf-8")) + content = data["choices"][0]["message"]["content"] - # Parse the function response - if content != "": - result = json.loads(content) - print(result) + # Cleanup any remaining background tasks if necessary + if response.background is not None: + await response.background() - # Call the function - if "name" in result: - if tool_id in webui_app.state.TOOLS: - toolkit_module = webui_app.state.TOOLS[tool_id] - else: - toolkit_module = load_toolkit_module_by_id(tool_id) - webui_app.state.TOOLS[tool_id] = toolkit_module + # Parse the function response + if content is not None: + result = json.loads(content) + print(result) - function = getattr(toolkit_module, result["name"]) - function_result = None - try: - function_result = function(**result["parameters"]) - except Exception as e: - print(e) + # Call the function + if "name" in result: + if tool_id in webui_app.state.TOOLS: + toolkit_module = webui_app.state.TOOLS[tool_id] + else: + toolkit_module = load_toolkit_module_by_id(tool_id) + webui_app.state.TOOLS[tool_id] = toolkit_module - # Add the function result to the system prompt - if function_result: - return function_result + function = getattr(toolkit_module, result["name"]) + function_result = None + try: + function_result = function(**result["parameters"]) + except Exception as e: + print(e) + + # Add the function result to the system prompt + if function_result: + return function_result + except Exception as e: + print(f"Error: {e}") return None @@ -285,15 +294,18 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): print(response) if response: - context += f"\n{response}" + context = ("\n" if context != "" else "") + response - system_prompt = rag_template( - rag_app.state.config.RAG_TEMPLATE, context, prompt - ) + if context != "": + system_prompt = rag_template( + rag_app.state.config.RAG_TEMPLATE, context, prompt + ) - data["messages"] = add_or_update_system_message( - system_prompt, data["messages"] - ) + print(system_prompt) + + data["messages"] = add_or_update_system_message( + f"\n{system_prompt}", data["messages"] + ) del data["tool_ids"] diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 1a0b0d894..3c4c75967 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -73,6 +73,7 @@ let selectedModels = ['']; let atSelectedModel: Model | undefined; + let selectedToolIds = []; let webSearchEnabled = false; let chat = null; @@ -687,6 +688,7 @@ }, format: $settings.requestFormat ?? undefined, keep_alive: $settings.keepAlive ?? undefined, + tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, docs: docs.length > 0 ? docs : undefined, citations: docs.length > 0, chat_id: $chatId @@ -948,6 +950,7 @@ top_p: $settings?.params?.top_p ?? undefined, frequency_penalty: $settings?.params?.frequency_penalty ?? undefined, max_tokens: $settings?.params?.max_tokens ?? undefined, + tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, docs: docs.length > 0 ? docs : undefined, citations: docs.length > 0, chat_id: $chatId @@ -1274,6 +1277,7 @@ bind:files bind:prompt bind:autoScroll + bind:selectedToolIds bind:webSearchEnabled bind:atSelectedModel {selectedModels} diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index b3ceb3e91..c5dc780ab 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -8,7 +8,8 @@ showSidebar, models, config, - showCallOverlay + showCallOverlay, + tools } from '$lib/stores'; import { blobToFile, calculateSHA256, findWordIndices } from '$lib/utils'; @@ -57,6 +58,7 @@ let chatInputPlaceholder = ''; export let files = []; + export let selectedToolIds = []; export let webSearchEnabled = false; @@ -653,6 +655,15 @@
{ + a[e.id] = { + name: e.name, + enabled: false + }; + + return a; + }, {})} uploadFilesHandler={() => { filesInputElement.click(); }} diff --git a/src/lib/components/chat/MessageInput/InputMenu.svelte b/src/lib/components/chat/MessageInput/InputMenu.svelte index 811e4d27d..5d43d4648 100644 --- a/src/lib/components/chat/MessageInput/InputMenu.svelte +++ b/src/lib/components/chat/MessageInput/InputMenu.svelte @@ -14,6 +14,8 @@ const i18n = getContext('i18n'); export let uploadFilesHandler: Function; + + export let selectedToolIds: string[] = []; export let webSearchEnabled: boolean; export let tools = {}; @@ -44,16 +46,23 @@ transition={flyAndScale} > {#if Object.keys(tools).length > 0} - {#each Object.keys(tools) as tool} + {#each Object.keys(tools) as toolId}
-
{tool}
+
{tools[toolId].name}
- + { + selectedToolIds = e.detail + ? [...selectedToolIds, toolId] + : selectedToolIds.filter((id) => id !== toolId); + }} + />
{/each}