feat: merge responses

This commit is contained in:
Timothy J. Baek 2024-08-18 20:59:59 +02:00
parent 65923006a8
commit 7c81509804
13 changed files with 378 additions and 129 deletions

View File

@ -100,3 +100,4 @@ class TASKS(str, Enum):
EMOJI_GENERATION = "emoji_generation" EMOJI_GENERATION = "emoji_generation"
QUERY_GENERATION = "query_generation" QUERY_GENERATION = "query_generation"
FUNCTION_CALLING = "function_calling" FUNCTION_CALLING = "function_calling"
MOA_RESPONSE_GENERATION = "moa_response_generation"

View File

@ -73,6 +73,7 @@ from utils.task import (
title_generation_template, title_generation_template,
search_query_generation_template, search_query_generation_template,
tools_function_calling_generation_template, tools_function_calling_generation_template,
moa_response_generation_template,
) )
from utils.misc import ( from utils.misc import (
get_last_user_message, get_last_user_message,
@ -1570,6 +1571,58 @@ Message: """{{prompt}}"""
return await generate_chat_completions(form_data=payload, user=user) return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/moa/completions")
async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)):
print("generate_moa_response")
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
model_id = get_task_model_id(model_id)
print(model_id)
template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}"
Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
Responses from models: {{responses}}"""
content = moa_response_generation_template(
template,
form_data["prompt"],
form_data["responses"],
)
payload = {
"model": model_id,
"messages": [{"role": "user", "content": content}],
"stream": form_data.get("stream", False),
"chat_id": form_data.get("chat_id", None),
"metadata": {"task": str(TASKS.MOA_RESPONSE_GENERATION)},
}
log.debug(payload)
try:
payload = filter_pipeline(payload, user)
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
################################## ##################################
# #
# Pipelines Endpoints # Pipelines Endpoints

View File

@ -121,6 +121,43 @@ def search_query_generation_template(
return template return template
def moa_response_generation_template(
template: str, prompt: str, responses: list[str]
) -> str:
def replacement_function(match):
full_match = match.group(0)
start_length = match.group(1)
end_length = match.group(2)
middle_length = match.group(3)
if full_match == "{{prompt}}":
return prompt
elif start_length is not None:
return prompt[: int(start_length)]
elif end_length is not None:
return prompt[-int(end_length) :]
elif middle_length is not None:
middle_length = int(middle_length)
if len(prompt) <= middle_length:
return prompt
start = prompt[: math.ceil(middle_length / 2)]
end = prompt[-math.floor(middle_length / 2) :]
return f"{start}...{end}"
return ""
template = re.sub(
r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}",
replacement_function,
template,
)
responses = [f'"""{response}"""' for response in responses]
responses = "\n\n".join(responses)
template = template.replace("{{responses}}", responses)
return template
def tools_function_calling_generation_template(template: str, tools_specs: str) -> str: def tools_function_calling_generation_template(template: str, tools_specs: str) -> str:
template = template.replace("{{TOOLS}}", tools_specs) template = template.replace("{{TOOLS}}", tools_specs)
return template return template

View File

@ -34,6 +34,10 @@ math {
@apply rounded-lg; @apply rounded-lg;
} }
.markdown-prose {
@apply prose dark:prose-invert prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line;
}
.markdown a { .markdown a {
@apply underline; @apply underline;
} }

View File

@ -333,6 +333,42 @@ export const generateSearchQuery = async (
return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? prompt; return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? prompt;
}; };
export const generateMoACompletion = async (
token: string = '',
model: string,
prompt: string,
responses: string[]
) => {
const controller = new AbortController();
let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/task/moa/completions`, {
signal: controller.signal,
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
model: model,
prompt: prompt,
responses: responses,
stream: true
})
}).catch((err) => {
console.log(err);
error = err;
return null;
});
if (error) {
throw error;
}
return [res, controller];
};
export const getPipelinesList = async (token: string = '') => { export const getPipelinesList = async (token: string = '') => {
let error = null; let error = null;

View File

@ -54,7 +54,13 @@
import { createOpenAITextStream } from '$lib/apis/streaming'; import { createOpenAITextStream } from '$lib/apis/streaming';
import { queryMemory } from '$lib/apis/memories'; import { queryMemory } from '$lib/apis/memories';
import { getAndUpdateUserLocation, getUserSettings } from '$lib/apis/users'; import { getAndUpdateUserLocation, getUserSettings } from '$lib/apis/users';
import { chatCompleted, generateTitle, generateSearchQuery, chatAction } from '$lib/apis'; import {
chatCompleted,
generateTitle,
generateSearchQuery,
chatAction,
generateMoACompletion
} from '$lib/apis';
import Banner from '../common/Banner.svelte'; import Banner from '../common/Banner.svelte';
import MessageInput from '$lib/components/chat/MessageInput.svelte'; import MessageInput from '$lib/components/chat/MessageInput.svelte';
@ -1511,6 +1517,50 @@
return []; return [];
}); });
}; };
const mergeResponses = async (messageId, responses) => {
console.log('mergeResponses', messageId, responses);
const message = history.messages[messageId];
const mergedResponse = {
status: true,
content: ''
};
message.merged = mergedResponse;
try {
const [res, controller] = await generateMoACompletion(
localStorage.token,
message.model,
history.messages[message.parentId].content,
responses
);
if (res && res.ok && res.body) {
const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks);
for await (const update of textStream) {
const { value, done, citations, error, usage } = update;
if (error || done) {
break;
}
if (mergedResponse.content == '' && value == '\n') {
continue;
} else {
mergedResponse.content += value;
messages = messages;
}
if (autoScroll) {
scrollToBottom();
}
}
} else {
console.error(res);
}
} catch (e) {
console.error(e);
}
};
</script> </script>
<svelte:head> <svelte:head>
@ -1637,6 +1687,7 @@
{sendPrompt} {sendPrompt}
{continueGeneration} {continueGeneration}
{regenerateResponse} {regenerateResponse}
{mergeResponses}
{chatActionHandler} {chatActionHandler}
/> />
</div> </div>

View File

@ -19,6 +19,7 @@
export let sendPrompt: Function; export let sendPrompt: Function;
export let continueGeneration: Function; export let continueGeneration: Function;
export let regenerateResponse: Function; export let regenerateResponse: Function;
export let mergeResponses: Function;
export let chatActionHandler: Function; export let chatActionHandler: Function;
export let user = $_user; export let user = $_user;
@ -374,6 +375,7 @@
{rateMessage} {rateMessage}
copyToClipboard={copyToClipboardWithToast} copyToClipboard={copyToClipboardWithToast}
{continueGeneration} {continueGeneration}
{mergeResponses}
{regenerateResponse} {regenerateResponse}
on:change={async () => { on:change={async () => {
await updateChatById(localStorage.token, chatId, { await updateChatById(localStorage.token, chatId, {

View File

@ -0,0 +1,56 @@
<script lang="ts">
import CitationsModal from './CitationsModal.svelte';
export let citations = [];
let showCitationModal = false;
let selectedCitation = null;
</script>
<CitationsModal bind:show={showCitationModal} citation={selectedCitation} />
<div class="mt-1 mb-2 w-full flex gap-1 items-center flex-wrap">
{#each citations.reduce((acc, citation) => {
citation.document.forEach((document, index) => {
const metadata = citation.metadata?.[index];
const id = metadata?.source ?? 'N/A';
let source = citation?.source;
if (metadata?.name) {
source = { ...source, name: metadata.name };
}
// Check if ID looks like a URL
if (id.startsWith('http://') || id.startsWith('https://')) {
source = { name: id };
}
const existingSource = acc.find((item) => item.id === id);
if (existingSource) {
existingSource.document.push(document);
existingSource.metadata.push(metadata);
} else {
acc.push( { id: id, source: source, document: [document], metadata: metadata ? [metadata] : [] } );
}
});
return acc;
}, []) as citation, idx}
<div class="flex gap-1 text-xs font-semibold">
<button
class="flex dark:text-gray-300 py-1 px-1 bg-gray-50 hover:bg-gray-100 dark:bg-gray-850 dark:hover:bg-gray-800 transition rounded-xl"
on:click={() => {
showCitationModal = true;
selectedCitation = citation;
}}
>
<div class="bg-white dark:bg-gray-700 rounded-full size-4">
{idx + 1}
</div>
<div class="flex-1 mx-2 line-clamp-1">
{citation.source.name}
</div>
</button>
</div>
{/each}
</div>

View File

@ -0,0 +1,26 @@
<script lang="ts">
export let content = '';
</script>
<div
class="flex mt-2 mb-4 space-x-2 border px-4 py-3 border-red-800 bg-red-800/30 font-medium rounded-lg"
>
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke-width="1.5"
stroke="currentColor"
class="w-5 h-5 self-center"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M12 9v3.75m9-.75a9 9 0 11-18 0 9 9 0 0118 0zm-9 3.75h.008v.008H12v-.008z"
/>
</svg>
<div class=" self-center">
{content}
</div>
</div>

View File

@ -0,0 +1,32 @@
<script>
import { marked } from 'marked';
import markedKatex from '$lib/utils/marked/katex-extension';
import { replaceTokens, processResponseContent } from '$lib/utils';
import { user } from '$lib/stores';
import MarkdownTokens from './Markdown/MarkdownTokens.svelte';
export let id;
export let content;
export let model = null;
let tokens = [];
const options = {
throwOnError: false
};
marked.use(markedKatex(options));
$: (async () => {
if (content) {
tokens = marked.lexer(
replaceTokens(processResponseContent(content), model?.name, $user?.name)
);
}
})();
</script>
{#key id}
<MarkdownTokens {tokens} {id} />
{/key}

View File

@ -1,11 +1,21 @@
<script lang="ts"> <script lang="ts">
import dayjs from 'dayjs';
import { onMount, tick, getContext } from 'svelte'; import { onMount, tick, getContext } from 'svelte';
import { createEventDispatcher } from 'svelte'; import { createEventDispatcher } from 'svelte';
import { mobile, settings } from '$lib/stores';
import { generateMoACompletion } from '$lib/apis';
import { updateChatById } from '$lib/apis/chats'; import { updateChatById } from '$lib/apis/chats';
import { createOpenAITextStream } from '$lib/apis/streaming';
import ResponseMessage from './ResponseMessage.svelte'; import ResponseMessage from './ResponseMessage.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte';
import Merge from '$lib/components/icons/Merge.svelte'; import Merge from '$lib/components/icons/Merge.svelte';
import { mobile } from '$lib/stores';
import Markdown from './Markdown.svelte';
import Name from './Name.svelte';
import Skeleton from './Skeleton.svelte';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
@ -26,6 +36,7 @@
export let copyToClipboard: Function; export let copyToClipboard: Function;
export let continueGeneration: Function; export let continueGeneration: Function;
export let mergeResponses: Function;
export let regenerateResponse: Function; export let regenerateResponse: Function;
const dispatch = createEventDispatcher(); const dispatch = createEventDispatcher();
@ -106,6 +117,14 @@
}, {}); }, {});
}; };
const mergeResponsesHandler = async () => {
const responses = Object.keys(groupedMessages).map((modelIdx) => {
const { messages } = groupedMessages[modelIdx];
return messages[groupedMessagesIdx[modelIdx]].content;
});
mergeResponses(currentMessageId, responses);
};
onMount(async () => { onMount(async () => {
initHandler(); initHandler();
}); });
@ -185,22 +204,55 @@
</div> </div>
{#if !readOnly && isLastMessage} {#if !readOnly && isLastMessage}
{#if !parentMessage?.childrenIds.map((id) => history.messages[id]).find((m) => !m.done)} {#if !Object.keys(groupedMessages).find((modelIdx) => {
<div class=" flex justify-end overflow-x-auto buttons text-gray-600 dark:text-gray-500 mt-1"> const { messages } = groupedMessages[modelIdx];
<Tooltip content={$i18n.t('Merge Responses')} placement="bottom"> return !messages[groupedMessagesIdx[modelIdx]].done;
<button })}
type="button" <div class="flex justify-end">
id="merge-response-button" <div class="w-full">
class="{true {#if history.messages[currentMessageId]?.merged?.status}
? 'visible' {@const message = history.messages[currentMessageId]?.merged}
: 'invisible group-hover:visible'} p-1 hover:bg-black/5 dark:hover:bg-white/5 rounded-lg dark:hover:text-white hover:text-black transition regenerate-response-button"
on:click={() => { <div class="w-full rounded-xl pl-5 pr-2 py-2">
// continueGeneration(); <Name>
}} Merged Response
>
<Merge className=" size-5 " /> {#if message.timestamp}
</button> <span
</Tooltip> class=" self-center invisible group-hover:visible text-gray-400 text-xs font-medium uppercase ml-0.5 -mt-0.5"
>
{dayjs(message.timestamp * 1000).format($i18n.t('h:mm a'))}
</span>
{/if}
</Name>
<div class="mt-1 markdown-prose w-full min-w-full">
{#if (message?.content ?? '') === ''}
<Skeleton />
{:else}
<Markdown id={`merged`} content={message.content ?? ''} />
{/if}
</div>
</div>
{/if}
</div>
<div class=" flex-shrink-0 text-gray-600 dark:text-gray-500 mt-1">
<Tooltip content={$i18n.t('Merge Responses')} placement="bottom">
<button
type="button"
id="merge-response-button"
class="{true
? 'visible'
: 'invisible group-hover:visible'} p-1 hover:bg-black/5 dark:hover:bg-white/5 rounded-lg dark:hover:text-white hover:text-black transition regenerate-response-button"
on:click={() => {
mergeResponsesHandler();
}}
>
<Merge className=" size-5 " />
</button>
</Tooltip>
</div>
</div> </div>
{/if} {/if}
{/if} {/if}

View File

@ -1,7 +1,6 @@
<script lang="ts"> <script lang="ts">
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';
import dayjs from 'dayjs'; import dayjs from 'dayjs';
import { marked } from 'marked';
import { fade } from 'svelte/transition'; import { fade } from 'svelte/transition';
import { createEventDispatcher } from 'svelte'; import { createEventDispatcher } from 'svelte';
@ -33,7 +32,9 @@
import Spinner from '$lib/components/common/Spinner.svelte'; import Spinner from '$lib/components/common/Spinner.svelte';
import WebSearchResults from './ResponseMessage/WebSearchResults.svelte'; import WebSearchResults from './ResponseMessage/WebSearchResults.svelte';
import Sparkles from '$lib/components/icons/Sparkles.svelte'; import Sparkles from '$lib/components/icons/Sparkles.svelte';
import MarkdownTokens from './Markdown/MarkdownTokens.svelte'; import Markdown from './Markdown.svelte';
import Error from './Error.svelte';
import Citations from './Citations.svelte';
export let message; export let message;
export let siblings; export let siblings;
@ -58,7 +59,6 @@
let edit = false; let edit = false;
let editedContent = ''; let editedContent = '';
let editTextAreaElement: HTMLTextAreaElement; let editTextAreaElement: HTMLTextAreaElement;
let tooltipInstance = null;
let sentencesAudio = {}; let sentencesAudio = {};
let speaking = null; let speaking = null;
@ -68,28 +68,6 @@
let generatingImage = false; let generatingImage = false;
let showRateComment = false; let showRateComment = false;
let showCitationModal = false;
let selectedCitation = null;
let tokens;
import 'katex/dist/katex.min.css';
import markedKatex from '$lib/utils/marked/katex-extension';
const options = {
throwOnError: false
};
marked.use(markedKatex(options));
$: (async () => {
if (message?.content) {
tokens = marked.lexer(
replaceTokens(processResponseContent(message?.content), model?.name, $user?.name)
);
}
})();
const playAudio = (idx) => { const playAudio = (idx) => {
return new Promise((res) => { return new Promise((res) => {
@ -282,8 +260,6 @@
}); });
</script> </script>
<CitationsModal bind:show={showCitationModal} citation={selectedCitation} />
{#key message.id} {#key message.id}
<div <div
class=" flex w-full message-{message.id}" class=" flex w-full message-{message.id}"
@ -321,9 +297,7 @@
</div> </div>
{/if} {/if}
<div <div class="chat-{message.role} w-full min-w-full markdown-prose">
class="prose chat-{message.role} w-full max-w-full dark:prose-invert prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line"
>
<div> <div>
{#if (message?.statusHistory ?? [...(message?.status ? [message?.status] : [])]).length > 0} {#if (message?.statusHistory ?? [...(message?.status ? [message?.status] : [])]).length > 0}
{@const status = ( {@const status = (
@ -408,82 +382,15 @@
{:else if message.content && message.error !== true} {:else if message.content && message.error !== true}
<!-- always show message contents even if there's an error --> <!-- always show message contents even if there's an error -->
<!-- unless message.error === true which is legacy error handling, where the error message is stored in message.content --> <!-- unless message.error === true which is legacy error handling, where the error message is stored in message.content -->
{#key message.id} <Markdown id={message.id} content={message.content} {model} />
<MarkdownTokens id={message.id} {tokens} />
{/key}
{/if} {/if}
{#if message.error} {#if message.error}
<div <Error content={message?.error?.content ?? message.content} />
class="flex mt-2 mb-4 space-x-2 border px-4 py-3 border-red-800 bg-red-800/30 font-medium rounded-lg"
>
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke-width="1.5"
stroke="currentColor"
class="w-5 h-5 self-center"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M12 9v3.75m9-.75a9 9 0 11-18 0 9 9 0 0118 0zm-9 3.75h.008v.008H12v-.008z"
/>
</svg>
<div class=" self-center">
{message?.error?.content ?? message.content}
</div>
</div>
{/if} {/if}
{#if message.citations} {#if message.citations}
<div class="mt-1 mb-2 w-full flex gap-1 items-center flex-wrap"> <Citations citations={message.citations} />
{#each message.citations.reduce((acc, citation) => {
citation.document.forEach((document, index) => {
const metadata = citation.metadata?.[index];
const id = metadata?.source ?? 'N/A';
let source = citation?.source;
if (metadata?.name) {
source = { ...source, name: metadata.name };
}
// Check if ID looks like a URL
if (id.startsWith('http://') || id.startsWith('https://')) {
source = { name: id };
}
const existingSource = acc.find((item) => item.id === id);
if (existingSource) {
existingSource.document.push(document);
existingSource.metadata.push(metadata);
} else {
acc.push( { id: id, source: source, document: [document], metadata: metadata ? [metadata] : [] } );
}
});
return acc;
}, []) as citation, idx}
<div class="flex gap-1 text-xs font-semibold">
<button
class="flex dark:text-gray-300 py-1 px-1 bg-gray-50 hover:bg-gray-100 dark:bg-gray-850 dark:hover:bg-gray-800 transition rounded-xl"
on:click={() => {
showCitationModal = true;
selectedCitation = citation;
}}
>
<div class="bg-white dark:bg-gray-700 rounded-full size-4">
{idx + 1}
</div>
<div class="flex-1 mx-2 line-clamp-1">
{citation.source.name}
</div>
</button>
</div>
{/each}
</div>
{/if} {/if}
</div> </div>
{/if} {/if}

View File

@ -13,6 +13,7 @@
import { marked } from 'marked'; import { marked } from 'marked';
import { processResponseContent, replaceTokens } from '$lib/utils'; import { processResponseContent, replaceTokens } from '$lib/utils';
import MarkdownTokens from './Markdown/MarkdownTokens.svelte'; import MarkdownTokens from './Markdown/MarkdownTokens.svelte';
import Markdown from './Markdown.svelte';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
@ -93,9 +94,7 @@
</div> </div>
{/if} {/if}
<div <div class="chat-{message.role} w-full min-w-full markdown-prose">
class="prose chat-{message.role} w-full max-w-full dark:prose-invert prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line"
>
{#if message.files} {#if message.files}
<div class="mt-2.5 mb-1 w-full flex flex-col justify-end overflow-x-auto gap-1 flex-wrap"> <div class="mt-2.5 mb-1 w-full flex flex-col justify-end overflow-x-auto gap-1 flex-wrap">
{#each message.files as file} {#each message.files as file}
@ -174,14 +173,7 @@
: ' w-full'}" : ' w-full'}"
> >
{#if message.content} {#if message.content}
<div class=""> <Markdown id={message.id} content={message.content} />
{#key message.id}
<MarkdownTokens
id={message.id}
tokens={marked.lexer(processResponseContent(message?.content))}
/>
{/key}
</div>
{/if} {/if}
</div> </div>
</div> </div>