From 7c81509804e280cddcd0188d21e9ffb0caa8c242 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 18 Aug 2024 20:59:59 +0200 Subject: [PATCH] feat: merge responses --- backend/constants.py | 1 + backend/main.py | 53 +++++++++ backend/utils/task.py | 37 ++++++ src/app.css | 4 + src/lib/apis/index.ts | 36 ++++++ src/lib/components/chat/Chat.svelte | 53 ++++++++- src/lib/components/chat/Messages.svelte | 2 + .../components/chat/Messages/Citations.svelte | 56 +++++++++ src/lib/components/chat/Messages/Error.svelte | 26 +++++ .../components/chat/Messages/Markdown.svelte | 32 ++++++ .../Messages/MultiResponseMessages.svelte | 86 +++++++++++--- .../chat/Messages/ResponseMessage.svelte | 107 ++---------------- .../chat/Messages/UserMessage.svelte | 14 +-- 13 files changed, 378 insertions(+), 129 deletions(-) create mode 100644 src/lib/components/chat/Messages/Citations.svelte create mode 100644 src/lib/components/chat/Messages/Error.svelte create mode 100644 src/lib/components/chat/Messages/Markdown.svelte diff --git a/backend/constants.py b/backend/constants.py index b9c7fc430..d55216bb5 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -100,3 +100,4 @@ class TASKS(str, Enum): EMOJI_GENERATION = "emoji_generation" QUERY_GENERATION = "query_generation" FUNCTION_CALLING = "function_calling" + MOA_RESPONSE_GENERATION = "moa_response_generation" diff --git a/backend/main.py b/backend/main.py index 6c75164dc..d539834ed 100644 --- a/backend/main.py +++ b/backend/main.py @@ -73,6 +73,7 @@ from utils.task import ( title_generation_template, search_query_generation_template, tools_function_calling_generation_template, + moa_response_generation_template, ) from utils.misc import ( get_last_user_message, @@ -1570,6 +1571,58 @@ Message: """{{prompt}}""" return await generate_chat_completions(form_data=payload, user=user) +@app.post("/api/task/moa/completions") +async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)): + print("generate_moa_response") + + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + model_id = get_task_model_id(model_id) + print(model_id) + + template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}" + +Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability. + +Responses from models: {{responses}}""" + + content = moa_response_generation_template( + template, + form_data["prompt"], + form_data["responses"], + ) + + payload = { + "model": model_id, + "messages": [{"role": "user", "content": content}], + "stream": form_data.get("stream", False), + "chat_id": form_data.get("chat_id", None), + "metadata": {"task": str(TASKS.MOA_RESPONSE_GENERATION)}, + } + + log.debug(payload) + + try: + payload = filter_pipeline(payload, user) + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + + if "chat_id" in payload: + del payload["chat_id"] + + return await generate_chat_completions(form_data=payload, user=user) + + ################################## # # Pipelines Endpoints diff --git a/backend/utils/task.py b/backend/utils/task.py index 1b2276c9c..ea9254c4f 100644 --- a/backend/utils/task.py +++ b/backend/utils/task.py @@ -121,6 +121,43 @@ def search_query_generation_template( return template +def moa_response_generation_template( + template: str, prompt: str, responses: list[str] +) -> str: + def replacement_function(match): + full_match = match.group(0) + start_length = match.group(1) + end_length = match.group(2) + middle_length = match.group(3) + + if full_match == "{{prompt}}": + return prompt + elif start_length is not None: + return prompt[: int(start_length)] + elif end_length is not None: + return prompt[-int(end_length) :] + elif middle_length is not None: + middle_length = int(middle_length) + if len(prompt) <= middle_length: + return prompt + start = prompt[: math.ceil(middle_length / 2)] + end = prompt[-math.floor(middle_length / 2) :] + return f"{start}...{end}" + return "" + + template = re.sub( + r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}", + replacement_function, + template, + ) + + responses = [f'"""{response}"""' for response in responses] + responses = "\n\n".join(responses) + + template = template.replace("{{responses}}", responses) + return template + + def tools_function_calling_generation_template(template: str, tools_specs: str) -> str: template = template.replace("{{TOOLS}}", tools_specs) return template diff --git a/src/app.css b/src/app.css index 4345bb377..a421d90ae 100644 --- a/src/app.css +++ b/src/app.css @@ -34,6 +34,10 @@ math { @apply rounded-lg; } +.markdown-prose { + @apply prose dark:prose-invert prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; +} + .markdown a { @apply underline; } diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index c4778cadb..fc01c209d 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -333,6 +333,42 @@ export const generateSearchQuery = async ( return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? prompt; }; +export const generateMoACompletion = async ( + token: string = '', + model: string, + prompt: string, + responses: string[] +) => { + const controller = new AbortController(); + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/task/moa/completions`, { + signal: controller.signal, + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + model: model, + prompt: prompt, + responses: responses, + stream: true + }) + }).catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return [res, controller]; +}; + export const getPipelinesList = async (token: string = '') => { let error = null; diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 3da6eab03..2703d6578 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -54,7 +54,13 @@ import { createOpenAITextStream } from '$lib/apis/streaming'; import { queryMemory } from '$lib/apis/memories'; import { getAndUpdateUserLocation, getUserSettings } from '$lib/apis/users'; - import { chatCompleted, generateTitle, generateSearchQuery, chatAction } from '$lib/apis'; + import { + chatCompleted, + generateTitle, + generateSearchQuery, + chatAction, + generateMoACompletion + } from '$lib/apis'; import Banner from '../common/Banner.svelte'; import MessageInput from '$lib/components/chat/MessageInput.svelte'; @@ -1511,6 +1517,50 @@ return []; }); }; + + const mergeResponses = async (messageId, responses) => { + console.log('mergeResponses', messageId, responses); + const message = history.messages[messageId]; + const mergedResponse = { + status: true, + content: '' + }; + + message.merged = mergedResponse; + try { + const [res, controller] = await generateMoACompletion( + localStorage.token, + message.model, + history.messages[message.parentId].content, + responses + ); + + if (res && res.ok && res.body) { + const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks); + for await (const update of textStream) { + const { value, done, citations, error, usage } = update; + if (error || done) { + break; + } + + if (mergedResponse.content == '' && value == '\n') { + continue; + } else { + mergedResponse.content += value; + messages = messages; + } + + if (autoScroll) { + scrollToBottom(); + } + } + } else { + console.error(res); + } + } catch (e) { + console.error(e); + } + }; @@ -1637,6 +1687,7 @@ {sendPrompt} {continueGeneration} {regenerateResponse} + {mergeResponses} {chatActionHandler} /> diff --git a/src/lib/components/chat/Messages.svelte b/src/lib/components/chat/Messages.svelte index e1e30059e..512014cb4 100644 --- a/src/lib/components/chat/Messages.svelte +++ b/src/lib/components/chat/Messages.svelte @@ -19,6 +19,7 @@ export let sendPrompt: Function; export let continueGeneration: Function; export let regenerateResponse: Function; + export let mergeResponses: Function; export let chatActionHandler: Function; export let user = $_user; @@ -374,6 +375,7 @@ {rateMessage} copyToClipboard={copyToClipboardWithToast} {continueGeneration} + {mergeResponses} {regenerateResponse} on:change={async () => { await updateChatById(localStorage.token, chatId, { diff --git a/src/lib/components/chat/Messages/Citations.svelte b/src/lib/components/chat/Messages/Citations.svelte new file mode 100644 index 000000000..8112d37f4 --- /dev/null +++ b/src/lib/components/chat/Messages/Citations.svelte @@ -0,0 +1,56 @@ + + + + +
+ {#each citations.reduce((acc, citation) => { + citation.document.forEach((document, index) => { + const metadata = citation.metadata?.[index]; + const id = metadata?.source ?? 'N/A'; + let source = citation?.source; + + if (metadata?.name) { + source = { ...source, name: metadata.name }; + } + + // Check if ID looks like a URL + if (id.startsWith('http://') || id.startsWith('https://')) { + source = { name: id }; + } + + const existingSource = acc.find((item) => item.id === id); + + if (existingSource) { + existingSource.document.push(document); + existingSource.metadata.push(metadata); + } else { + acc.push( { id: id, source: source, document: [document], metadata: metadata ? [metadata] : [] } ); + } + }); + return acc; + }, []) as citation, idx} +
+ +
+ {/each} +
diff --git a/src/lib/components/chat/Messages/Error.svelte b/src/lib/components/chat/Messages/Error.svelte new file mode 100644 index 000000000..a1fed2f42 --- /dev/null +++ b/src/lib/components/chat/Messages/Error.svelte @@ -0,0 +1,26 @@ + + +
+ + + + +
+ {content} +
+
diff --git a/src/lib/components/chat/Messages/Markdown.svelte b/src/lib/components/chat/Messages/Markdown.svelte new file mode 100644 index 000000000..2c2f74d76 --- /dev/null +++ b/src/lib/components/chat/Messages/Markdown.svelte @@ -0,0 +1,32 @@ + + +{#key id} + +{/key} diff --git a/src/lib/components/chat/Messages/MultiResponseMessages.svelte b/src/lib/components/chat/Messages/MultiResponseMessages.svelte index d0be97d82..7a760f07a 100644 --- a/src/lib/components/chat/Messages/MultiResponseMessages.svelte +++ b/src/lib/components/chat/Messages/MultiResponseMessages.svelte @@ -1,11 +1,21 @@ - - {#key message.id}
{/if} -
+
{#if (message?.statusHistory ?? [...(message?.status ? [message?.status] : [])]).length > 0} {@const status = ( @@ -408,82 +382,15 @@ {:else if message.content && message.error !== true} - {#key message.id} - - {/key} + {/if} {#if message.error} -
- - - - -
- {message?.error?.content ?? message.content} -
-
+ {/if} {#if message.citations} -
- {#each message.citations.reduce((acc, citation) => { - citation.document.forEach((document, index) => { - const metadata = citation.metadata?.[index]; - const id = metadata?.source ?? 'N/A'; - let source = citation?.source; - - if (metadata?.name) { - source = { ...source, name: metadata.name }; - } - - // Check if ID looks like a URL - if (id.startsWith('http://') || id.startsWith('https://')) { - source = { name: id }; - } - - const existingSource = acc.find((item) => item.id === id); - - if (existingSource) { - existingSource.document.push(document); - existingSource.metadata.push(metadata); - } else { - acc.push( { id: id, source: source, document: [document], metadata: metadata ? [metadata] : [] } ); - } - }); - return acc; - }, []) as citation, idx} -
- -
- {/each} -
+ {/if}
{/if} diff --git a/src/lib/components/chat/Messages/UserMessage.svelte b/src/lib/components/chat/Messages/UserMessage.svelte index 11d14523f..67f682520 100644 --- a/src/lib/components/chat/Messages/UserMessage.svelte +++ b/src/lib/components/chat/Messages/UserMessage.svelte @@ -13,6 +13,7 @@ import { marked } from 'marked'; import { processResponseContent, replaceTokens } from '$lib/utils'; import MarkdownTokens from './Markdown/MarkdownTokens.svelte'; + import Markdown from './Markdown.svelte'; const i18n = getContext('i18n'); @@ -93,9 +94,7 @@
{/if} -
+
{#if message.files}
{#each message.files as file} @@ -174,14 +173,7 @@ : ' w-full'}" > {#if message.content} -
- {#key message.id} - - {/key} -
+ {/if}