From 5e7237b9cb66d92646b89d0a6771114c3f09e722 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 9 Jun 2024 14:25:31 -0700 Subject: [PATCH] refac: title generation --- backend/apps/ollama/main.py | 2 - backend/config.py | 21 ++ backend/main.py | 200 +++++++++++------- backend/utils/models.py | 10 - backend/utils/task.py | 70 ++++++ src/lib/apis/index.ts | 40 ++++ src/lib/components/chat/Chat.svelte | 48 ++--- .../utils/{index.test.ts => _template_old.ts} | 0 8 files changed, 267 insertions(+), 124 deletions(-) delete mode 100644 backend/utils/models.py create mode 100644 backend/utils/task.py rename src/lib/utils/{index.test.ts => _template_old.ts} (100%) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 9ada17262..3e0674ef4 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -41,8 +41,6 @@ from utils.utils import ( get_admin_user, ) -from utils.models import get_model_id_from_custom_model_id - from config import ( SRC_LOG_LEVELS, diff --git a/backend/config.py b/backend/config.py index fcced4f61..cc8fd0cf8 100644 --- a/backend/config.py +++ b/backend/config.py @@ -618,6 +618,27 @@ ADMIN_EMAIL = PersistentConfig( ) +TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig( + "TITLE_GENERATION_PROMPT_TEMPLATE", + "task.title.prompt_template", + os.environ.get( + "TITLE_GENERATION_PROMPT_TEMPLATE", + """Here is the query: +{{prompt:middletruncate:8000}} + +Create a concise, 3-5 word phrase with an emoji as a title for the previous query. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT. + +Examples of titles: +📉 Stock Market Trends +🍪 Perfect Chocolate Chip Recipe +Evolution of Music Streaming +Remote Work Productivity Tips +Artificial Intelligence in Healthcare +🎮 Video Game Development Insights""", + ), +) + + #################################### # WEBUI_SECRET_KEY #################################### diff --git a/backend/main.py b/backend/main.py index ff87b3da7..dd8fff7b0 100644 --- a/backend/main.py +++ b/backend/main.py @@ -53,6 +53,8 @@ from utils.utils import ( get_current_user, get_http_authorization_cred, ) +from utils.task import title_generation_template + from apps.rag.utils import rag_messages from config import ( @@ -74,8 +76,9 @@ from config import ( SRC_LOG_LEVELS, WEBHOOK_URL, ENABLE_ADMIN_EXPORT, - AppConfig, WEBUI_BUILD_HASH, + TITLE_GENERATION_PROMPT_TEMPLATE, + AppConfig, ) from constants import ERROR_MESSAGES @@ -131,7 +134,7 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.WEBHOOK_URL = WEBHOOK_URL - +app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE app.state.MODELS = {} @@ -240,6 +243,78 @@ class RAGMiddleware(BaseHTTPMiddleware): app.add_middleware(RAGMiddleware) +def filter_pipeline(payload, user): + user = {"id": user.id, "name": user.name, "role": user.role} + model_id = payload["model"] + filters = [ + model + for model in app.state.MODELS.values() + if "pipeline" in model + and "type" in model["pipeline"] + and model["pipeline"]["type"] == "filter" + and ( + model["pipeline"]["pipelines"] == ["*"] + or any( + model_id == target_model_id + for target_model_id in model["pipeline"]["pipelines"] + ) + ) + ] + sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) + + model = app.state.MODELS[model_id] + + if "pipeline" in model: + sorted_filters.append(model) + + for filter in sorted_filters: + r = None + try: + urlIdx = filter["urlIdx"] + + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + if key != "": + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/{filter['id']}/filter/inlet", + headers=headers, + json={ + "user": user, + "body": payload, + }, + ) + + r.raise_for_status() + payload = r.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + if r is not None: + try: + res = r.json() + if "detail" in res: + return JSONResponse( + status_code=r.status_code, + content=res, + ) + except: + pass + + else: + pass + + if "pipeline" not in app.state.MODELS[model_id]: + if "chat_id" in payload: + del payload["chat_id"] + + if "title" in payload: + del payload["title"] + return payload + + class PipelineMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): if request.method == "POST" and ( @@ -255,85 +330,10 @@ class PipelineMiddleware(BaseHTTPMiddleware): # Parse string to JSON data = json.loads(body_str) if body_str else {} - model_id = data["model"] - filters = [ - model - for model in app.state.MODELS.values() - if "pipeline" in model - and "type" in model["pipeline"] - and model["pipeline"]["type"] == "filter" - and ( - model["pipeline"]["pipelines"] == ["*"] - or any( - model_id == target_model_id - for target_model_id in model["pipeline"]["pipelines"] - ) - ) - ] - sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) - - user = None - if len(sorted_filters) > 0: - try: - user = get_current_user( - get_http_authorization_cred( - request.headers.get("Authorization") - ) - ) - user = {"id": user.id, "name": user.name, "role": user.role} - except: - pass - - model = app.state.MODELS[model_id] - - if "pipeline" in model: - sorted_filters.append(model) - - for filter in sorted_filters: - r = None - try: - urlIdx = filter["urlIdx"] - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - if key != "": - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/inlet", - headers=headers, - json={ - "user": user, - "body": data, - }, - ) - - r.raise_for_status() - data = r.json() - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - if r is not None: - try: - res = r.json() - if "detail" in res: - return JSONResponse( - status_code=r.status_code, - content=res, - ) - except: - pass - - else: - pass - - if "pipeline" not in app.state.MODELS[model_id]: - if "chat_id" in data: - del data["chat_id"] - - if "title" in data: - del data["title"] + user = get_current_user( + get_http_authorization_cred(request.headers.get("Authorization")) + ) + data = filter_pipeline(data, user) modified_body_bytes = json.dumps(data).encode("utf-8") # Replace the request body with the modified one @@ -494,6 +494,44 @@ async def get_models(user=Depends(get_verified_user)): return {"data": models} +@app.post("/api/title/completions") +async def generate_title(form_data: dict, user=Depends(get_verified_user)): + print("generate_title") + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + model = app.state.MODELS[model_id] + + template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE + + content = title_generation_template( + template, form_data["prompt"], user.model_dump() + ) + + payload = { + "model": model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "max_tokens": 50, + "chat_id": form_data.get("chat_id", None), + "title": True, + } + + print(payload) + payload = filter_pipeline(payload, user) + + if model["owned_by"] == "ollama": + return await generate_ollama_chat_completion( + OpenAIChatCompletionForm(**payload), user=user + ) + else: + return await generate_openai_chat_completion(payload, user=user) + + @app.post("/api/chat/completions") async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)): model_id = form_data["model"] diff --git a/backend/utils/models.py b/backend/utils/models.py deleted file mode 100644 index c4d675d29..000000000 --- a/backend/utils/models.py +++ /dev/null @@ -1,10 +0,0 @@ -from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse - - -def get_model_id_from_custom_model_id(id: str): - model = Models.get_model_by_id(id) - - if model: - return model.id - else: - return id diff --git a/backend/utils/task.py b/backend/utils/task.py new file mode 100644 index 000000000..b2de4a617 --- /dev/null +++ b/backend/utils/task.py @@ -0,0 +1,70 @@ +import re +import math + +from datetime import datetime +from typing import Optional + + +def prompt_template( + template: str, user_name: str = None, current_location: str = None +) -> str: + # Get the current date + current_date = datetime.now() + + # Format the date to YYYY-MM-DD + formatted_date = current_date.strftime("%Y-%m-%d") + + # Replace {{CURRENT_DATE}} in the template with the formatted date + template = template.replace("{{CURRENT_DATE}}", formatted_date) + + if user_name: + # Replace {{USER_NAME}} in the template with the user's name + template = template.replace("{{USER_NAME}}", user_name) + + if current_location: + # Replace {{CURRENT_LOCATION}} in the template with the current location + template = template.replace("{{CURRENT_LOCATION}}", current_location) + + return template + + +def title_generation_template( + template: str, prompt: str, user: Optional[dict] = None +) -> str: + def replacement_function(match): + full_match = match.group(0) + start_length = match.group(1) + end_length = match.group(2) + middle_length = match.group(3) + + if full_match == "{{prompt}}": + return prompt + elif start_length is not None: + return prompt[: int(start_length)] + elif end_length is not None: + return prompt[-int(end_length) :] + elif middle_length is not None: + middle_length = int(middle_length) + if len(prompt) <= middle_length: + return prompt + start = prompt[: math.ceil(middle_length / 2)] + end = prompt[-math.floor(middle_length / 2) :] + return f"{start}...{end}" + return "" + + template = re.sub( + r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}", + replacement_function, + template, + ) + + template = prompt_template( + template, + **( + {"user_name": user.get("name"), "current_location": user.get("location")} + if user + else {} + ), + ) + + return template diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 70d8b8804..fc96e0878 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -104,6 +104,46 @@ export const chatCompleted = async (token: string, body: ChatCompletedForm) => { return res; }; +export const generateTitle = async ( + token: string = '', + model: string, + prompt: string, + chat_id?: string +) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/title/completions`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + model: model, + prompt: prompt, + ...(chat_id && { chat_id: chat_id }) + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } + return null; + }); + + if (error) { + throw error; + } + + return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? 'New Chat'; +}; + export const getPipelinesList = async (token: string = '') => { let error = null; diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 8aad3ff48..fb846fc24 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -7,6 +7,10 @@ import { goto } from '$app/navigation'; import { page } from '$app/stores'; + import type { Writable } from 'svelte/store'; + import type { i18n as i18nType } from 'i18next'; + import { OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; + import { chatId, chats, @@ -40,24 +44,17 @@ getTagsById, updateChatById } from '$lib/apis/chats'; - import { - generateOpenAIChatCompletion, - generateSearchQuery, - generateTitle - } from '$lib/apis/openai'; + import { generateOpenAIChatCompletion, generateSearchQuery } from '$lib/apis/openai'; + import { runWebSearch } from '$lib/apis/rag'; + import { createOpenAITextStream } from '$lib/apis/streaming'; + import { queryMemory } from '$lib/apis/memories'; + import { getUserSettings } from '$lib/apis/users'; + import { chatCompleted, generateTitle } from '$lib/apis'; + import Banner from '../common/Banner.svelte'; import MessageInput from '$lib/components/chat/MessageInput.svelte'; import Messages from '$lib/components/chat/Messages.svelte'; import Navbar from '$lib/components/layout/Navbar.svelte'; - import { OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; - import { createOpenAITextStream } from '$lib/apis/streaming'; - import { queryMemory } from '$lib/apis/memories'; - import type { Writable } from 'svelte/store'; - import type { i18n as i18nType } from 'i18next'; - import { runWebSearch } from '$lib/apis/rag'; - import Banner from '../common/Banner.svelte'; - import { getUserSettings } from '$lib/apis/users'; - import { chatCompleted } from '$lib/apis'; import CallOverlay from './MessageInput/CallOverlay.svelte'; const i18n: Writable = getContext('i18n'); @@ -1116,26 +1113,15 @@ const generateChatTitle = async (userPrompt) => { if ($settings?.title?.auto ?? true) { - const model = $models.find((model) => model.id === selectedModels[0]); - - const titleModelId = - model?.owned_by === 'openai' ?? false - ? $settings?.title?.modelExternal ?? selectedModels[0] - : $settings?.title?.model ?? selectedModels[0]; - const titleModel = $models.find((model) => model.id === titleModelId); - - console.log(titleModel); const title = await generateTitle( localStorage.token, - $settings?.title?.prompt ?? - $i18n.t( - "Create a concise, 3-5 word phrase as a header for the following query, strictly adhering to the 3-5 word limit and avoiding the use of the word 'title':" - ) + ' {{prompt}}', - titleModelId, + selectedModels[0], userPrompt, - $chatId, - `${WEBUI_BASE_URL}/api` - ); + $chatId + ).catch((error) => { + console.error(error); + return 'New Chat'; + }); return title; } else { diff --git a/src/lib/utils/index.test.ts b/src/lib/utils/_template_old.ts similarity index 100% rename from src/lib/utils/index.test.ts rename to src/lib/utils/_template_old.ts