mirror of
https://github.com/open-webui/open-webui
synced 2025-01-30 14:29:55 +00:00
refac: title generation
This commit is contained in:
parent
84defafc14
commit
5e7237b9cb
@ -41,8 +41,6 @@ from utils.utils import (
|
|||||||
get_admin_user,
|
get_admin_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
from utils.models import get_model_id_from_custom_model_id
|
|
||||||
|
|
||||||
|
|
||||||
from config import (
|
from config import (
|
||||||
SRC_LOG_LEVELS,
|
SRC_LOG_LEVELS,
|
||||||
|
@ -618,6 +618,27 @@ ADMIN_EMAIL = PersistentConfig(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
||||||
|
"TITLE_GENERATION_PROMPT_TEMPLATE",
|
||||||
|
"task.title.prompt_template",
|
||||||
|
os.environ.get(
|
||||||
|
"TITLE_GENERATION_PROMPT_TEMPLATE",
|
||||||
|
"""Here is the query:
|
||||||
|
{{prompt:middletruncate:8000}}
|
||||||
|
|
||||||
|
Create a concise, 3-5 word phrase with an emoji as a title for the previous query. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
|
||||||
|
|
||||||
|
Examples of titles:
|
||||||
|
📉 Stock Market Trends
|
||||||
|
🍪 Perfect Chocolate Chip Recipe
|
||||||
|
Evolution of Music Streaming
|
||||||
|
Remote Work Productivity Tips
|
||||||
|
Artificial Intelligence in Healthcare
|
||||||
|
🎮 Video Game Development Insights""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# WEBUI_SECRET_KEY
|
# WEBUI_SECRET_KEY
|
||||||
####################################
|
####################################
|
||||||
|
200
backend/main.py
200
backend/main.py
@ -53,6 +53,8 @@ from utils.utils import (
|
|||||||
get_current_user,
|
get_current_user,
|
||||||
get_http_authorization_cred,
|
get_http_authorization_cred,
|
||||||
)
|
)
|
||||||
|
from utils.task import title_generation_template
|
||||||
|
|
||||||
from apps.rag.utils import rag_messages
|
from apps.rag.utils import rag_messages
|
||||||
|
|
||||||
from config import (
|
from config import (
|
||||||
@ -74,8 +76,9 @@ from config import (
|
|||||||
SRC_LOG_LEVELS,
|
SRC_LOG_LEVELS,
|
||||||
WEBHOOK_URL,
|
WEBHOOK_URL,
|
||||||
ENABLE_ADMIN_EXPORT,
|
ENABLE_ADMIN_EXPORT,
|
||||||
AppConfig,
|
|
||||||
WEBUI_BUILD_HASH,
|
WEBUI_BUILD_HASH,
|
||||||
|
TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||||
|
AppConfig,
|
||||||
)
|
)
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
|
||||||
@ -131,7 +134,7 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
|||||||
|
|
||||||
|
|
||||||
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
||||||
|
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
||||||
|
|
||||||
app.state.MODELS = {}
|
app.state.MODELS = {}
|
||||||
|
|
||||||
@ -240,6 +243,78 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
|||||||
app.add_middleware(RAGMiddleware)
|
app.add_middleware(RAGMiddleware)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_pipeline(payload, user):
|
||||||
|
user = {"id": user.id, "name": user.name, "role": user.role}
|
||||||
|
model_id = payload["model"]
|
||||||
|
filters = [
|
||||||
|
model
|
||||||
|
for model in app.state.MODELS.values()
|
||||||
|
if "pipeline" in model
|
||||||
|
and "type" in model["pipeline"]
|
||||||
|
and model["pipeline"]["type"] == "filter"
|
||||||
|
and (
|
||||||
|
model["pipeline"]["pipelines"] == ["*"]
|
||||||
|
or any(
|
||||||
|
model_id == target_model_id
|
||||||
|
for target_model_id in model["pipeline"]["pipelines"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
||||||
|
|
||||||
|
model = app.state.MODELS[model_id]
|
||||||
|
|
||||||
|
if "pipeline" in model:
|
||||||
|
sorted_filters.append(model)
|
||||||
|
|
||||||
|
for filter in sorted_filters:
|
||||||
|
r = None
|
||||||
|
try:
|
||||||
|
urlIdx = filter["urlIdx"]
|
||||||
|
|
||||||
|
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||||
|
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||||
|
|
||||||
|
if key != "":
|
||||||
|
headers = {"Authorization": f"Bearer {key}"}
|
||||||
|
r = requests.post(
|
||||||
|
f"{url}/{filter['id']}/filter/inlet",
|
||||||
|
headers=headers,
|
||||||
|
json={
|
||||||
|
"user": user,
|
||||||
|
"body": payload,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
r.raise_for_status()
|
||||||
|
payload = r.json()
|
||||||
|
except Exception as e:
|
||||||
|
# Handle connection error here
|
||||||
|
print(f"Connection error: {e}")
|
||||||
|
|
||||||
|
if r is not None:
|
||||||
|
try:
|
||||||
|
res = r.json()
|
||||||
|
if "detail" in res:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=r.status_code,
|
||||||
|
content=res,
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if "pipeline" not in app.state.MODELS[model_id]:
|
||||||
|
if "chat_id" in payload:
|
||||||
|
del payload["chat_id"]
|
||||||
|
|
||||||
|
if "title" in payload:
|
||||||
|
del payload["title"]
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
class PipelineMiddleware(BaseHTTPMiddleware):
|
class PipelineMiddleware(BaseHTTPMiddleware):
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
if request.method == "POST" and (
|
if request.method == "POST" and (
|
||||||
@ -255,85 +330,10 @@ class PipelineMiddleware(BaseHTTPMiddleware):
|
|||||||
# Parse string to JSON
|
# Parse string to JSON
|
||||||
data = json.loads(body_str) if body_str else {}
|
data = json.loads(body_str) if body_str else {}
|
||||||
|
|
||||||
model_id = data["model"]
|
user = get_current_user(
|
||||||
filters = [
|
get_http_authorization_cred(request.headers.get("Authorization"))
|
||||||
model
|
)
|
||||||
for model in app.state.MODELS.values()
|
data = filter_pipeline(data, user)
|
||||||
if "pipeline" in model
|
|
||||||
and "type" in model["pipeline"]
|
|
||||||
and model["pipeline"]["type"] == "filter"
|
|
||||||
and (
|
|
||||||
model["pipeline"]["pipelines"] == ["*"]
|
|
||||||
or any(
|
|
||||||
model_id == target_model_id
|
|
||||||
for target_model_id in model["pipeline"]["pipelines"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
|
||||||
|
|
||||||
user = None
|
|
||||||
if len(sorted_filters) > 0:
|
|
||||||
try:
|
|
||||||
user = get_current_user(
|
|
||||||
get_http_authorization_cred(
|
|
||||||
request.headers.get("Authorization")
|
|
||||||
)
|
|
||||||
)
|
|
||||||
user = {"id": user.id, "name": user.name, "role": user.role}
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
model = app.state.MODELS[model_id]
|
|
||||||
|
|
||||||
if "pipeline" in model:
|
|
||||||
sorted_filters.append(model)
|
|
||||||
|
|
||||||
for filter in sorted_filters:
|
|
||||||
r = None
|
|
||||||
try:
|
|
||||||
urlIdx = filter["urlIdx"]
|
|
||||||
|
|
||||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
|
||||||
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
|
||||||
|
|
||||||
if key != "":
|
|
||||||
headers = {"Authorization": f"Bearer {key}"}
|
|
||||||
r = requests.post(
|
|
||||||
f"{url}/{filter['id']}/filter/inlet",
|
|
||||||
headers=headers,
|
|
||||||
json={
|
|
||||||
"user": user,
|
|
||||||
"body": data,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
r.raise_for_status()
|
|
||||||
data = r.json()
|
|
||||||
except Exception as e:
|
|
||||||
# Handle connection error here
|
|
||||||
print(f"Connection error: {e}")
|
|
||||||
|
|
||||||
if r is not None:
|
|
||||||
try:
|
|
||||||
res = r.json()
|
|
||||||
if "detail" in res:
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=r.status_code,
|
|
||||||
content=res,
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if "pipeline" not in app.state.MODELS[model_id]:
|
|
||||||
if "chat_id" in data:
|
|
||||||
del data["chat_id"]
|
|
||||||
|
|
||||||
if "title" in data:
|
|
||||||
del data["title"]
|
|
||||||
|
|
||||||
modified_body_bytes = json.dumps(data).encode("utf-8")
|
modified_body_bytes = json.dumps(data).encode("utf-8")
|
||||||
# Replace the request body with the modified one
|
# Replace the request body with the modified one
|
||||||
@ -494,6 +494,44 @@ async def get_models(user=Depends(get_verified_user)):
|
|||||||
return {"data": models}
|
return {"data": models}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/title/completions")
|
||||||
|
async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
||||||
|
print("generate_title")
|
||||||
|
model_id = form_data["model"]
|
||||||
|
if model_id not in app.state.MODELS:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Model not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
model = app.state.MODELS[model_id]
|
||||||
|
|
||||||
|
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||||
|
|
||||||
|
content = title_generation_template(
|
||||||
|
template, form_data["prompt"], user.model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": model_id,
|
||||||
|
"messages": [{"role": "user", "content": content}],
|
||||||
|
"stream": False,
|
||||||
|
"max_tokens": 50,
|
||||||
|
"chat_id": form_data.get("chat_id", None),
|
||||||
|
"title": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
print(payload)
|
||||||
|
payload = filter_pipeline(payload, user)
|
||||||
|
|
||||||
|
if model["owned_by"] == "ollama":
|
||||||
|
return await generate_ollama_chat_completion(
|
||||||
|
OpenAIChatCompletionForm(**payload), user=user
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return await generate_openai_chat_completion(payload, user=user)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/chat/completions")
|
@app.post("/api/chat/completions")
|
||||||
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
|
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
|
||||||
model_id = form_data["model"]
|
model_id = form_data["model"]
|
||||||
|
@ -1,10 +0,0 @@
|
|||||||
from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_id_from_custom_model_id(id: str):
|
|
||||||
model = Models.get_model_by_id(id)
|
|
||||||
|
|
||||||
if model:
|
|
||||||
return model.id
|
|
||||||
else:
|
|
||||||
return id
|
|
70
backend/utils/task.py
Normal file
70
backend/utils/task.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
import re
|
||||||
|
import math
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_template(
|
||||||
|
template: str, user_name: str = None, current_location: str = None
|
||||||
|
) -> str:
|
||||||
|
# Get the current date
|
||||||
|
current_date = datetime.now()
|
||||||
|
|
||||||
|
# Format the date to YYYY-MM-DD
|
||||||
|
formatted_date = current_date.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
# Replace {{CURRENT_DATE}} in the template with the formatted date
|
||||||
|
template = template.replace("{{CURRENT_DATE}}", formatted_date)
|
||||||
|
|
||||||
|
if user_name:
|
||||||
|
# Replace {{USER_NAME}} in the template with the user's name
|
||||||
|
template = template.replace("{{USER_NAME}}", user_name)
|
||||||
|
|
||||||
|
if current_location:
|
||||||
|
# Replace {{CURRENT_LOCATION}} in the template with the current location
|
||||||
|
template = template.replace("{{CURRENT_LOCATION}}", current_location)
|
||||||
|
|
||||||
|
return template
|
||||||
|
|
||||||
|
|
||||||
|
def title_generation_template(
|
||||||
|
template: str, prompt: str, user: Optional[dict] = None
|
||||||
|
) -> str:
|
||||||
|
def replacement_function(match):
|
||||||
|
full_match = match.group(0)
|
||||||
|
start_length = match.group(1)
|
||||||
|
end_length = match.group(2)
|
||||||
|
middle_length = match.group(3)
|
||||||
|
|
||||||
|
if full_match == "{{prompt}}":
|
||||||
|
return prompt
|
||||||
|
elif start_length is not None:
|
||||||
|
return prompt[: int(start_length)]
|
||||||
|
elif end_length is not None:
|
||||||
|
return prompt[-int(end_length) :]
|
||||||
|
elif middle_length is not None:
|
||||||
|
middle_length = int(middle_length)
|
||||||
|
if len(prompt) <= middle_length:
|
||||||
|
return prompt
|
||||||
|
start = prompt[: math.ceil(middle_length / 2)]
|
||||||
|
end = prompt[-math.floor(middle_length / 2) :]
|
||||||
|
return f"{start}...{end}"
|
||||||
|
return ""
|
||||||
|
|
||||||
|
template = re.sub(
|
||||||
|
r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}",
|
||||||
|
replacement_function,
|
||||||
|
template,
|
||||||
|
)
|
||||||
|
|
||||||
|
template = prompt_template(
|
||||||
|
template,
|
||||||
|
**(
|
||||||
|
{"user_name": user.get("name"), "current_location": user.get("location")}
|
||||||
|
if user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return template
|
@ -104,6 +104,46 @@ export const chatCompleted = async (token: string, body: ChatCompletedForm) => {
|
|||||||
return res;
|
return res;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const generateTitle = async (
|
||||||
|
token: string = '',
|
||||||
|
model: string,
|
||||||
|
prompt: string,
|
||||||
|
chat_id?: string
|
||||||
|
) => {
|
||||||
|
let error = null;
|
||||||
|
|
||||||
|
const res = await fetch(`${WEBUI_BASE_URL}/api/title/completions`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
Accept: 'application/json',
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
Authorization: `Bearer ${token}`
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
model: model,
|
||||||
|
prompt: prompt,
|
||||||
|
...(chat_id && { chat_id: chat_id })
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.then(async (res) => {
|
||||||
|
if (!res.ok) throw await res.json();
|
||||||
|
return res.json();
|
||||||
|
})
|
||||||
|
.catch((err) => {
|
||||||
|
console.log(err);
|
||||||
|
if ('detail' in err) {
|
||||||
|
error = err.detail;
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
|
return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? 'New Chat';
|
||||||
|
};
|
||||||
|
|
||||||
export const getPipelinesList = async (token: string = '') => {
|
export const getPipelinesList = async (token: string = '') => {
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
|
@ -7,6 +7,10 @@
|
|||||||
import { goto } from '$app/navigation';
|
import { goto } from '$app/navigation';
|
||||||
import { page } from '$app/stores';
|
import { page } from '$app/stores';
|
||||||
|
|
||||||
|
import type { Writable } from 'svelte/store';
|
||||||
|
import type { i18n as i18nType } from 'i18next';
|
||||||
|
import { OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
|
||||||
|
|
||||||
import {
|
import {
|
||||||
chatId,
|
chatId,
|
||||||
chats,
|
chats,
|
||||||
@ -40,24 +44,17 @@
|
|||||||
getTagsById,
|
getTagsById,
|
||||||
updateChatById
|
updateChatById
|
||||||
} from '$lib/apis/chats';
|
} from '$lib/apis/chats';
|
||||||
import {
|
import { generateOpenAIChatCompletion, generateSearchQuery } from '$lib/apis/openai';
|
||||||
generateOpenAIChatCompletion,
|
import { runWebSearch } from '$lib/apis/rag';
|
||||||
generateSearchQuery,
|
import { createOpenAITextStream } from '$lib/apis/streaming';
|
||||||
generateTitle
|
import { queryMemory } from '$lib/apis/memories';
|
||||||
} from '$lib/apis/openai';
|
import { getUserSettings } from '$lib/apis/users';
|
||||||
|
import { chatCompleted, generateTitle } from '$lib/apis';
|
||||||
|
|
||||||
|
import Banner from '../common/Banner.svelte';
|
||||||
import MessageInput from '$lib/components/chat/MessageInput.svelte';
|
import MessageInput from '$lib/components/chat/MessageInput.svelte';
|
||||||
import Messages from '$lib/components/chat/Messages.svelte';
|
import Messages from '$lib/components/chat/Messages.svelte';
|
||||||
import Navbar from '$lib/components/layout/Navbar.svelte';
|
import Navbar from '$lib/components/layout/Navbar.svelte';
|
||||||
import { OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
|
|
||||||
import { createOpenAITextStream } from '$lib/apis/streaming';
|
|
||||||
import { queryMemory } from '$lib/apis/memories';
|
|
||||||
import type { Writable } from 'svelte/store';
|
|
||||||
import type { i18n as i18nType } from 'i18next';
|
|
||||||
import { runWebSearch } from '$lib/apis/rag';
|
|
||||||
import Banner from '../common/Banner.svelte';
|
|
||||||
import { getUserSettings } from '$lib/apis/users';
|
|
||||||
import { chatCompleted } from '$lib/apis';
|
|
||||||
import CallOverlay from './MessageInput/CallOverlay.svelte';
|
import CallOverlay from './MessageInput/CallOverlay.svelte';
|
||||||
|
|
||||||
const i18n: Writable<i18nType> = getContext('i18n');
|
const i18n: Writable<i18nType> = getContext('i18n');
|
||||||
@ -1116,26 +1113,15 @@
|
|||||||
|
|
||||||
const generateChatTitle = async (userPrompt) => {
|
const generateChatTitle = async (userPrompt) => {
|
||||||
if ($settings?.title?.auto ?? true) {
|
if ($settings?.title?.auto ?? true) {
|
||||||
const model = $models.find((model) => model.id === selectedModels[0]);
|
|
||||||
|
|
||||||
const titleModelId =
|
|
||||||
model?.owned_by === 'openai' ?? false
|
|
||||||
? $settings?.title?.modelExternal ?? selectedModels[0]
|
|
||||||
: $settings?.title?.model ?? selectedModels[0];
|
|
||||||
const titleModel = $models.find((model) => model.id === titleModelId);
|
|
||||||
|
|
||||||
console.log(titleModel);
|
|
||||||
const title = await generateTitle(
|
const title = await generateTitle(
|
||||||
localStorage.token,
|
localStorage.token,
|
||||||
$settings?.title?.prompt ??
|
selectedModels[0],
|
||||||
$i18n.t(
|
|
||||||
"Create a concise, 3-5 word phrase as a header for the following query, strictly adhering to the 3-5 word limit and avoiding the use of the word 'title':"
|
|
||||||
) + ' {{prompt}}',
|
|
||||||
titleModelId,
|
|
||||||
userPrompt,
|
userPrompt,
|
||||||
$chatId,
|
$chatId
|
||||||
`${WEBUI_BASE_URL}/api`
|
).catch((error) => {
|
||||||
);
|
console.error(error);
|
||||||
|
return 'New Chat';
|
||||||
|
});
|
||||||
|
|
||||||
return title;
|
return title;
|
||||||
} else {
|
} else {
|
||||||
|
Loading…
Reference in New Issue
Block a user