mirror of
https://github.com/open-webui/open-webui
synced 2024-11-16 21:42:58 +00:00
refac: search query task
This commit is contained in:
parent
aa1bb4fb6d
commit
591cd993c2
@ -618,6 +618,18 @@ ADMIN_EMAIL = PersistentConfig(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
TASK_MODEL = PersistentConfig(
|
||||||
|
"TASK_MODEL",
|
||||||
|
"task.model.default",
|
||||||
|
os.environ.get("TASK_MODEL", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
TASK_MODEL_EXTERNAL = PersistentConfig(
|
||||||
|
"TASK_MODEL_EXTERNAL",
|
||||||
|
"task.model.external",
|
||||||
|
os.environ.get("TASK_MODEL_EXTERNAL", ""),
|
||||||
|
)
|
||||||
|
|
||||||
TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
||||||
"TITLE_GENERATION_PROMPT_TEMPLATE",
|
"TITLE_GENERATION_PROMPT_TEMPLATE",
|
||||||
"task.title.prompt_template",
|
"task.title.prompt_template",
|
||||||
@ -639,6 +651,19 @@ Artificial Intelligence in Healthcare
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
||||||
|
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
|
||||||
|
"task.search.prompt_template",
|
||||||
|
os.environ.get(
|
||||||
|
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
|
||||||
|
"""You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is {{CURRENT_DATE}}.
|
||||||
|
|
||||||
|
Question:
|
||||||
|
{{prompt:end:4000}}""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# WEBUI_SECRET_KEY
|
# WEBUI_SECRET_KEY
|
||||||
####################################
|
####################################
|
||||||
|
116
backend/main.py
116
backend/main.py
@ -53,7 +53,7 @@ from utils.utils import (
|
|||||||
get_current_user,
|
get_current_user,
|
||||||
get_http_authorization_cred,
|
get_http_authorization_cred,
|
||||||
)
|
)
|
||||||
from utils.task import title_generation_template
|
from utils.task import title_generation_template, search_query_generation_template
|
||||||
|
|
||||||
from apps.rag.utils import rag_messages
|
from apps.rag.utils import rag_messages
|
||||||
|
|
||||||
@ -77,7 +77,10 @@ from config import (
|
|||||||
WEBHOOK_URL,
|
WEBHOOK_URL,
|
||||||
ENABLE_ADMIN_EXPORT,
|
ENABLE_ADMIN_EXPORT,
|
||||||
WEBUI_BUILD_HASH,
|
WEBUI_BUILD_HASH,
|
||||||
|
TASK_MODEL,
|
||||||
|
TASK_MODEL_EXTERNAL,
|
||||||
TITLE_GENERATION_PROMPT_TEMPLATE,
|
TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||||
|
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||||
AppConfig,
|
AppConfig,
|
||||||
)
|
)
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
@ -132,9 +135,15 @@ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
|
|||||||
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||||
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||||
|
|
||||||
|
|
||||||
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
||||||
|
|
||||||
|
|
||||||
|
app.state.config.TASK_MODEL = TASK_MODEL
|
||||||
|
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
|
||||||
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
||||||
|
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
||||||
|
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||||
|
)
|
||||||
|
|
||||||
app.state.MODELS = {}
|
app.state.MODELS = {}
|
||||||
|
|
||||||
@ -494,9 +503,46 @@ async def get_models(user=Depends(get_verified_user)):
|
|||||||
return {"data": models}
|
return {"data": models}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/task/config")
|
||||||
|
async def get_task_config(user=Depends(get_verified_user)):
|
||||||
|
return {
|
||||||
|
"TASK_MODEL": app.state.config.TASK_MODEL,
|
||||||
|
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
|
||||||
|
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||||
|
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TaskConfigForm(BaseModel):
|
||||||
|
TASK_MODEL: Optional[str]
|
||||||
|
TASK_MODEL_EXTERNAL: Optional[str]
|
||||||
|
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
||||||
|
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/task/config/update")
|
||||||
|
async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)):
|
||||||
|
app.state.config.TASK_MODEL = form_data.TASK_MODEL
|
||||||
|
app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
|
||||||
|
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
|
||||||
|
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||||
|
)
|
||||||
|
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
||||||
|
form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"TASK_MODEL": app.state.config.TASK_MODEL,
|
||||||
|
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
|
||||||
|
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||||
|
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/task/title/completions")
|
@app.post("/api/task/title/completions")
|
||||||
async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
||||||
print("generate_title")
|
print("generate_title")
|
||||||
|
|
||||||
model_id = form_data["model"]
|
model_id = form_data["model"]
|
||||||
if model_id not in app.state.MODELS:
|
if model_id not in app.state.MODELS:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@ -504,6 +550,20 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
|||||||
detail="Model not found",
|
detail="Model not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if the user has a custom task model
|
||||||
|
# If the user has a custom task model, use that model
|
||||||
|
if app.state.MODELS[model_id]["owned_by"] == "ollama":
|
||||||
|
if app.state.config.TASK_MODEL:
|
||||||
|
task_model_id = app.state.config.TASK_MODEL
|
||||||
|
if task_model_id in app.state.MODELS:
|
||||||
|
model_id = task_model_id
|
||||||
|
else:
|
||||||
|
if app.state.config.TASK_MODEL_EXTERNAL:
|
||||||
|
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
||||||
|
if task_model_id in app.state.MODELS:
|
||||||
|
model_id = task_model_id
|
||||||
|
|
||||||
|
print(model_id)
|
||||||
model = app.state.MODELS[model_id]
|
model = app.state.MODELS[model_id]
|
||||||
|
|
||||||
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
|
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||||
@ -532,6 +592,57 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
|||||||
return await generate_openai_chat_completion(payload, user=user)
|
return await generate_openai_chat_completion(payload, user=user)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/task/query/completions")
|
||||||
|
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
|
||||||
|
print("generate_search_query")
|
||||||
|
|
||||||
|
model_id = form_data["model"]
|
||||||
|
if model_id not in app.state.MODELS:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Model not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the user has a custom task model
|
||||||
|
# If the user has a custom task model, use that model
|
||||||
|
if app.state.MODELS[model_id]["owned_by"] == "ollama":
|
||||||
|
if app.state.config.TASK_MODEL:
|
||||||
|
task_model_id = app.state.config.TASK_MODEL
|
||||||
|
if task_model_id in app.state.MODELS:
|
||||||
|
model_id = task_model_id
|
||||||
|
else:
|
||||||
|
if app.state.config.TASK_MODEL_EXTERNAL:
|
||||||
|
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
||||||
|
if task_model_id in app.state.MODELS:
|
||||||
|
model_id = task_model_id
|
||||||
|
|
||||||
|
print(model_id)
|
||||||
|
model = app.state.MODELS[model_id]
|
||||||
|
|
||||||
|
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||||
|
|
||||||
|
content = search_query_generation_template(
|
||||||
|
template, form_data["prompt"], user.model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": model_id,
|
||||||
|
"messages": [{"role": "user", "content": content}],
|
||||||
|
"stream": False,
|
||||||
|
"max_tokens": 30,
|
||||||
|
}
|
||||||
|
|
||||||
|
print(payload)
|
||||||
|
payload = filter_pipeline(payload, user)
|
||||||
|
|
||||||
|
if model["owned_by"] == "ollama":
|
||||||
|
return await generate_ollama_chat_completion(
|
||||||
|
OpenAIChatCompletionForm(**payload), user=user
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return await generate_openai_chat_completion(payload, user=user)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/chat/completions")
|
@app.post("/api/chat/completions")
|
||||||
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
|
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
|
||||||
model_id = form_data["model"]
|
model_id = form_data["model"]
|
||||||
@ -542,7 +653,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
|
|||||||
)
|
)
|
||||||
|
|
||||||
model = app.state.MODELS[model_id]
|
model = app.state.MODELS[model_id]
|
||||||
|
|
||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
if model["owned_by"] == "ollama":
|
if model["owned_by"] == "ollama":
|
||||||
|
@ -68,3 +68,45 @@ def title_generation_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return template
|
return template
|
||||||
|
|
||||||
|
|
||||||
|
def search_query_generation_template(
|
||||||
|
template: str, prompt: str, user: Optional[dict] = None
|
||||||
|
) -> str:
|
||||||
|
|
||||||
|
def replacement_function(match):
|
||||||
|
full_match = match.group(0)
|
||||||
|
start_length = match.group(1)
|
||||||
|
end_length = match.group(2)
|
||||||
|
middle_length = match.group(3)
|
||||||
|
|
||||||
|
if full_match == "{{prompt}}":
|
||||||
|
return prompt
|
||||||
|
elif start_length is not None:
|
||||||
|
return prompt[: int(start_length)]
|
||||||
|
elif end_length is not None:
|
||||||
|
return prompt[-int(end_length) :]
|
||||||
|
elif middle_length is not None:
|
||||||
|
middle_length = int(middle_length)
|
||||||
|
if len(prompt) <= middle_length:
|
||||||
|
return prompt
|
||||||
|
start = prompt[: math.ceil(middle_length / 2)]
|
||||||
|
end = prompt[-math.floor(middle_length / 2) :]
|
||||||
|
return f"{start}...{end}"
|
||||||
|
return ""
|
||||||
|
|
||||||
|
template = re.sub(
|
||||||
|
r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}",
|
||||||
|
replacement_function,
|
||||||
|
template,
|
||||||
|
)
|
||||||
|
|
||||||
|
template = prompt_template(
|
||||||
|
template,
|
||||||
|
**(
|
||||||
|
{"user_name": user.get("name"), "current_location": user.get("location")}
|
||||||
|
if user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return template
|
||||||
|
@ -144,6 +144,46 @@ export const generateTitle = async (
|
|||||||
return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? 'New Chat';
|
return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? 'New Chat';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const generateSearchQuery = async (
|
||||||
|
token: string = '',
|
||||||
|
model: string,
|
||||||
|
messages: object[],
|
||||||
|
prompt: string
|
||||||
|
) => {
|
||||||
|
let error = null;
|
||||||
|
|
||||||
|
const res = await fetch(`${WEBUI_BASE_URL}/api/task/query/completions`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
Accept: 'application/json',
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
Authorization: `Bearer ${token}`
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
model: model,
|
||||||
|
messages: messages,
|
||||||
|
prompt: prompt
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.then(async (res) => {
|
||||||
|
if (!res.ok) throw await res.json();
|
||||||
|
return res.json();
|
||||||
|
})
|
||||||
|
.catch((err) => {
|
||||||
|
console.log(err);
|
||||||
|
if ('detail' in err) {
|
||||||
|
error = err.detail;
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
|
return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? prompt;
|
||||||
|
};
|
||||||
|
|
||||||
export const getPipelinesList = async (token: string = '') => {
|
export const getPipelinesList = async (token: string = '') => {
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
|
@ -44,12 +44,12 @@
|
|||||||
getTagsById,
|
getTagsById,
|
||||||
updateChatById
|
updateChatById
|
||||||
} from '$lib/apis/chats';
|
} from '$lib/apis/chats';
|
||||||
import { generateOpenAIChatCompletion, generateSearchQuery } from '$lib/apis/openai';
|
import { generateOpenAIChatCompletion } from '$lib/apis/openai';
|
||||||
import { runWebSearch } from '$lib/apis/rag';
|
import { runWebSearch } from '$lib/apis/rag';
|
||||||
import { createOpenAITextStream } from '$lib/apis/streaming';
|
import { createOpenAITextStream } from '$lib/apis/streaming';
|
||||||
import { queryMemory } from '$lib/apis/memories';
|
import { queryMemory } from '$lib/apis/memories';
|
||||||
import { getUserSettings } from '$lib/apis/users';
|
import { getUserSettings } from '$lib/apis/users';
|
||||||
import { chatCompleted, generateTitle } from '$lib/apis';
|
import { chatCompleted, generateTitle, generateSearchQuery } from '$lib/apis';
|
||||||
|
|
||||||
import Banner from '../common/Banner.svelte';
|
import Banner from '../common/Banner.svelte';
|
||||||
import MessageInput from '$lib/components/chat/MessageInput.svelte';
|
import MessageInput from '$lib/components/chat/MessageInput.svelte';
|
||||||
@ -508,7 +508,7 @@
|
|||||||
const prompt = history.messages[parentId].content;
|
const prompt = history.messages[parentId].content;
|
||||||
let searchQuery = prompt;
|
let searchQuery = prompt;
|
||||||
if (prompt.length > 100) {
|
if (prompt.length > 100) {
|
||||||
searchQuery = await generateChatSearchQuery(model, prompt);
|
searchQuery = await generateSearchQuery(localStorage.token, model, messages, prompt);
|
||||||
if (!searchQuery) {
|
if (!searchQuery) {
|
||||||
toast.warning($i18n.t('No search query generated'));
|
toast.warning($i18n.t('No search query generated'));
|
||||||
responseMessage.status = {
|
responseMessage.status = {
|
||||||
@ -1129,29 +1129,6 @@
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const generateChatSearchQuery = async (modelId: string, prompt: string) => {
|
|
||||||
const model = $models.find((model) => model.id === modelId);
|
|
||||||
const taskModelId =
|
|
||||||
model?.owned_by === 'openai' ?? false
|
|
||||||
? $settings?.title?.modelExternal ?? modelId
|
|
||||||
: $settings?.title?.model ?? modelId;
|
|
||||||
const taskModel = $models.find((model) => model.id === taskModelId);
|
|
||||||
|
|
||||||
const previousMessages = messages
|
|
||||||
.filter((message) => message.role === 'user')
|
|
||||||
.map((message) => message.content);
|
|
||||||
|
|
||||||
return await generateSearchQuery(
|
|
||||||
localStorage.token,
|
|
||||||
taskModelId,
|
|
||||||
previousMessages,
|
|
||||||
prompt,
|
|
||||||
taskModel?.owned_by === 'openai' ?? false
|
|
||||||
? `${OPENAI_API_BASE_URL}`
|
|
||||||
: `${OLLAMA_API_BASE_URL}/v1`
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const setChatTitle = async (_chatId, _title) => {
|
const setChatTitle = async (_chatId, _title) => {
|
||||||
if (_chatId === $chatId) {
|
if (_chatId === $chatId) {
|
||||||
title = _title;
|
title = _title;
|
||||||
|
Loading…
Reference in New Issue
Block a user