diff --git a/backend/config.py b/backend/config.py index cc8fd0cf8..2e718ce8c 100644 --- a/backend/config.py +++ b/backend/config.py @@ -618,6 +618,18 @@ ADMIN_EMAIL = PersistentConfig( ) +TASK_MODEL = PersistentConfig( + "TASK_MODEL", + "task.model.default", + os.environ.get("TASK_MODEL", ""), +) + +TASK_MODEL_EXTERNAL = PersistentConfig( + "TASK_MODEL_EXTERNAL", + "task.model.external", + os.environ.get("TASK_MODEL_EXTERNAL", ""), +) + TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig( "TITLE_GENERATION_PROMPT_TEMPLATE", "task.title.prompt_template", @@ -639,6 +651,19 @@ Artificial Intelligence in Healthcare ) +SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig( + "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", + "task.search.prompt_template", + os.environ.get( + "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", + """You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is {{CURRENT_DATE}}. + +Question: +{{prompt:end:4000}}""", + ), +) + + #################################### # WEBUI_SECRET_KEY #################################### diff --git a/backend/main.py b/backend/main.py index 39ccf1832..abd899614 100644 --- a/backend/main.py +++ b/backend/main.py @@ -53,7 +53,7 @@ from utils.utils import ( get_current_user, get_http_authorization_cred, ) -from utils.task import title_generation_template +from utils.task import title_generation_template, search_query_generation_template from apps.rag.utils import rag_messages @@ -77,7 +77,10 @@ from config import ( WEBHOOK_URL, ENABLE_ADMIN_EXPORT, WEBUI_BUILD_HASH, + TASK_MODEL, + TASK_MODEL_EXTERNAL, TITLE_GENERATION_PROMPT_TEMPLATE, + SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, AppConfig, ) from constants import ERROR_MESSAGES @@ -132,9 +135,15 @@ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST - app.state.config.WEBHOOK_URL = WEBHOOK_URL + + +app.state.config.TASK_MODEL = TASK_MODEL +app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE +app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( + SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE +) app.state.MODELS = {} @@ -494,9 +503,46 @@ async def get_models(user=Depends(get_verified_user)): return {"data": models} +@app.get("/api/task/config") +async def get_task_config(user=Depends(get_verified_user)): + return { + "TASK_MODEL": app.state.config.TASK_MODEL, + "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, + "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, + } + + +class TaskConfigForm(BaseModel): + TASK_MODEL: Optional[str] + TASK_MODEL_EXTERNAL: Optional[str] + TITLE_GENERATION_PROMPT_TEMPLATE: str + SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str + + +@app.post("/api/task/config/update") +async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)): + app.state.config.TASK_MODEL = form_data.TASK_MODEL + app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL + app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = ( + form_data.TITLE_GENERATION_PROMPT_TEMPLATE + ) + app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( + form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE + ) + + return { + "TASK_MODEL": app.state.config.TASK_MODEL, + "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL, + "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, + "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, + } + + @app.post("/api/task/title/completions") async def generate_title(form_data: dict, user=Depends(get_verified_user)): print("generate_title") + model_id = form_data["model"] if model_id not in app.state.MODELS: raise HTTPException( @@ -504,6 +550,20 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): detail="Model not found", ) + # Check if the user has a custom task model + # If the user has a custom task model, use that model + if app.state.MODELS[model_id]["owned_by"] == "ollama": + if app.state.config.TASK_MODEL: + task_model_id = app.state.config.TASK_MODEL + if task_model_id in app.state.MODELS: + model_id = task_model_id + else: + if app.state.config.TASK_MODEL_EXTERNAL: + task_model_id = app.state.config.TASK_MODEL_EXTERNAL + if task_model_id in app.state.MODELS: + model_id = task_model_id + + print(model_id) model = app.state.MODELS[model_id] template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE @@ -532,6 +592,57 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): return await generate_openai_chat_completion(payload, user=user) +@app.post("/api/task/query/completions") +async def generate_search_query(form_data: dict, user=Depends(get_verified_user)): + print("generate_search_query") + + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + if app.state.MODELS[model_id]["owned_by"] == "ollama": + if app.state.config.TASK_MODEL: + task_model_id = app.state.config.TASK_MODEL + if task_model_id in app.state.MODELS: + model_id = task_model_id + else: + if app.state.config.TASK_MODEL_EXTERNAL: + task_model_id = app.state.config.TASK_MODEL_EXTERNAL + if task_model_id in app.state.MODELS: + model_id = task_model_id + + print(model_id) + model = app.state.MODELS[model_id] + + template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE + + content = search_query_generation_template( + template, form_data["prompt"], user.model_dump() + ) + + payload = { + "model": model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "max_tokens": 30, + } + + print(payload) + payload = filter_pipeline(payload, user) + + if model["owned_by"] == "ollama": + return await generate_ollama_chat_completion( + OpenAIChatCompletionForm(**payload), user=user + ) + else: + return await generate_openai_chat_completion(payload, user=user) + + @app.post("/api/chat/completions") async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)): model_id = form_data["model"] @@ -542,7 +653,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u ) model = app.state.MODELS[model_id] - print(model) if model["owned_by"] == "ollama": diff --git a/backend/utils/task.py b/backend/utils/task.py index b2de4a617..2239de7df 100644 --- a/backend/utils/task.py +++ b/backend/utils/task.py @@ -68,3 +68,45 @@ def title_generation_template( ) return template + + +def search_query_generation_template( + template: str, prompt: str, user: Optional[dict] = None +) -> str: + + def replacement_function(match): + full_match = match.group(0) + start_length = match.group(1) + end_length = match.group(2) + middle_length = match.group(3) + + if full_match == "{{prompt}}": + return prompt + elif start_length is not None: + return prompt[: int(start_length)] + elif end_length is not None: + return prompt[-int(end_length) :] + elif middle_length is not None: + middle_length = int(middle_length) + if len(prompt) <= middle_length: + return prompt + start = prompt[: math.ceil(middle_length / 2)] + end = prompt[-math.floor(middle_length / 2) :] + return f"{start}...{end}" + return "" + + template = re.sub( + r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}", + replacement_function, + template, + ) + + template = prompt_template( + template, + **( + {"user_name": user.get("name"), "current_location": user.get("location")} + if user + else {} + ), + ) + return template diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 2ddbb4a1e..fc1d850b3 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -144,6 +144,46 @@ export const generateTitle = async ( return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? 'New Chat'; }; +export const generateSearchQuery = async ( + token: string = '', + model: string, + messages: object[], + prompt: string +) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/task/query/completions`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + model: model, + messages: messages, + prompt: prompt + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } + return null; + }); + + if (error) { + throw error; + } + + return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? prompt; +}; + export const getPipelinesList = async (token: string = '') => { let error = null; diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index fb846fc24..9a7839460 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -44,12 +44,12 @@ getTagsById, updateChatById } from '$lib/apis/chats'; - import { generateOpenAIChatCompletion, generateSearchQuery } from '$lib/apis/openai'; + import { generateOpenAIChatCompletion } from '$lib/apis/openai'; import { runWebSearch } from '$lib/apis/rag'; import { createOpenAITextStream } from '$lib/apis/streaming'; import { queryMemory } from '$lib/apis/memories'; import { getUserSettings } from '$lib/apis/users'; - import { chatCompleted, generateTitle } from '$lib/apis'; + import { chatCompleted, generateTitle, generateSearchQuery } from '$lib/apis'; import Banner from '../common/Banner.svelte'; import MessageInput from '$lib/components/chat/MessageInput.svelte'; @@ -508,7 +508,7 @@ const prompt = history.messages[parentId].content; let searchQuery = prompt; if (prompt.length > 100) { - searchQuery = await generateChatSearchQuery(model, prompt); + searchQuery = await generateSearchQuery(localStorage.token, model, messages, prompt); if (!searchQuery) { toast.warning($i18n.t('No search query generated')); responseMessage.status = { @@ -1129,29 +1129,6 @@ } }; - const generateChatSearchQuery = async (modelId: string, prompt: string) => { - const model = $models.find((model) => model.id === modelId); - const taskModelId = - model?.owned_by === 'openai' ?? false - ? $settings?.title?.modelExternal ?? modelId - : $settings?.title?.model ?? modelId; - const taskModel = $models.find((model) => model.id === taskModelId); - - const previousMessages = messages - .filter((message) => message.role === 'user') - .map((message) => message.content); - - return await generateSearchQuery( - localStorage.token, - taskModelId, - previousMessages, - prompt, - taskModel?.owned_by === 'openai' ?? false - ? `${OPENAI_API_BASE_URL}` - : `${OLLAMA_API_BASE_URL}/v1` - ); - }; - const setChatTitle = async (_chatId, _title) => { if (_chatId === $chatId) { title = _title;