refac: title generation

This commit is contained in:
Timothy J. Baek 2024-06-09 14:25:31 -07:00
parent 84defafc14
commit 5e7237b9cb
8 changed files with 267 additions and 124 deletions

View File

@ -41,8 +41,6 @@ from utils.utils import (
get_admin_user,
)
from utils.models import get_model_id_from_custom_model_id
from config import (
SRC_LOG_LEVELS,

View File

@ -618,6 +618,27 @@ ADMIN_EMAIL = PersistentConfig(
)
TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"TITLE_GENERATION_PROMPT_TEMPLATE",
"task.title.prompt_template",
os.environ.get(
"TITLE_GENERATION_PROMPT_TEMPLATE",
"""Here is the query:
{{prompt:middletruncate:8000}}
Create a concise, 3-5 word phrase with an emoji as a title for the previous query. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
Examples of titles:
📉 Stock Market Trends
🍪 Perfect Chocolate Chip Recipe
Evolution of Music Streaming
Remote Work Productivity Tips
Artificial Intelligence in Healthcare
🎮 Video Game Development Insights""",
),
)
####################################
# WEBUI_SECRET_KEY
####################################

View File

@ -53,6 +53,8 @@ from utils.utils import (
get_current_user,
get_http_authorization_cred,
)
from utils.task import title_generation_template
from apps.rag.utils import rag_messages
from config import (
@ -74,8 +76,9 @@ from config import (
SRC_LOG_LEVELS,
WEBHOOK_URL,
ENABLE_ADMIN_EXPORT,
AppConfig,
WEBUI_BUILD_HASH,
TITLE_GENERATION_PROMPT_TEMPLATE,
AppConfig,
)
from constants import ERROR_MESSAGES
@ -131,7 +134,7 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
app.state.MODELS = {}
@ -240,6 +243,78 @@ class RAGMiddleware(BaseHTTPMiddleware):
app.add_middleware(RAGMiddleware)
def filter_pipeline(payload, user):
user = {"id": user.id, "name": user.name, "role": user.role}
model_id = payload["model"]
filters = [
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
model = app.state.MODELS[model_id]
if "pipeline" in model:
sorted_filters.append(model)
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json={
"user": user,
"body": payload,
},
)
r.raise_for_status()
payload = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except:
pass
else:
pass
if "pipeline" not in app.state.MODELS[model_id]:
if "chat_id" in payload:
del payload["chat_id"]
if "title" in payload:
del payload["title"]
return payload
class PipelineMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if request.method == "POST" and (
@ -255,85 +330,10 @@ class PipelineMiddleware(BaseHTTPMiddleware):
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
model_id = data["model"]
filters = [
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
user = None
if len(sorted_filters) > 0:
try:
user = get_current_user(
get_http_authorization_cred(
request.headers.get("Authorization")
)
)
user = {"id": user.id, "name": user.name, "role": user.role}
except:
pass
model = app.state.MODELS[model_id]
if "pipeline" in model:
sorted_filters.append(model)
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json={
"user": user,
"body": data,
},
)
r.raise_for_status()
data = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except:
pass
else:
pass
if "pipeline" not in app.state.MODELS[model_id]:
if "chat_id" in data:
del data["chat_id"]
if "title" in data:
del data["title"]
user = get_current_user(
get_http_authorization_cred(request.headers.get("Authorization"))
)
data = filter_pipeline(data, user)
modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one
@ -494,6 +494,44 @@ async def get_models(user=Depends(get_verified_user)):
return {"data": models}
@app.post("/api/title/completions")
async def generate_title(form_data: dict, user=Depends(get_verified_user)):
print("generate_title")
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
content = title_generation_template(
template, form_data["prompt"], user.model_dump()
)
payload = {
"model": model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": 50,
"chat_id": form_data.get("chat_id", None),
"title": True,
}
print(payload)
payload = filter_pipeline(payload, user)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
return await generate_openai_chat_completion(payload, user=user)
@app.post("/api/chat/completions")
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
model_id = form_data["model"]

View File

@ -1,10 +0,0 @@
from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
def get_model_id_from_custom_model_id(id: str):
model = Models.get_model_by_id(id)
if model:
return model.id
else:
return id

70
backend/utils/task.py Normal file
View File

@ -0,0 +1,70 @@
import re
import math
from datetime import datetime
from typing import Optional
def prompt_template(
template: str, user_name: str = None, current_location: str = None
) -> str:
# Get the current date
current_date = datetime.now()
# Format the date to YYYY-MM-DD
formatted_date = current_date.strftime("%Y-%m-%d")
# Replace {{CURRENT_DATE}} in the template with the formatted date
template = template.replace("{{CURRENT_DATE}}", formatted_date)
if user_name:
# Replace {{USER_NAME}} in the template with the user's name
template = template.replace("{{USER_NAME}}", user_name)
if current_location:
# Replace {{CURRENT_LOCATION}} in the template with the current location
template = template.replace("{{CURRENT_LOCATION}}", current_location)
return template
def title_generation_template(
template: str, prompt: str, user: Optional[dict] = None
) -> str:
def replacement_function(match):
full_match = match.group(0)
start_length = match.group(1)
end_length = match.group(2)
middle_length = match.group(3)
if full_match == "{{prompt}}":
return prompt
elif start_length is not None:
return prompt[: int(start_length)]
elif end_length is not None:
return prompt[-int(end_length) :]
elif middle_length is not None:
middle_length = int(middle_length)
if len(prompt) <= middle_length:
return prompt
start = prompt[: math.ceil(middle_length / 2)]
end = prompt[-math.floor(middle_length / 2) :]
return f"{start}...{end}"
return ""
template = re.sub(
r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}",
replacement_function,
template,
)
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "current_location": user.get("location")}
if user
else {}
),
)
return template

