Merge branch 'open-webui:main' into feature-external-db-reconnect

This commit is contained in:
perfectra1n
2024-06-18 12:11:28 -07:00
committed by GitHub
92 changed files with 3224 additions and 1284 deletions

View File

@@ -37,6 +37,10 @@ from config import (
ENABLE_IMAGE_GENERATION,
AUTOMATIC1111_BASE_URL,
COMFYUI_BASE_URL,
COMFYUI_CFG_SCALE,
COMFYUI_SAMPLER,
COMFYUI_SCHEDULER,
COMFYUI_SD3,
IMAGES_OPENAI_API_BASE_URL,
IMAGES_OPENAI_API_KEY,
IMAGE_GENERATION_MODEL,
@@ -78,6 +82,10 @@ app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.config.IMAGE_SIZE = IMAGE_SIZE
app.state.config.IMAGE_STEPS = IMAGE_STEPS
app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE
app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER
app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER
app.state.config.COMFYUI_SD3 = COMFYUI_SD3
@app.get("/config")
@@ -457,6 +465,18 @@ def generate_image(
if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt
if app.state.config.COMFYUI_CFG_SCALE:
data["cfg_scale"] = app.state.config.COMFYUI_CFG_SCALE
if app.state.config.COMFYUI_SAMPLER is not None:
data["sampler"] = app.state.config.COMFYUI_SAMPLER
if app.state.config.COMFYUI_SCHEDULER is not None:
data["scheduler"] = app.state.config.COMFYUI_SCHEDULER
if app.state.config.COMFYUI_SD3 is not None:
data["sd3"] = app.state.config.COMFYUI_SD3
data = ImageGenerationPayload(**data)
res = comfyui_generate_image(

View File

@@ -190,6 +190,10 @@ class ImageGenerationPayload(BaseModel):
width: int
height: int
n: int = 1
cfg_scale: Optional[float] = None
sampler: Optional[str] = None
scheduler: Optional[str] = None
sd3: Optional[bool] = None
def comfyui_generate_image(
@@ -199,6 +203,18 @@ def comfyui_generate_image(
comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT)
if payload.cfg_scale:
comfyui_prompt["3"]["inputs"]["cfg"] = payload.cfg_scale
if payload.sampler:
comfyui_prompt["3"]["inputs"]["sampler"] = payload.sampler
if payload.scheduler:
comfyui_prompt["3"]["inputs"]["scheduler"] = payload.scheduler
if payload.sd3:
comfyui_prompt["5"]["class_type"] = "EmptySD3LatentImage"
comfyui_prompt["4"]["inputs"]["ckpt_name"] = model
comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n
comfyui_prompt["5"]["inputs"]["width"] = payload.width

View File

@@ -46,6 +46,7 @@ from config import (
SRC_LOG_LEVELS,
OLLAMA_BASE_URLS,
ENABLE_OLLAMA_API,
AIOHTTP_CLIENT_TIMEOUT,
ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST,
UPLOAD_DIR,
@@ -154,7 +155,9 @@ async def cleanup_response(
async def post_streaming_url(url: str, payload: str):
r = None
try:
session = aiohttp.ClientSession(trust_env=True)
session = aiohttp.ClientSession(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
)
r = await session.post(url, data=payload)
r.raise_for_status()
@@ -751,6 +754,14 @@ async def generate_chat_completion(
if model_info.params.get("num_ctx", None):
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
if model_info.params.get("num_batch", None):
payload["options"]["num_batch"] = model_info.params.get(
"num_batch", None
)
if model_info.params.get("num_keep", None):
payload["options"]["num_keep"] = model_info.params.get("num_keep", None)
if model_info.params.get("repeat_last_n", None):
payload["options"]["repeat_last_n"] = model_info.params.get(
"repeat_last_n", None
@@ -839,8 +850,7 @@ async def generate_chat_completion(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
print(payload)
log.debug(payload)
return await post_streaming_url(f"{url}/api/chat", json.dumps(payload))

View File

@@ -430,13 +430,11 @@ async def generate_chat_completion(
# Convert the modified body back to JSON
payload = json.dumps(payload)
print(payload)
log.debug(payload)
url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx]
print(payload)
headers = {}
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"

View File

@@ -73,6 +73,7 @@ from apps.rag.search.serper import search_serper
from apps.rag.search.serpstack import search_serpstack
from apps.rag.search.serply import search_serply
from apps.rag.search.duckduckgo import search_duckduckgo
from apps.rag.search.tavily import search_tavily
from utils.misc import (
calculate_sha256,
@@ -119,6 +120,7 @@ from config import (
SERPSTACK_HTTPS,
SERPER_API_KEY,
SERPLY_API_KEY,
TAVILY_API_KEY,
RAG_WEB_SEARCH_RESULT_COUNT,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
RAG_EMBEDDING_OPENAI_BATCH_SIZE,
@@ -172,6 +174,7 @@ app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
app.state.config.SERPER_API_KEY = SERPER_API_KEY
app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
@@ -400,6 +403,7 @@ async def get_rag_config(user=Depends(get_admin_user)):
"serpstack_https": app.state.config.SERPSTACK_HTTPS,
"serper_api_key": app.state.config.SERPER_API_KEY,
"serply_api_key": app.state.config.SERPLY_API_KEY,
"tavily_api_key": app.state.config.TAVILY_API_KEY,
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
},
@@ -428,6 +432,7 @@ class WebSearchConfig(BaseModel):
serpstack_https: Optional[bool] = None
serper_api_key: Optional[str] = None
serply_api_key: Optional[str] = None
tavily_api_key: Optional[str] = None
result_count: Optional[int] = None
concurrent_requests: Optional[int] = None
@@ -479,6 +484,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
form_data.web.search.concurrent_requests
@@ -508,6 +514,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
"serpstack_https": app.state.config.SERPSTACK_HTTPS,
"serper_api_key": app.state.config.SERPER_API_KEY,
"serply_api_key": app.state.config.SERPLY_API_KEY,
"tavily_api_key": app.state.config.TAVILY_API_KEY,
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
},
@@ -756,7 +763,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
- SERPSTACK_API_KEY
- SERPER_API_KEY
- SERPLY_API_KEY
- TAVILY_API_KEY
Args:
query (str): The query to search for
"""
@@ -825,6 +832,15 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
raise Exception("No SERPLY_API_KEY found in environment variables")
elif engine == "duckduckgo":
return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
elif engine == "tavily":
if app.state.config.TAVILY_API_KEY:
return search_tavily(
app.state.config.TAVILY_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
)
else:
raise Exception("No TAVILY_API_KEY found in environment variables")
else:
raise Exception("No search engine API key found in environment variables")

View File

@@ -0,0 +1,39 @@
import logging
import requests
from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
"""Search using Tavily's Search API and return the results as a list of SearchResult objects.
Args:
api_key (str): A Tavily Search API key
query (str): The query to search for
Returns:
List[SearchResult]: A list of search results
"""
url = "https://api.tavily.com/search"
data = {"query": query, "api_key": api_key}
response = requests.post(url, json=data)
response.raise_for_status()
json_response = response.json()
raw_search_results = json_response.get("results", [])
return [
SearchResult(
link=result["url"],
title=result.get("title", ""),
snippet=result.get("content"),
)
for result in raw_search_results[:count]
]

View File

@@ -0,0 +1,48 @@
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Adding fields info to the 'user' table
migrator.add_fields("user", info=pw.TextField(null=True))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
# Remove the settings field
migrator.remove_fields("user", "info")

View File

@@ -25,6 +25,7 @@ from config import (
USER_PERMISSIONS,
WEBHOOK_URL,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
JWT_EXPIRES_IN,
WEBUI_BANNERS,
ENABLE_COMMUNITY_SHARING,
@@ -40,6 +41,7 @@ app.state.config = AppConfig()
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS

View File

@@ -65,6 +65,20 @@ class MemoriesTable:
else:
return None
def update_memory_by_id(
self,
id: str,
content: str,
) -> Optional[MemoryModel]:
try:
memory = Memory.get(Memory.id == id)
memory.content = content
memory.updated_at = int(time.time())
memory.save()
return MemoryModel(**model_to_dict(memory))
except:
return None
def get_memories(self) -> List[MemoryModel]:
try:
memories = Memory.select()

View File

@@ -26,6 +26,7 @@ class User(Model):
api_key = CharField(null=True, unique=True)
settings = JSONField(null=True)
info = JSONField(null=True)
class Meta:
database = DB
@@ -50,6 +51,7 @@ class UserModel(BaseModel):
api_key: Optional[str] = None
settings: Optional[UserSettings] = None
info: Optional[dict] = None
####################

View File

@@ -33,7 +33,11 @@ from utils.utils import (
from utils.misc import parse_duration, validate_email_format
from utils.webhook import post_webhook
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER
from config import (
WEBUI_AUTH,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
)
router = APIRouter()
@@ -110,11 +114,16 @@ async def signin(request: Request, form_data: SigninForm):
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
trusted_name = trusted_email
if WEBUI_AUTH_TRUSTED_NAME_HEADER:
trusted_name = request.headers.get(
WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email
)
if not Users.get_user_by_email(trusted_email.lower()):
await signup(
request,
SignupForm(
email=trusted_email, password=str(uuid.uuid4()), name=trusted_email
email=trusted_email, password=str(uuid.uuid4()), name=trusted_name
),
)
user = Auths.authenticate_user_by_trusted_header(trusted_email)

View File

@@ -44,6 +44,10 @@ class AddMemoryForm(BaseModel):
content: str
class MemoryUpdateModel(BaseModel):
content: Optional[str] = None
@router.post("/add", response_model=Optional[MemoryModel])
async def add_memory(
request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user)
@@ -62,6 +66,34 @@ async def add_memory(
return memory
@router.post("/{memory_id}/update", response_model=Optional[MemoryModel])
async def update_memory_by_id(
memory_id: str,
request: Request,
form_data: MemoryUpdateModel,
user=Depends(get_verified_user),
):
memory = Memories.update_memory_by_id(memory_id, form_data.content)
if memory is None:
raise HTTPException(status_code=404, detail="Memory not found")
if form_data.content is not None:
memory_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
collection = CHROMA_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
)
collection.upsert(
documents=[form_data.content],
ids=[memory.id],
embeddings=[memory_embedding],
metadatas=[
{"created_at": memory.created_at, "updated_at": memory.updated_at}
],
)
return memory
############################
# QueryMemory
############################

View File

@@ -115,6 +115,52 @@ async def update_user_settings_by_session_user(
)
############################
# GetUserInfoBySessionUser
############################
@router.get("/user/info", response_model=Optional[dict])
async def get_user_info_by_session_user(user=Depends(get_verified_user)):
user = Users.get_user_by_id(user.id)
if user:
return user.info
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
############################
# UpdateUserInfoBySessionUser
############################
@router.post("/user/info/update", response_model=Optional[dict])
async def update_user_settings_by_session_user(
form_data: dict, user=Depends(get_verified_user)
):
user = Users.get_user_by_id(user.id)
if user:
if user.info is None:
user.info = {}
user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}})
if user:
return user.info
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
############################
# GetUserById
############################

View File

@@ -294,6 +294,7 @@ WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
)
WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
JWT_EXPIRES_IN = PersistentConfig(
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
)
@@ -425,6 +426,14 @@ OLLAMA_API_BASE_URL = os.environ.get(
)
OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "")
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "300")
if AIOHTTP_CLIENT_TIMEOUT == "":
AIOHTTP_CLIENT_TIMEOUT = None
else:
AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT)
K8S_FLAG = os.environ.get("K8S_FLAG", "")
USE_OLLAMA_DOCKER = os.environ.get("USE_OLLAMA_DOCKER", "false")
@@ -942,6 +951,11 @@ SERPLY_API_KEY = PersistentConfig(
os.getenv("SERPLY_API_KEY", ""),
)
TAVILY_API_KEY = PersistentConfig(
"TAVILY_API_KEY",
"rag.web.search.tavily_api_key",
os.getenv("TAVILY_API_KEY", ""),
)
RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
"RAG_WEB_SEARCH_RESULT_COUNT",
@@ -994,6 +1008,30 @@ COMFYUI_BASE_URL = PersistentConfig(
os.getenv("COMFYUI_BASE_URL", ""),
)
COMFYUI_CFG_SCALE = PersistentConfig(
"COMFYUI_CFG_SCALE",
"image_generation.comfyui.cfg_scale",
os.getenv("COMFYUI_CFG_SCALE", ""),
)
COMFYUI_SAMPLER = PersistentConfig(
"COMFYUI_SAMPLER",
"image_generation.comfyui.sampler",
os.getenv("COMFYUI_SAMPLER", ""),
)
COMFYUI_SCHEDULER = PersistentConfig(
"COMFYUI_SCHEDULER",
"image_generation.comfyui.scheduler",
os.getenv("COMFYUI_SCHEDULER", ""),
)
COMFYUI_SD3 = PersistentConfig(
"COMFYUI_SD3",
"image_generation.comfyui.sd3",
os.environ.get("COMFYUI_SD3", "").lower() == "true",
)
IMAGES_OPENAI_API_BASE_URL = PersistentConfig(
"IMAGES_OPENAI_API_BASE_URL",
"image_generation.openai.api_base_url",

View File

@@ -494,6 +494,9 @@ def filter_pipeline(payload, user):
if "title" in payload:
del payload["title"]
if "task" in payload:
del payload["task"]
return payload
@@ -761,7 +764,12 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
content = title_generation_template(
template, form_data["prompt"], user.model_dump()
template,
form_data["prompt"],
{
"name": user.name,
"location": user.info.get("location") if user.info else None,
},
)
payload = {
@@ -773,7 +781,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
"title": True,
}
print(payload)
log.debug(payload)
try:
payload = filter_pipeline(payload, user)
@@ -827,7 +835,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
content = search_query_generation_template(
template, form_data["prompt"], user.model_dump()
template, form_data["prompt"], {"name": user.name}
)
payload = {
@@ -835,6 +843,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": 30,
"task": True,
}
print(payload)
@@ -855,6 +864,75 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
return await generate_openai_chat_completion(payload, user=user)
@app.post("/api/task/emoji/completions")
async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
print("generate_emoji")
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
print(model_id)
model = app.state.MODELS[model_id]
template = '''
Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
Message: """{{prompt}}"""
'''
content = title_generation_template(
template,
form_data["prompt"],
{
"name": user.name,
"location": user.info.get("location") if user.info else None,
},
)
payload = {
"model": model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": 4,
"chat_id": form_data.get("chat_id", None),
"task": True,
}
log.debug(payload)
try:
payload = filter_pipeline(payload, user)
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
return await generate_openai_chat_completion(payload, user=user)
@app.post("/api/task/tools/completions")
async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
print("get_tools_function_calling")

View File

@@ -6,24 +6,28 @@ from typing import Optional
def prompt_template(
template: str, user_name: str = None, current_location: str = None
template: str, user_name: str = None, user_location: str = None
) -> str:
# Get the current date
current_date = datetime.now()
# Format the date to YYYY-MM-DD
formatted_date = current_date.strftime("%Y-%m-%d")
formatted_time = current_date.strftime("%I:%M:%S %p")
# Replace {{CURRENT_DATE}} in the template with the formatted date
template = template.replace("{{CURRENT_DATE}}", formatted_date)
template = template.replace("{{CURRENT_TIME}}", formatted_time)
template = template.replace(
"{{CURRENT_DATETIME}}", f"{formatted_date} {formatted_time}"
)
if user_name:
# Replace {{USER_NAME}} in the template with the user's name
template = template.replace("{{USER_NAME}}", user_name)
if current_location:
# Replace {{CURRENT_LOCATION}} in the template with the current location
template = template.replace("{{CURRENT_LOCATION}}", current_location)
if user_location:
# Replace {{USER_LOCATION}} in the template with the current location
template = template.replace("{{USER_LOCATION}}", user_location)
return template
@@ -61,7 +65,7 @@ def title_generation_template(
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "current_location": user.get("location")}
{"user_name": user.get("name"), "user_location": user.get("location")}
if user
else {}
),
@@ -104,7 +108,7 @@ def search_query_generation_template(
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "current_location": user.get("location")}
{"user_name": user.get("name"), "user_location": user.get("location")}
if user
else {}
),