feat: tools full integration

This commit is contained in:
Timothy J. Baek 2024-06-11 00:18:45 -07:00
parent a27175d672
commit 3d6f5f418d
4 changed files with 75 additions and 39 deletions

View File

@ -185,39 +185,48 @@ async def get_function_call_response(prompt, tool_id, template, task_model_id, u
model = app.state.MODELS[task_model_id]
response = None
if model["owned_by"] == "ollama":
response = await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
response = await generate_openai_chat_completion(payload, user=user)
try:
if model["owned_by"] == "ollama":
response = await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
response = await generate_openai_chat_completion(payload, user=user)
print(response)
content = response["choices"][0]["message"]["content"]
content = None
async for chunk in response.body_iterator:
data = json.loads(chunk.decode("utf-8"))
content = data["choices"][0]["message"]["content"]
# Parse the function response
if content != "":
result = json.loads(content)
print(result)
# Cleanup any remaining background tasks if necessary
if response.background is not None:
await response.background()
# Call the function
if "name" in result:
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
# Parse the function response
if content is not None:
result = json.loads(content)
print(result)
function = getattr(toolkit_module, result["name"])
function_result = None
try:
function_result = function(**result["parameters"])
except Exception as e:
print(e)
# Call the function
if "name" in result:
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
# Add the function result to the system prompt
if function_result:
return function_result
function = getattr(toolkit_module, result["name"])
function_result = None
try:
function_result = function(**result["parameters"])
except Exception as e:
print(e)
# Add the function result to the system prompt
if function_result:
return function_result
except Exception as e:
print(f"Error: {e}")
return None
@ -285,15 +294,18 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
print(response)
if response:
context += f"\n{response}"
context = ("\n" if context != "" else "") + response
system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt
)
if context != "":
system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt
)
data["messages"] = add_or_update_system_message(
system_prompt, data["messages"]
)
print(system_prompt)
data["messages"] = add_or_update_system_message(
f"\n{system_prompt}", data["messages"]
)
del data["tool_ids"]

View File

@ -73,6 +73,7 @@
let selectedModels = [''];
let atSelectedModel: Model | undefined;
let selectedToolIds = [];
let webSearchEnabled = false;
let chat = null;
@ -687,6 +688,7 @@
},
format: $settings.requestFormat ?? undefined,
keep_alive: $settings.keepAlive ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
docs: docs.length > 0 ? docs : undefined,
citations: docs.length > 0,
chat_id: $chatId
@ -948,6 +950,7 @@
top_p: $settings?.params?.top_p ?? undefined,
frequency_penalty: $settings?.params?.frequency_penalty ?? undefined,
max_tokens: $settings?.params?.max_tokens ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
docs: docs.length > 0 ? docs : undefined,
citations: docs.length > 0,
chat_id: $chatId
@ -1274,6 +1277,7 @@
bind:files
bind:prompt
bind:autoScroll
bind:selectedToolIds
bind:webSearchEnabled
bind:atSelectedModel
{selectedModels}

View File

@ -8,7 +8,8 @@
showSidebar,
models,
config,
showCallOverlay
showCallOverlay,
tools
} from '$lib/stores';
import { blobToFile, calculateSHA256, findWordIndices } from '$lib/utils';
@ -57,6 +58,7 @@
let chatInputPlaceholder = '';
export let files = [];
export let selectedToolIds = [];
export let webSearchEnabled = false;
@ -653,6 +655,15 @@
<div class=" ml-0.5 self-end mb-1.5 flex space-x-1">
<InputMenu
bind:webSearchEnabled
bind:selectedToolIds
tools={$tools.reduce((a, e, i, arr) => {
a[e.id] = {
name: e.name,
enabled: false
};
return a;
}, {})}
uploadFilesHandler={() => {
filesInputElement.click();
}}

View File

@ -14,6 +14,8 @@
const i18n = getContext('i18n');
export let uploadFilesHandler: Function;
export let selectedToolIds: string[] = [];
export let webSearchEnabled: boolean;
export let tools = {};
@ -44,16 +46,23 @@
transition={flyAndScale}
>
{#if Object.keys(tools).length > 0}
{#each Object.keys(tools) as tool}
{#each Object.keys(tools) as toolId}
<div
class="flex gap-2 items-center px-3 py-2 text-sm font-medium cursor-pointer rounded-xl"
>
<div class="flex-1 flex items-center gap-2">
<WrenchSolid />
<div class="flex items-center">{tool}</div>
<div class="flex items-center">{tools[toolId].name}</div>
</div>
<Switch bind:state={tools[tool]} />
<Switch
bind:state={tools[toolId].enabled}
on:change={(e) => {
selectedToolIds = e.detail
? [...selectedToolIds, toolId]
: selectedToolIds.filter((id) => id !== toolId);
}}
/>
</div>
{/each}
<hr class="border-gray-100 dark:border-gray-800 my-1" />