View File

@ -104,6 +104,46 @@ export const chatCompleted = async (token: string, body: ChatCompletedForm) => {
return res;
};
export const generateTitle = async (
token: string = '',
model: string,
prompt: string,
chat_id?: string
) => {
let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/title/completions`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
model: model,
prompt: prompt,
...(chat_id && { chat_id: chat_id })
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
}
return null;
});
if (error) {
throw error;
}
return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? 'New Chat';
};
export const getPipelinesList = async (token: string = '') => {
let error = null;

View File

@ -7,6 +7,10 @@
import { goto } from '$app/navigation';
import { page } from '$app/stores';
import type { Writable } from 'svelte/store';
import type { i18n as i18nType } from 'i18next';
import { OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
import {
chatId,
chats,
@ -40,24 +44,17 @@
getTagsById,
updateChatById
} from '$lib/apis/chats';
import {
generateOpenAIChatCompletion,
generateSearchQuery,
generateTitle
} from '$lib/apis/openai';
import { generateOpenAIChatCompletion, generateSearchQuery } from '$lib/apis/openai';
import { runWebSearch } from '$lib/apis/rag';
import { createOpenAITextStream } from '$lib/apis/streaming';
import { queryMemory } from '$lib/apis/memories';
import { getUserSettings } from '$lib/apis/users';
import { chatCompleted, generateTitle } from '$lib/apis';
import Banner from '../common/Banner.svelte';
import MessageInput from '$lib/components/chat/MessageInput.svelte';
import Messages from '$lib/components/chat/Messages.svelte';
import Navbar from '$lib/components/layout/Navbar.svelte';
import { OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
import { createOpenAITextStream } from '$lib/apis/streaming';
import { queryMemory } from '$lib/apis/memories';
import type { Writable } from 'svelte/store';
import type { i18n as i18nType } from 'i18next';
import { runWebSearch } from '$lib/apis/rag';
import Banner from '../common/Banner.svelte';
import { getUserSettings } from '$lib/apis/users';
import { chatCompleted } from '$lib/apis';
import CallOverlay from './MessageInput/CallOverlay.svelte';
const i18n: Writable<i18nType> = getContext('i18n');
@ -1116,26 +1113,15 @@
const generateChatTitle = async (userPrompt) => {
if ($settings?.title?.auto ?? true) {
const model = $models.find((model) => model.id === selectedModels[0]);
const titleModelId =
model?.owned_by === 'openai' ?? false
? $settings?.title?.modelExternal ?? selectedModels[0]
: $settings?.title?.model ?? selectedModels[0];
const titleModel = $models.find((model) => model.id === titleModelId);
console.log(titleModel);
const title = await generateTitle(
localStorage.token,
$settings?.title?.prompt ??
$i18n.t(
"Create a concise, 3-5 word phrase as a header for the following query, strictly adhering to the 3-5 word limit and avoiding the use of the word 'title':"
) + ' {{prompt}}',
titleModelId,
selectedModels[0],
userPrompt,
$chatId,
`${WEBUI_BASE_URL}/api`
);
$chatId
).catch((error) => {
console.error(error);
return 'New Chat';
});
return title;
} else {