mirror of
https://github.com/open-webui/open-webui
synced 2024-11-16 21:42:58 +00:00
refac: search query task
This commit is contained in:
parent
aa1bb4fb6d
commit
591cd993c2
@ -618,6 +618,18 @@ ADMIN_EMAIL = PersistentConfig(
|
||||
)
|
||||
|
||||
|
||||
TASK_MODEL = PersistentConfig(
|
||||
"TASK_MODEL",
|
||||
"task.model.default",
|
||||
os.environ.get("TASK_MODEL", ""),
|
||||
)
|
||||
|
||||
TASK_MODEL_EXTERNAL = PersistentConfig(
|
||||
"TASK_MODEL_EXTERNAL",
|
||||
"task.model.external",
|
||||
os.environ.get("TASK_MODEL_EXTERNAL", ""),
|
||||
)
|
||||
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE",
|
||||
"task.title.prompt_template",
|
||||
@ -639,6 +651,19 @@ Artificial Intelligence in Healthcare
|
||||
)
|
||||
|
||||
|
||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
||||
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
|
||||
"task.search.prompt_template",
|
||||
os.environ.get(
|
||||
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
|
||||
"""You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is {{CURRENT_DATE}}.
|
||||
|
||||
Question:
|
||||
{{prompt:end:4000}}""",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
####################################
|
||||
# WEBUI_SECRET_KEY
|
||||
####################################
|
||||
|
116
backend/main.py
116
backend/main.py
@ -53,7 +53,7 @@ from utils.utils import (
|
||||
get_current_user,
|
||||
get_http_authorization_cred,
|
||||
)
|
||||
from utils.task import title_generation_template
|
||||
from utils.task import title_generation_template, search_query_generation_template
|
||||
|
||||
from apps.rag.utils import rag_messages
|
||||
|
||||
@ -77,7 +77,10 @@ from config import (
|
||||
WEBHOOK_URL,
|
||||
ENABLE_ADMIN_EXPORT,
|
||||
WEBUI_BUILD_HASH,
|
||||
TASK_MODEL,
|
||||
TASK_MODEL_EXTERNAL,
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
AppConfig,
|
||||
)
|
||||
from constants import ERROR_MESSAGES
|
||||
@ -132,9 +135,15 @@ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
|
||||
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
|
||||
|
||||
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
||||
|
||||
|
||||
app.state.config.TASK_MODEL = TASK_MODEL
|
||||
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
|
||||
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
app.state.MODELS = {}
|
||||
|
||||
@ -494,9 +503,46 @@ async def get_models(user=Depends(get_verified_user)):
|
||||
return {"data": models}
|
||||
|
||||
|
||||
@app.get("/api/task/config")
|
||||
async def get_task_config(user=Depends(get_verified_user)):
|
||||
return {
|
||||
"TASK_MODEL": app.state.config.TASK_MODEL,
|
||||
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
}
|
||||
|
||||
|
||||
class TaskConfigForm(BaseModel):
|
||||
TASK_MODEL: Optional[str]
|
||||
TASK_MODEL_EXTERNAL: Optional[str]
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
|
||||
|
||||
|
||||
@app.post("/api/task/config/update")
|
||||
async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)):
|
||||
app.state.config.TASK_MODEL = form_data.TASK_MODEL
|
||||
app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
|
||||
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
return {
|
||||
"TASK_MODEL": app.state.config.TASK_MODEL,
|
||||
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/api/task/title/completions")
|
||||
async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
||||
print("generate_title")
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
raise HTTPException(
|
||||
@ -504,6 +550,20 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
if app.state.MODELS[model_id]["owned_by"] == "ollama":
|
||||
if app.state.config.TASK_MODEL:
|
||||
task_model_id = app.state.config.TASK_MODEL
|
||||
if task_model_id in app.state.MODELS:
|
||||
model_id = task_model_id
|
||||
else:
|
||||
if app.state.config.TASK_MODEL_EXTERNAL:
|
||||
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
||||
if task_model_id in app.state.MODELS:
|
||||
model_id = task_model_id
|
||||
|
||||
print(model_id)
|
||||
model = app.state.MODELS[model_id]
|
||||
|
||||
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
@ -532,6 +592,57 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
||||
return await generate_openai_chat_completion(payload, user=user)
|
||||
|
||||
|
||||
@app.post("/api/task/query/completions")
|
||||
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
|
||||
print("generate_search_query")
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
if app.state.MODELS[model_id]["owned_by"] == "ollama":
|
||||
if app.state.config.TASK_MODEL:
|
||||
task_model_id = app.state.config.TASK_MODEL
|
||||
if task_model_id in app.state.MODELS:
|
||||
model_id = task_model_id
|
||||
else:
|
||||
if app.state.config.TASK_MODEL_EXTERNAL:
|
||||
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
||||
if task_model_id in app.state.MODELS:
|
||||
model_id = task_model_id
|
||||
|
||||
print(model_id)
|
||||
model = app.state.MODELS[model_id]
|
||||
|
||||
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = search_query_generation_template(
|
||||
template, form_data["prompt"], user.model_dump()
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"max_tokens": 30,
|
||||
}
|
||||
|
||||
print(payload)
|
||||
payload = filter_pipeline(payload, user)
|
||||
|
||||
if model["owned_by"] == "ollama":
|
||||
return await generate_ollama_chat_completion(
|
||||
OpenAIChatCompletionForm(**payload), user=user
|
||||
)
|
||||
else:
|
||||
return await generate_openai_chat_completion(payload, user=user)
|
||||
|
||||
|
||||
@app.post("/api/chat/completions")
|
||||
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
|
||||
model_id = form_data["model"]
|
||||
@ -542,7 +653,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
|
||||
)
|
||||
|
||||
model = app.state.MODELS[model_id]
|
||||
|
||||
print(model)
|
||||
|
||||
if model["owned_by"] == "ollama":
|
||||
|
@ -68,3 +68,45 @@ def title_generation_template(
|
||||
)
|
||||
|
||||
return template
|
||||
|
||||
|
||||
def search_query_generation_template(
|
||||
template: str, prompt: str, user: Optional[dict] = None
|
||||
) -> str:
|
||||
|
||||
def replacement_function(match):
|
||||
full_match = match.group(0)
|
||||
start_length = match.group(1)
|
||||
end_length = match.group(2)
|
||||
middle_length = match.group(3)
|
||||
|
||||
if full_match == "{{prompt}}":
|
||||
return prompt
|
||||
elif start_length is not None:
|
||||
return prompt[: int(start_length)]
|
||||
elif end_length is not None:
|
||||
return prompt[-int(end_length) :]
|
||||
elif middle_length is not None:
|
||||
middle_length = int(middle_length)
|
||||
if len(prompt) <= middle_length:
|
||||
return prompt
|
||||
start = prompt[: math.ceil(middle_length / 2)]
|
||||
end = prompt[-math.floor(middle_length / 2) :]
|
||||
return f"{start}...{end}"
|
||||
return ""
|
||||
|
||||
template = re.sub(
|
||||
r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}",
|
||||
replacement_function,
|
||||
template,
|
||||
)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "current_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
return template
|
||||
|
@ -144,6 +144,46 @@ export const generateTitle = async (
|
||||
return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? 'New Chat';
|
||||
};
|
||||
|
||||
export const generateSearchQuery = async (
|
||||
token: string = '',
|
||||
model: string,
|
||||
messages: object[],
|
||||
prompt: string
|
||||
) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${WEBUI_BASE_URL}/api/task/query/completions`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: model,
|
||||
messages: messages,
|
||||
prompt: prompt
|
||||
})
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw await res.json();
|
||||
return res.json();
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
if ('detail' in err) {
|
||||
error = err.detail;
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? prompt;
|
||||
};
|
||||
|
||||
export const getPipelinesList = async (token: string = '') => {
|
||||
let error = null;
|
||||
|
||||
|
@ -44,12 +44,12 @@
|
||||
getTagsById,
|
||||
updateChatById
|
||||
} from '$lib/apis/chats';
|
||||
import { generateOpenAIChatCompletion, generateSearchQuery } from '$lib/apis/openai';
|
||||
import { generateOpenAIChatCompletion } from '$lib/apis/openai';
|
||||
import { runWebSearch } from '$lib/apis/rag';
|
||||
import { createOpenAITextStream } from '$lib/apis/streaming';
|
||||
import { queryMemory } from '$lib/apis/memories';
|
||||
import { getUserSettings } from '$lib/apis/users';
|
||||
import { chatCompleted, generateTitle } from '$lib/apis';
|
||||
import { chatCompleted, generateTitle, generateSearchQuery } from '$lib/apis';
|
||||
|
||||
import Banner from '../common/Banner.svelte';
|
||||
import MessageInput from '$lib/components/chat/MessageInput.svelte';
|
||||
@ -508,7 +508,7 @@
|
||||
const prompt = history.messages[parentId].content;
|
||||
let searchQuery = prompt;
|
||||
if (prompt.length > 100) {
|
||||
searchQuery = await generateChatSearchQuery(model, prompt);
|
||||
searchQuery = await generateSearchQuery(localStorage.token, model, messages, prompt);
|
||||
if (!searchQuery) {
|
||||
toast.warning($i18n.t('No search query generated'));
|
||||
responseMessage.status = {
|
||||
@ -1129,29 +1129,6 @@
|
||||
}
|
||||
};
|
||||
|
||||
const generateChatSearchQuery = async (modelId: string, prompt: string) => {
|
||||
const model = $models.find((model) => model.id === modelId);
|
||||
const taskModelId =
|
||||
model?.owned_by === 'openai' ?? false
|
||||
? $settings?.title?.modelExternal ?? modelId
|
||||
: $settings?.title?.model ?? modelId;
|
||||
const taskModel = $models.find((model) => model.id === taskModelId);
|
||||
|
||||
const previousMessages = messages
|
||||
.filter((message) => message.role === 'user')
|
||||
.map((message) => message.content);
|
||||
|
||||
return await generateSearchQuery(
|
||||
localStorage.token,
|
||||
taskModelId,
|
||||
previousMessages,
|
||||
prompt,
|
||||
taskModel?.owned_by === 'openai' ?? false
|
||||
? `${OPENAI_API_BASE_URL}`
|
||||
: `${OLLAMA_API_BASE_URL}/v1`
|
||||
);
|
||||
};
|
||||
|
||||
const setChatTitle = async (_chatId, _title) => {
|
||||
if (_chatId === $chatId) {
|
||||
title = _title;
|
||||
|
Loading…
Reference in New Issue
Block a user