feat: __event_call__ support

This commit is contained in:
Timothy J. Baek 2024-07-08 21:39:06 -07:00
parent 1d979d9b75
commit 1b7ff1c5df
3 changed files with 74 additions and 11 deletions

View File

@ -302,6 +302,7 @@ async def get_function_call_response(
user, user,
model, model,
__event_emitter__=None, __event_emitter__=None,
__event_call__=None,
): ):
tool = Tools.get_tool_by_id(tool_id) tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2) tools_specs = json.dumps(tool.specs, indent=2)
@ -445,6 +446,13 @@ async def get_function_call_response(
"__event_emitter__": __event_emitter__, "__event_emitter__": __event_emitter__,
} }
if "__event_call__" in sig.parameters:
# Call the function with the '__event_call__' parameter included
params = {
**params,
"__event_call__": __event_call__,
}
if inspect.iscoroutinefunction(function): if inspect.iscoroutinefunction(function):
function_result = await function(**params) function_result = await function(**params)
else: else:
@ -468,7 +476,9 @@ async def get_function_call_response(
return None, None, False return None, None, False
async def chat_completion_functions_handler(body, model, user, __event_emitter__): async def chat_completion_functions_handler(
body, model, user, __event_emitter__, __event_call__
):
skip_files = None skip_files = None
filter_ids = get_filter_function_ids(model) filter_ids = get_filter_function_ids(model)
@ -534,12 +544,19 @@ async def chat_completion_functions_handler(body, model, user, __event_emitter__
**params, **params,
"__model__": model, "__model__": model,
} }
if "__event_emitter__" in sig.parameters: if "__event_emitter__" in sig.parameters:
params = { params = {
**params, **params,
"__event_emitter__": __event_emitter__, "__event_emitter__": __event_emitter__,
} }
if "__event_call__" in sig.parameters:
params = {
**params,
"__event_call__": __event_call__,
}
if inspect.iscoroutinefunction(inlet): if inspect.iscoroutinefunction(inlet):
body = await inlet(**params) body = await inlet(**params)
else: else:
@ -556,7 +573,9 @@ async def chat_completion_functions_handler(body, model, user, __event_emitter__
return body, {} return body, {}
async def chat_completion_tools_handler(body, model, user, __event_emitter__): async def chat_completion_tools_handler(
body, model, user, __event_emitter__, __event_call__
):
skip_files = None skip_files = None
contexts = [] contexts = []
@ -579,6 +598,7 @@ async def chat_completion_tools_handler(body, model, user, __event_emitter__):
user=user, user=user,
model=model, model=model,
__event_emitter__=__event_emitter__, __event_emitter__=__event_emitter__,
__event_call__=__event_call__,
) )
print(file_handler) print(file_handler)
@ -676,6 +696,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
to=session_id, to=session_id,
) )
async def __event_call__(data):
response = await sio.call(
"chat-events",
{"chat_id": chat_id, "message_id": message_id, "data": data},
to=session_id,
)
return response
# Initialize data_items to store additional data to be sent to the client # Initialize data_items to store additional data to be sent to the client
data_items = [] data_items = []
@ -685,7 +713,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
try: try:
body, flags = await chat_completion_functions_handler( body, flags = await chat_completion_functions_handler(
body, model, user, __event_emitter__ body, model, user, __event_emitter__, __event_call__
) )
except Exception as e: except Exception as e:
return JSONResponse( return JSONResponse(
@ -695,7 +723,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
try: try:
body, flags = await chat_completion_tools_handler( body, flags = await chat_completion_tools_handler(
body, model, user, __event_emitter__ body, model, user, __event_emitter__, __event_call__
) )
contexts.extend(flags.get("contexts", [])) contexts.extend(flags.get("contexts", []))

View File

@ -61,6 +61,7 @@
import CallOverlay from './MessageInput/CallOverlay.svelte'; import CallOverlay from './MessageInput/CallOverlay.svelte';
import { error } from '@sveltejs/kit'; import { error } from '@sveltejs/kit';
import ChatControls from './ChatControls.svelte'; import ChatControls from './ChatControls.svelte';
import EventConfirmDialog from '../common/ConfirmDialog.svelte';
const i18n: Writable<i18nType> = getContext('i18n'); const i18n: Writable<i18nType> = getContext('i18n');
@ -74,6 +75,11 @@
let processing = ''; let processing = '';
let messagesContainerElement: HTMLDivElement; let messagesContainerElement: HTMLDivElement;
let showEventConfirmation = false;
let eventConfirmationTitle = '';
let eventConfirmationMessage = '';
let eventCallback = null;
let showModelSelector = true; let showModelSelector = true;
let selectedModels = ['']; let selectedModels = [''];
@ -129,7 +135,7 @@
})(); })();
} }
const chatEventHandler = async (event) => { const chatEventHandler = async (event, cb) => {
if (event.chat_id === $chatId) { if (event.chat_id === $chatId) {
await tick(); await tick();
console.log(event); console.log(event);
@ -139,17 +145,23 @@
const data = event?.data?.data ?? null; const data = event?.data?.data ?? null;
if (type === 'status') { if (type === 'status') {
if (message.statusHistory) { if (message?.statusHistory) {
message.statusHistory.push(data); message.statusHistory.push(data);
} else { } else {
message.statusHistory = [data]; message.statusHistory = [data];
} }
} else if (type === 'citation') { } else if (type === 'citation') {
if (message.citations) { if (message?.citations) {
message.citations.push(data); message.citations.push(data);
} else { } else {
message.citations = [data]; message.citations = [data];
} }
} else if (type === 'confirmation') {
eventCallback = cb;
showEventConfirmation = true;
eventConfirmationTitle = data.title;
eventConfirmationMessage = data.message;
} else { } else {
console.log('Unknown message type', data); console.log('Unknown message type', data);
} }
@ -1392,6 +1404,18 @@
<audio id="audioElement" src="" style="display: none;" /> <audio id="audioElement" src="" style="display: none;" />
<EventConfirmDialog
bind:show={showEventConfirmation}
title={eventConfirmationTitle}
message={eventConfirmationMessage}
on:confirm={(e) => {
eventCallback(true);
}}
on:cancel={() => {
eventCallback(false);
}}
/>
{#if $showCallOverlay} {#if $showCallOverlay}
<CallOverlay <CallOverlay
{submitPrompt} {submitPrompt}

View File

@ -7,8 +7,8 @@
const dispatch = createEventDispatcher(); const dispatch = createEventDispatcher();
export let title = $i18n.t('Confirm your action'); export let title = '';
export let message = $i18n.t('This action cannot be undone. Do you wish to continue?'); export let message = '';
export let cancelLabel = $i18n.t('Cancel'); export let cancelLabel = $i18n.t('Cancel');
export let confirmLabel = $i18n.t('Confirm'); export let confirmLabel = $i18n.t('Confirm');
@ -58,11 +58,21 @@
}} }}
> >
<div class="px-[1.75rem] py-6"> <div class="px-[1.75rem] py-6">
<div class=" text-lg font-semibold dark:text-gray-200 mb-2.5">{title}</div> <div class=" text-lg font-semibold dark:text-gray-200 mb-2.5">
{#if title !== ''}
{title}
{:else}
{$i18n.t('Confirm your action')}
{/if}
</div>
<slot> <slot>
<div class=" text-sm text-gray-500"> <div class=" text-sm text-gray-500">
{#if message !== ''}
{message} {message}
{:else}
{$i18n.t('This action cannot be undone. Do you wish to continue?')}
{/if}
</div> </div>
</slot> </slot>
@ -71,6 +81,7 @@
class="bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-white font-medium w-full py-2.5 rounded-lg transition" class="bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-white font-medium w-full py-2.5 rounded-lg transition"
on:click={() => { on:click={() => {
show = false; show = false;
dispatch('cancel');
}} }}
type="button" type="button"
> >