mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	refac: search query task
This commit is contained in:
		
							parent
							
								
									aa1bb4fb6d
								
							
						
					
					
						commit
						591cd993c2
					
				@ -618,6 +618,18 @@ ADMIN_EMAIL = PersistentConfig(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
TASK_MODEL = PersistentConfig(
 | 
			
		||||
    "TASK_MODEL",
 | 
			
		||||
    "task.model.default",
 | 
			
		||||
    os.environ.get("TASK_MODEL", ""),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
TASK_MODEL_EXTERNAL = PersistentConfig(
 | 
			
		||||
    "TASK_MODEL_EXTERNAL",
 | 
			
		||||
    "task.model.external",
 | 
			
		||||
    os.environ.get("TASK_MODEL_EXTERNAL", ""),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
 | 
			
		||||
    "TITLE_GENERATION_PROMPT_TEMPLATE",
 | 
			
		||||
    "task.title.prompt_template",
 | 
			
		||||
@ -639,6 +651,19 @@ Artificial Intelligence in Healthcare
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
 | 
			
		||||
    "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
 | 
			
		||||
    "task.search.prompt_template",
 | 
			
		||||
    os.environ.get(
 | 
			
		||||
        "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
 | 
			
		||||
        """You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is {{CURRENT_DATE}}.
 | 
			
		||||
        
 | 
			
		||||
Question:
 | 
			
		||||
{{prompt:end:4000}}""",
 | 
			
		||||
    ),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
####################################
 | 
			
		||||
# WEBUI_SECRET_KEY
 | 
			
		||||
####################################
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										116
									
								
								backend/main.py
									
									
									
									
									
								
							
							
						
						
									
										116
									
								
								backend/main.py
									
									
									
									
									
								
							@ -53,7 +53,7 @@ from utils.utils import (
 | 
			
		||||
    get_current_user,
 | 
			
		||||
    get_http_authorization_cred,
 | 
			
		||||
)
 | 
			
		||||
from utils.task import title_generation_template
 | 
			
		||||
from utils.task import title_generation_template, search_query_generation_template
 | 
			
		||||
 | 
			
		||||
from apps.rag.utils import rag_messages
 | 
			
		||||
 | 
			
		||||
@ -77,7 +77,10 @@ from config import (
 | 
			
		||||
    WEBHOOK_URL,
 | 
			
		||||
    ENABLE_ADMIN_EXPORT,
 | 
			
		||||
    WEBUI_BUILD_HASH,
 | 
			
		||||
    TASK_MODEL,
 | 
			
		||||
    TASK_MODEL_EXTERNAL,
 | 
			
		||||
    TITLE_GENERATION_PROMPT_TEMPLATE,
 | 
			
		||||
    SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
 | 
			
		||||
    AppConfig,
 | 
			
		||||
)
 | 
			
		||||
from constants import ERROR_MESSAGES
 | 
			
		||||
@ -132,9 +135,15 @@ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
 | 
			
		||||
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
 | 
			
		||||
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
app.state.config.WEBHOOK_URL = WEBHOOK_URL
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
app.state.config.TASK_MODEL = TASK_MODEL
 | 
			
		||||
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
 | 
			
		||||
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
 | 
			
		||||
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
 | 
			
		||||
    SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
app.state.MODELS = {}
 | 
			
		||||
 | 
			
		||||
@ -494,9 +503,46 @@ async def get_models(user=Depends(get_verified_user)):
 | 
			
		||||
    return {"data": models}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.get("/api/task/config")
 | 
			
		||||
async def get_task_config(user=Depends(get_verified_user)):
 | 
			
		||||
    return {
 | 
			
		||||
        "TASK_MODEL": app.state.config.TASK_MODEL,
 | 
			
		||||
        "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
 | 
			
		||||
        "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
 | 
			
		||||
        "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TaskConfigForm(BaseModel):
 | 
			
		||||
    TASK_MODEL: Optional[str]
 | 
			
		||||
    TASK_MODEL_EXTERNAL: Optional[str]
 | 
			
		||||
    TITLE_GENERATION_PROMPT_TEMPLATE: str
 | 
			
		||||
    SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/api/task/config/update")
 | 
			
		||||
async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)):
 | 
			
		||||
    app.state.config.TASK_MODEL = form_data.TASK_MODEL
 | 
			
		||||
    app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
 | 
			
		||||
    app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
 | 
			
		||||
        form_data.TITLE_GENERATION_PROMPT_TEMPLATE
 | 
			
		||||
    )
 | 
			
		||||
    app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
 | 
			
		||||
        form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return {
 | 
			
		||||
        "TASK_MODEL": app.state.config.TASK_MODEL,
 | 
			
		||||
        "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
 | 
			
		||||
        "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
 | 
			
		||||
        "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/api/task/title/completions")
 | 
			
		||||
async def generate_title(form_data: dict, user=Depends(get_verified_user)):
 | 
			
		||||
    print("generate_title")
 | 
			
		||||
 | 
			
		||||
    model_id = form_data["model"]
 | 
			
		||||
    if model_id not in app.state.MODELS:
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
@ -504,6 +550,20 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
 | 
			
		||||
            detail="Model not found",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # Check if the user has a custom task model
 | 
			
		||||
    # If the user has a custom task model, use that model
 | 
			
		||||
    if app.state.MODELS[model_id]["owned_by"] == "ollama":
 | 
			
		||||
        if app.state.config.TASK_MODEL:
 | 
			
		||||
            task_model_id = app.state.config.TASK_MODEL
 | 
			
		||||
            if task_model_id in app.state.MODELS:
 | 
			
		||||
                model_id = task_model_id
 | 
			
		||||
    else:
 | 
			
		||||
        if app.state.config.TASK_MODEL_EXTERNAL:
 | 
			
		||||
            task_model_id = app.state.config.TASK_MODEL_EXTERNAL
 | 
			
		||||
            if task_model_id in app.state.MODELS:
 | 
			
		||||
                model_id = task_model_id
 | 
			
		||||
 | 
			
		||||
    print(model_id)
 | 
			
		||||
    model = app.state.MODELS[model_id]
 | 
			
		||||
 | 
			
		||||
    template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
 | 
			
		||||
@ -532,6 +592,57 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
 | 
			
		||||
        return await generate_openai_chat_completion(payload, user=user)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/api/task/query/completions")
 | 
			
		||||
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
 | 
			
		||||
    print("generate_search_query")
 | 
			
		||||
 | 
			
		||||
    model_id = form_data["model"]
 | 
			
		||||
    if model_id not in app.state.MODELS:
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
            status_code=status.HTTP_404_NOT_FOUND,
 | 
			
		||||
            detail="Model not found",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # Check if the user has a custom task model
 | 
			
		||||
    # If the user has a custom task model, use that model
 | 
			
		||||
    if app.state.MODELS[model_id]["owned_by"] == "ollama":
 | 
			
		||||
        if app.state.config.TASK_MODEL:
 | 
			
		||||
            task_model_id = app.state.config.TASK_MODEL
 | 
			
		||||
            if task_model_id in app.state.MODELS:
 | 
			
		||||
                model_id = task_model_id
 | 
			
		||||
    else:
 | 
			
		||||
        if app.state.config.TASK_MODEL_EXTERNAL:
 | 
			
		||||
            task_model_id = app.state.config.TASK_MODEL_EXTERNAL
 | 
			
		||||
            if task_model_id in app.state.MODELS:
 | 
			
		||||
                model_id = task_model_id
 | 
			
		||||
 | 
			
		||||
    print(model_id)
 | 
			
		||||
    model = app.state.MODELS[model_id]
 | 
			
		||||
 | 
			
		||||
    template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
 | 
			
		||||
 | 
			
		||||
    content = search_query_generation_template(
 | 
			
		||||
        template, form_data["prompt"], user.model_dump()
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    payload = {
 | 
			
		||||
        "model": model_id,
 | 
			
		||||
        "messages": [{"role": "user", "content": content}],
 | 
			
		||||
        "stream": False,
 | 
			
		||||
        "max_tokens": 30,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    print(payload)
 | 
			
		||||
    payload = filter_pipeline(payload, user)
 | 
			
		||||
 | 
			
		||||
    if model["owned_by"] == "ollama":
 | 
			
		||||
        return await generate_ollama_chat_completion(
 | 
			
		||||
            OpenAIChatCompletionForm(**payload), user=user
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        return await generate_openai_chat_completion(payload, user=user)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/api/chat/completions")
 | 
			
		||||
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
 | 
			
		||||
    model_id = form_data["model"]
 | 
			
		||||
@ -542,7 +653,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    model = app.state.MODELS[model_id]
 | 
			
		||||
 | 
			
		||||
    print(model)
 | 
			
		||||
 | 
			
		||||
    if model["owned_by"] == "ollama":
 | 
			
		||||
 | 
			
		||||
@ -68,3 +68,45 @@ def title_generation_template(
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def search_query_generation_template(
 | 
			
		||||
    template: str, prompt: str, user: Optional[dict] = None
 | 
			
		||||
) -> str:
 | 
			
		||||
 | 
			
		||||
    def replacement_function(match):
 | 
			
		||||
        full_match = match.group(0)
 | 
			
		||||
        start_length = match.group(1)
 | 
			
		||||
        end_length = match.group(2)
 | 
			
		||||
        middle_length = match.group(3)
 | 
			
		||||
 | 
			
		||||
        if full_match == "{{prompt}}":
 | 
			
		||||
            return prompt
 | 
			
		||||
        elif start_length is not None:
 | 
			
		||||
            return prompt[: int(start_length)]
 | 
			
		||||
        elif end_length is not None:
 | 
			
		||||
            return prompt[-int(end_length) :]
 | 
			
		||||
        elif middle_length is not None:
 | 
			
		||||
            middle_length = int(middle_length)
 | 
			
		||||
            if len(prompt) <= middle_length:
 | 
			
		||||
                return prompt
 | 
			
		||||
            start = prompt[: math.ceil(middle_length / 2)]
 | 
			
		||||
            end = prompt[-math.floor(middle_length / 2) :]
 | 
			
		||||
            return f"{start}...{end}"
 | 
			
		||||
        return ""
 | 
			
		||||
 | 
			
		||||
    template = re.sub(
 | 
			
		||||
        r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}",
 | 
			
		||||
        replacement_function,
 | 
			
		||||
        template,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    template = prompt_template(
 | 
			
		||||
        template,
 | 
			
		||||
        **(
 | 
			
		||||
            {"user_name": user.get("name"), "current_location": user.get("location")}
 | 
			
		||||
            if user
 | 
			
		||||
            else {}
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    return template
 | 
			
		||||
 | 
			
		||||
@ -144,6 +144,46 @@ export const generateTitle = async (
 | 
			
		||||
	return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? 'New Chat';
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
export const generateSearchQuery = async (
 | 
			
		||||
	token: string = '',
 | 
			
		||||
	model: string,
 | 
			
		||||
	messages: object[],
 | 
			
		||||
	prompt: string
 | 
			
		||||
) => {
 | 
			
		||||
	let error = null;
 | 
			
		||||
 | 
			
		||||
	const res = await fetch(`${WEBUI_BASE_URL}/api/task/query/completions`, {
 | 
			
		||||
		method: 'POST',
 | 
			
		||||
		headers: {
 | 
			
		||||
			Accept: 'application/json',
 | 
			
		||||
			'Content-Type': 'application/json',
 | 
			
		||||
			Authorization: `Bearer ${token}`
 | 
			
		||||
		},
 | 
			
		||||
		body: JSON.stringify({
 | 
			
		||||
			model: model,
 | 
			
		||||
			messages: messages,
 | 
			
		||||
			prompt: prompt
 | 
			
		||||
		})
 | 
			
		||||
	})
 | 
			
		||||
		.then(async (res) => {
 | 
			
		||||
			if (!res.ok) throw await res.json();
 | 
			
		||||
			return res.json();
 | 
			
		||||
		})
 | 
			
		||||
		.catch((err) => {
 | 
			
		||||
			console.log(err);
 | 
			
		||||
			if ('detail' in err) {
 | 
			
		||||
				error = err.detail;
 | 
			
		||||
			}
 | 
			
		||||
			return null;
 | 
			
		||||
		});
 | 
			
		||||
 | 
			
		||||
	if (error) {
 | 
			
		||||
		throw error;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? prompt;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
export const getPipelinesList = async (token: string = '') => {
 | 
			
		||||
	let error = null;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -44,12 +44,12 @@
 | 
			
		||||
		getTagsById,
 | 
			
		||||
		updateChatById
 | 
			
		||||
	} from '$lib/apis/chats';
 | 
			
		||||
	import { generateOpenAIChatCompletion, generateSearchQuery } from '$lib/apis/openai';
 | 
			
		||||
	import { generateOpenAIChatCompletion } from '$lib/apis/openai';
 | 
			
		||||
	import { runWebSearch } from '$lib/apis/rag';
 | 
			
		||||
	import { createOpenAITextStream } from '$lib/apis/streaming';
 | 
			
		||||
	import { queryMemory } from '$lib/apis/memories';
 | 
			
		||||
	import { getUserSettings } from '$lib/apis/users';
 | 
			
		||||
	import { chatCompleted, generateTitle } from '$lib/apis';
 | 
			
		||||
	import { chatCompleted, generateTitle, generateSearchQuery } from '$lib/apis';
 | 
			
		||||
 | 
			
		||||
	import Banner from '../common/Banner.svelte';
 | 
			
		||||
	import MessageInput from '$lib/components/chat/MessageInput.svelte';
 | 
			
		||||
@ -508,7 +508,7 @@
 | 
			
		||||
		const prompt = history.messages[parentId].content;
 | 
			
		||||
		let searchQuery = prompt;
 | 
			
		||||
		if (prompt.length > 100) {
 | 
			
		||||
			searchQuery = await generateChatSearchQuery(model, prompt);
 | 
			
		||||
			searchQuery = await generateSearchQuery(localStorage.token, model, messages, prompt);
 | 
			
		||||
			if (!searchQuery) {
 | 
			
		||||
				toast.warning($i18n.t('No search query generated'));
 | 
			
		||||
				responseMessage.status = {
 | 
			
		||||
@ -1129,29 +1129,6 @@
 | 
			
		||||
		}
 | 
			
		||||
	};
 | 
			
		||||
 | 
			
		||||
	const generateChatSearchQuery = async (modelId: string, prompt: string) => {
 | 
			
		||||
		const model = $models.find((model) => model.id === modelId);
 | 
			
		||||
		const taskModelId =
 | 
			
		||||
			model?.owned_by === 'openai' ?? false
 | 
			
		||||
				? $settings?.title?.modelExternal ?? modelId
 | 
			
		||||
				: $settings?.title?.model ?? modelId;
 | 
			
		||||
		const taskModel = $models.find((model) => model.id === taskModelId);
 | 
			
		||||
 | 
			
		||||
		const previousMessages = messages
 | 
			
		||||
			.filter((message) => message.role === 'user')
 | 
			
		||||
			.map((message) => message.content);
 | 
			
		||||
 | 
			
		||||
		return await generateSearchQuery(
 | 
			
		||||
			localStorage.token,
 | 
			
		||||
			taskModelId,
 | 
			
		||||
			previousMessages,
 | 
			
		||||
			prompt,
 | 
			
		||||
			taskModel?.owned_by === 'openai' ?? false
 | 
			
		||||
				? `${OPENAI_API_BASE_URL}`
 | 
			
		||||
				: `${OLLAMA_API_BASE_URL}/v1`
 | 
			
		||||
		);
 | 
			
		||||
	};
 | 
			
		||||
 | 
			
		||||
	const setChatTitle = async (_chatId, _title) => {
 | 
			
		||||
		if (_chatId === $chatId) {
 | 
			
		||||
			title = _title;
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user