feat: __event_call__ support

This commit is contained in:
Timothy J. Baek 2024-07-08 21:39:06 -07:00
parent 1d979d9b75
commit 1b7ff1c5df
3 changed files with 74 additions and 11 deletions

View File

@ -302,6 +302,7 @@ async def get_function_call_response(
user,
model,
__event_emitter__=None,
__event_call__=None,
):
tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2)
@ -445,6 +446,13 @@ async def get_function_call_response(
"__event_emitter__": __event_emitter__,
}
if "__event_call__" in sig.parameters:
# Call the function with the '__event_call__' parameter included
params = {
**params,
"__event_call__": __event_call__,
}
if inspect.iscoroutinefunction(function):
function_result = await function(**params)
else:
@ -468,7 +476,9 @@ async def get_function_call_response(
return None, None, False
async def chat_completion_functions_handler(body, model, user, __event_emitter__):
async def chat_completion_functions_handler(
body, model, user, __event_emitter__, __event_call__
):
skip_files = None
filter_ids = get_filter_function_ids(model)
@ -534,12 +544,19 @@ async def chat_completion_functions_handler(body, model, user, __event_emitter__
**params,
"__model__": model,
}
if "__event_emitter__" in sig.parameters:
params = {
**params,
"__event_emitter__": __event_emitter__,
}
if "__event_call__" in sig.parameters:
params = {
**params,
"__event_call__": __event_call__,
}
if inspect.iscoroutinefunction(inlet):
body = await inlet(**params)
else:
@ -556,7 +573,9 @@ async def chat_completion_functions_handler(body, model, user, __event_emitter__
return body, {}
async def chat_completion_tools_handler(body, model, user, __event_emitter__):
async def chat_completion_tools_handler(
body, model, user, __event_emitter__, __event_call__
):
skip_files = None
contexts = []
@ -579,6 +598,7 @@ async def chat_completion_tools_handler(body, model, user, __event_emitter__):
user=user,
model=model,
__event_emitter__=__event_emitter__,
__event_call__=__event_call__,
)
print(file_handler)
@ -676,6 +696,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
to=session_id,
)
async def __event_call__(data):
response = await sio.call(
"chat-events",
{"chat_id": chat_id, "message_id": message_id, "data": data},
to=session_id,
)
return response
# Initialize data_items to store additional data to be sent to the client
data_items = []
@ -685,7 +713,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
try:
body, flags = await chat_completion_functions_handler(
body, model, user, __event_emitter__
body, model, user, __event_emitter__, __event_call__
)
except Exception as e:
return JSONResponse(
@ -695,7 +723,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
try:
body, flags = await chat_completion_tools_handler(
body, model, user, __event_emitter__
body, model, user, __event_emitter__, __event_call__
)
contexts.extend(flags.get("contexts", []))

View File

@ -61,6 +61,7 @@
import CallOverlay from './MessageInput/CallOverlay.svelte';
import { error } from '@sveltejs/kit';
import ChatControls from './ChatControls.svelte';
import EventConfirmDialog from '../common/ConfirmDialog.svelte';
const i18n: Writable<i18nType> = getContext('i18n');
@ -74,6 +75,11 @@
let processing = '';
let messagesContainerElement: HTMLDivElement;
let showEventConfirmation = false;
let eventConfirmationTitle = '';
let eventConfirmationMessage = '';
let eventCallback = null;
let showModelSelector = true;
let selectedModels = [''];
@ -129,7 +135,7 @@
})();
}
const chatEventHandler = async (event) => {
const chatEventHandler = async (event, cb) => {
if (event.chat_id === $chatId) {
await tick();
console.log(event);
@ -139,17 +145,23 @@
const data = event?.data?.data ?? null;
if (type === 'status') {
if (message.statusHistory) {
if (message?.statusHistory) {
message.statusHistory.push(data);
} else {
message.statusHistory = [data];
}
} else if (type === 'citation') {
if (message.citations) {
if (message?.citations) {
message.citations.push(data);
} else {
message.citations = [data];
}
} else if (type === 'confirmation') {
eventCallback = cb;
showEventConfirmation = true;
eventConfirmationTitle = data.title;
eventConfirmationMessage = data.message;
} else {
console.log('Unknown message type', data);
}
@ -1392,6 +1404,18 @@
<audio id="audioElement" src="" style="display: none;" />
<EventConfirmDialog
bind:show={showEventConfirmation}
title={eventConfirmationTitle}
message={eventConfirmationMessage}
on:confirm={(e) => {
eventCallback(true);
}}
on:cancel={() => {
eventCallback(false);
}}
/>
{#if $showCallOverlay}
<CallOverlay
{submitPrompt}

View File

@ -7,8 +7,8 @@
const dispatch = createEventDispatcher();
export let title = $i18n.t('Confirm your action');
export let message = $i18n.t('This action cannot be undone. Do you wish to continue?');
export let title = '';
export let message = '';
export let cancelLabel = $i18n.t('Cancel');
export let confirmLabel = $i18n.t('Confirm');
@ -58,11 +58,21 @@
}}
>
<div class="px-[1.75rem] py-6">
<div class=" text-lg font-semibold dark:text-gray-200 mb-2.5">{title}</div>
<div class=" text-lg font-semibold dark:text-gray-200 mb-2.5">
{#if title !== ''}
{title}
{:else}
{$i18n.t('Confirm your action')}
{/if}
</div>
<slot>
<div class=" text-sm text-gray-500">
{#if message !== ''}
{message}
{:else}
{$i18n.t('This action cannot be undone. Do you wish to continue?')}
{/if}
</div>
</slot>
@ -71,6 +81,7 @@
class="bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-white font-medium w-full py-2.5 rounded-lg transition"
on:click={() => {
show = false;
dispatch('cancel');
}}
type="button"
>