feat: merge responses

This commit is contained in:
Timothy J. Baek 2024-08-18 20:59:59 +02:00
parent 65923006a8
commit 7c81509804
13 changed files with 378 additions and 129 deletions

View File

@ -100,3 +100,4 @@ class TASKS(str, Enum):
EMOJI_GENERATION = "emoji_generation"
QUERY_GENERATION = "query_generation"
FUNCTION_CALLING = "function_calling"
MOA_RESPONSE_GENERATION = "moa_response_generation"

View File

@ -73,6 +73,7 @@ from utils.task import (
title_generation_template,
search_query_generation_template,
tools_function_calling_generation_template,
moa_response_generation_template,
)
from utils.misc import (
get_last_user_message,
@ -1570,6 +1571,58 @@ Message: """{{prompt}}"""
return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/moa/completions")
async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)):
print("generate_moa_response")
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
model_id = get_task_model_id(model_id)
print(model_id)
template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}"
Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
Responses from models: {{responses}}"""
content = moa_response_generation_template(
template,
form_data["prompt"],
form_data["responses"],
)
payload = {
"model": model_id,
"messages": [{"role": "user", "content": content}],
"stream": form_data.get("stream", False),
"chat_id": form_data.get("chat_id", None),
"metadata": {"task": str(TASKS.MOA_RESPONSE_GENERATION)},
}
log.debug(payload)
try:
payload = filter_pipeline(payload, user)
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
##################################
#
# Pipelines Endpoints

View File

@ -121,6 +121,43 @@ def search_query_generation_template(
return template
def moa_response_generation_template(
template: str, prompt: str, responses: list[str]
) -> str:
def replacement_function(match):
full_match = match.group(0)
start_length = match.group(1)
end_length = match.group(2)
middle_length = match.group(3)
if full_match == "{{prompt}}":
return prompt
elif start_length is not None:
return prompt[: int(start_length)]
elif end_length is not None:
return prompt[-int(end_length) :]
elif middle_length is not None:
middle_length = int(middle_length)
if len(prompt) <= middle_length:
return prompt
start = prompt[: math.ceil(middle_length / 2)]
end = prompt[-math.floor(middle_length / 2) :]
return f"{start}...{end}"
return ""
template = re.sub(
r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}",
replacement_function,
template,
)
responses = [f'"""{response}"""' for response in responses]
responses = "\n\n".join(responses)
template = template.replace("{{responses}}", responses)
return template
def tools_function_calling_generation_template(template: str, tools_specs: str) -> str:
template = template.replace("{{TOOLS}}", tools_specs)
return template

View File

@ -34,6 +34,10 @@ math {
@apply rounded-lg;
}
.markdown-prose {
@apply prose dark:prose-invert prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line;
}
.markdown a {
@apply underline;
}

View File

@ -333,6 +333,42 @@ export const generateSearchQuery = async (
return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? prompt;
};
export const generateMoACompletion = async (
token: string = '',
model: string,
prompt: string,
responses: string[]
) => {
const controller = new AbortController();
let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/task/moa/completions`, {
signal: controller.signal,
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
model: model,
prompt: prompt,
responses: responses,
stream: true
})
}).catch((err) => {
console.log(err);
error = err;
return null;
});
if (error) {
throw error;
}
return [res, controller];
};
export const getPipelinesList = async (token: string = '') => {
let error = null;

View File

@ -54,7 +54,13 @@
import { createOpenAITextStream } from '$lib/apis/streaming';
import { queryMemory } from '$lib/apis/memories';
import { getAndUpdateUserLocation, getUserSettings } from '$lib/apis/users';
import { chatCompleted, generateTitle, generateSearchQuery, chatAction } from '$lib/apis';
import {
chatCompleted,
generateTitle,
generateSearchQuery,
chatAction,
generateMoACompletion
} from '$lib/apis';
import Banner from '../common/Banner.svelte';
import MessageInput from '$lib/components/chat/MessageInput.svelte';
@ -1511,6 +1517,50 @@
return [];
});
};
const mergeResponses = async (messageId, responses) => {
console.log('mergeResponses', messageId, responses);
const message = history.messages[messageId];
const mergedResponse = {
status: true,
content: ''
};
message.merged = mergedResponse;
try {
const [res, controller] = await generateMoACompletion(
localStorage.token,
message.model,
history.messages[message.parentId].content,
responses
);
if (res && res.ok && res.body) {
const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks);
for await (const update of textStream) {
const { value, done, citations, error, usage } = update;
if (error || done) {
break;
}
if (mergedResponse.content == '' && value == '\n') {
continue;
} else {
mergedResponse.content += value;
messages = messages;
}
if (autoScroll) {
scrollToBottom();
}
}
} else {
console.error(res);
}
} catch (e) {
console.error(e);
}
};
</script>
<svelte:head>
@ -1637,6 +1687,7 @@
{sendPrompt}
{continueGeneration}
{regenerateResponse}
{mergeResponses}
{chatActionHandler}
/>
</div>

View File

@ -19,6 +19,7 @@
export let sendPrompt: Function;
export let continueGeneration: Function;
export let regenerateResponse: Function;
export let mergeResponses: Function;
export let chatActionHandler: Function;
export let user = $_user;
@ -374,6 +375,7 @@
{rateMessage}
copyToClipboard={copyToClipboardWithToast}
{continueGeneration}
{mergeResponses}
{regenerateResponse}
on:change={async () => {
await updateChatById(localStorage.token, chatId, {

View File

@ -0,0 +1,56 @@
<script lang="ts">
import CitationsModal from './CitationsModal.svelte';
export let citations = [];
let showCitationModal = false;
let selectedCitation = null;
</script>
<CitationsModal bind:show={showCitationModal} citation={selectedCitation} />
<div class="mt-1 mb-2 w-full flex gap-1 items-center flex-wrap">
{#each citations.reduce((acc, citation) => {
citation.document.forEach((document, index) => {
const metadata = citation.metadata?.[index];
const id = metadata?.source ?? 'N/A';
let source = citation?.source;
if (metadata?.name) {
source = { ...source, name: metadata.name };
}
// Check if ID looks like a URL
if (id.startsWith('http://') || id.startsWith('https://')) {
source = { name: id };
}
const existingSource = acc.find((item) => item.id === id);
if (existingSource) {
existingSource.document.push(document);
existingSource.metadata.push(metadata);
} else {
acc.push( { id: id, source: source, document: [document], metadata: metadata ? [metadata] : [] } );
}
});
return acc;
}, []) as citation, idx}
<div class="flex gap-1 text-xs font-semibold">
<button
class="flex dark:text-gray-300 py-1 px-1 bg-gray-50 hover:bg-gray-100 dark:bg-gray-850 dark:hover:bg-gray-800 transition rounded-xl"
on:click={() => {
showCitationModal = true;
selectedCitation = citation;
}}
>
<div class="bg-white dark:bg-gray-700 rounded-full size-4">
{idx + 1}
</div>
<div class="flex-1 mx-2 line-clamp-1">
{citation.source.name}
</div>
</button>
</div>
{/each}
</div>

View File

@ -0,0 +1,26 @@
<script lang="ts">
export let content = '';
</script>
<div
class="flex mt-2 mb-4 space-x-2 border px-4 py-3 border-red-800 bg-red-800/30 font-medium rounded-lg"
>
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke-width="1.5"
stroke="currentColor"
class="w-5 h-5 self-center"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M12 9v3.75m9-.75a9 9 0 11-18 0 9 9 0 0118 0zm-9 3.75h.008v.008H12v-.008z"
/>
</svg>
<div class=" self-center">
{content}
</div>
</div>

View File

@ -0,0 +1,32 @@
<script>
import { marked } from 'marked';
import markedKatex from '$lib/utils/marked/katex-extension';
import { replaceTokens, processResponseContent } from '$lib/utils';
import { user } from '$lib/stores';
import MarkdownTokens from './Markdown/MarkdownTokens.svelte';
export let id;
export let content;
export let model = null;
let tokens = [];
const options = {
throwOnError: false
};
marked.use(markedKatex(options));
$: (async () => {
if (content) {
tokens = marked.lexer(
replaceTokens(processResponseContent(content), model?.name, $user?.name)
);
}
})();
</script>
{#key id}
<MarkdownTokens {tokens} {id} />
{/key}

View File

@ -1,11 +1,21 @@
<script lang="ts">
import dayjs from 'dayjs';
import { onMount, tick, getContext } from 'svelte';
import { createEventDispatcher } from 'svelte';
import { mobile, settings } from '$lib/stores';
import { generateMoACompletion } from '$lib/apis';
import { updateChatById } from '$lib/apis/chats';
import { createOpenAITextStream } from '$lib/apis/streaming';
import ResponseMessage from './ResponseMessage.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte';
import Merge from '$lib/components/icons/Merge.svelte';
import { mobile } from '$lib/stores';
import Markdown from './Markdown.svelte';
import Name from './Name.svelte';
import Skeleton from './Skeleton.svelte';
const i18n = getContext('i18n');
@ -26,6 +36,7 @@
export let copyToClipboard: Function;
export let continueGeneration: Function;
export let mergeResponses: Function;
export let regenerateResponse: Function;
const dispatch = createEventDispatcher();
@ -106,6 +117,14 @@
}, {});
};
const mergeResponsesHandler = async () => {
const responses = Object.keys(groupedMessages).map((modelIdx) => {
const { messages } = groupedMessages[modelIdx];
return messages[groupedMessagesIdx[modelIdx]].content;
});
mergeResponses(currentMessageId, responses);
};
onMount(async () => {
initHandler();
});
@ -185,22 +204,55 @@
</div>
{#if !readOnly && isLastMessage}
{#if !parentMessage?.childrenIds.map((id) => history.messages[id]).find((m) => !m.done)}
<div class=" flex justify-end overflow-x-auto buttons text-gray-600 dark:text-gray-500 mt-1">
<Tooltip content={$i18n.t('Merge Responses')} placement="bottom">
<button
type="button"
id="merge-response-button"
class="{true
? 'visible'
: 'invisible group-hover:visible'} p-1 hover:bg-black/5 dark:hover:bg-white/5 rounded-lg dark:hover:text-white hover:text-black transition regenerate-response-button"
on:click={() => {
// continueGeneration();
}}
>
<Merge className=" size-5 " />
</button>
</Tooltip>
{#if !Object.keys(groupedMessages).find((modelIdx) => {
const { messages } = groupedMessages[modelIdx];
return !messages[groupedMessagesIdx[modelIdx]].done;
})}
<div class="flex justify-end">
<div class="w-full">
{#if history.messages[currentMessageId]?.merged?.status}
{@const message = history.messages[currentMessageId]?.merged}
<div class="w-full rounded-xl pl-5 pr-2 py-2">
<Name>
Merged Response
{#if message.timestamp}
<span
class=" self-center invisible group-hover:visible text-gray-400 text-xs font-medium uppercase ml-0.5 -mt-0.5"
>
{dayjs(message.timestamp * 1000).format($i18n.t('h:mm a'))}
</span>
{/if}
</Name>
<div class="mt-1 markdown-prose w-full min-w-full">
{#if (message?.content ?? '') === ''}
<Skeleton />
{:else}
<Markdown id={`merged`} content={message.content ?? ''} />
{/if}
</div>
</div>
{/if}
</div>
<div class=" flex-shrink-0 text-gray-600 dark:text-gray-500 mt-1">
<Tooltip content={$i18n.t('Merge Responses')} placement="bottom">
<button
type="button"
id="merge-response-button"
class="{true
? 'visible'
: 'invisible group-hover:visible'} p-1 hover:bg-black/5 dark:hover:bg-white/5 rounded-lg dark:hover:text-white hover:text-black transition regenerate-response-button"
on:click={() => {
mergeResponsesHandler();
}}
>
<Merge className=" size-5 " />
</button>
</Tooltip>
</div>
</div>
{/if}
{/if}

View File

@ -1,7 +1,6 @@
<script lang="ts">
import { toast } from 'svelte-sonner';
import dayjs from 'dayjs';
import { marked } from 'marked';
import { fade } from 'svelte/transition';
import { createEventDispatcher } from 'svelte';
@ -33,7 +32,9 @@
import Spinner from '$lib/components/common/Spinner.svelte';
import WebSearchResults from './ResponseMessage/WebSearchResults.svelte';
import Sparkles from '$lib/components/icons/Sparkles.svelte';
import MarkdownTokens from './Markdown/MarkdownTokens.svelte';
import Markdown from './Markdown.svelte';
import Error from './Error.svelte';
import Citations from './Citations.svelte';
export let message;
export let siblings;
@ -58,7 +59,6 @@
let edit = false;
let editedContent = '';
let editTextAreaElement: HTMLTextAreaElement;
let tooltipInstance = null;
let sentencesAudio = {};
let speaking = null;
@ -68,28 +68,6 @@
let generatingImage = false;
let showRateComment = false;
let showCitationModal = false;
let selectedCitation = null;
let tokens;
import 'katex/dist/katex.min.css';
import markedKatex from '$lib/utils/marked/katex-extension';
const options = {
throwOnError: false
};
marked.use(markedKatex(options));
$: (async () => {
if (message?.content) {
tokens = marked.lexer(
replaceTokens(processResponseContent(message?.content), model?.name, $user?.name)
);
}
})();
const playAudio = (idx) => {
return new Promise((res) => {
@ -282,8 +260,6 @@
});
</script>
<CitationsModal bind:show={showCitationModal} citation={selectedCitation} />
{#key message.id}
<div
class=" flex w-full message-{message.id}"
@ -321,9 +297,7 @@
</div>
{/if}
<div
class="prose chat-{message.role} w-full max-w-full dark:prose-invert prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line"
>
<div class="chat-{message.role} w-full min-w-full markdown-prose">
<div>
{#if (message?.statusHistory ?? [...(message?.status ? [message?.status] : [])]).length > 0}
{@const status = (
@ -408,82 +382,15 @@
{:else if message.content && message.error !== true}
<!-- always show message contents even if there's an error -->
<!-- unless message.error === true which is legacy error handling, where the error message is stored in message.content -->
{#key message.id}
<MarkdownTokens id={message.id} {tokens} />
{/key}
<Markdown id={message.id} content={message.content} {model} />
{/if}
{#if message.error}
<div
class="flex mt-2 mb-4 space-x-2 border px-4 py-3 border-red-800 bg-red-800/30 font-medium rounded-lg"
>
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke-width="1.5"
stroke="currentColor"
class="w-5 h-5 self-center"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M12 9v3.75m9-.75a9 9 0 11-18 0 9 9 0 0118 0zm-9 3.75h.008v.008H12v-.008z"
/>
</svg>
<div class=" self-center">
{message?.error?.content ?? message.content}
</div>
</div>
<Error content={message?.error?.content ?? message.content} />
{/if}
{#if message.citations}
<div class="mt-1 mb-2 w-full flex gap-1 items-center flex-wrap">
{#each message.citations.reduce((acc, citation) => {
citation.document.forEach((document, index) => {
const metadata = citation.metadata?.[index];
const id = metadata?.source ?? 'N/A';
let source = citation?.source;
if (metadata?.name) {
source = { ...source, name: metadata.name };
}
// Check if ID looks like a URL
if (id.startsWith('http://') || id.startsWith('https://')) {
source = { name: id };
}
const existingSource = acc.find((item) => item.id === id);
if (existingSource) {
existingSource.document.push(document);
existingSource.metadata.push(metadata);
} else {
acc.push( { id: id, source: source, document: [document], metadata: metadata ? [metadata] : [] } );
}
});
return acc;
}, []) as citation, idx}
<div class="flex gap-1 text-xs font-semibold">
<button
class="flex dark:text-gray-300 py-1 px-1 bg-gray-50 hover:bg-gray-100 dark:bg-gray-850 dark:hover:bg-gray-800 transition rounded-xl"
on:click={() => {
showCitationModal = true;
selectedCitation = citation;
}}
>
<div class="bg-white dark:bg-gray-700 rounded-full size-4">
{idx + 1}
</div>
<div class="flex-1 mx-2 line-clamp-1">
{citation.source.name}
</div>
</button>
</div>
{/each}
</div>
<Citations citations={message.citations} />
{/if}
</div>
{/if}

View File

@ -13,6 +13,7 @@
import { marked } from 'marked';
import { processResponseContent, replaceTokens } from '$lib/utils';
import MarkdownTokens from './Markdown/MarkdownTokens.svelte';
import Markdown from './Markdown.svelte';
const i18n = getContext('i18n');
@ -93,9 +94,7 @@
</div>
{/if}
<div
class="prose chat-{message.role} w-full max-w-full dark:prose-invert prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line"
>
<div class="chat-{message.role} w-full min-w-full markdown-prose">
{#if message.files}
<div class="mt-2.5 mb-1 w-full flex flex-col justify-end overflow-x-auto gap-1 flex-wrap">
{#each message.files as file}
@ -174,14 +173,7 @@
: ' w-full'}"
>
{#if message.content}
<div class="">
{#key message.id}
<MarkdownTokens
id={message.id}
tokens={marked.lexer(processResponseContent(message?.content))}
/>
{/key}
</div>
<Markdown id={message.id} content={message.content} />
{/if}
</div>
</div>