From 4b6b33b08b4cd336d0e9b2e2cfdf70c85de7ae1c Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 16 Jun 2024 15:32:26 -0700 Subject: [PATCH] feat: user_location --- .../internal/migrations/013_add_user_info.py | 48 +++++++++++++ backend/apps/webui/models/users.py | 2 + backend/apps/webui/routers/users.py | 46 ++++++++++++ backend/main.py | 16 ++++- backend/utils/task.py | 12 ++-- src/lib/apis/users/index.ts | 70 +++++++++++++++++++ src/lib/components/chat/Chat.svelte | 19 ++++- .../components/chat/Settings/Interface.svelte | 50 ++++++++++++- src/lib/utils/index.ts | 31 ++++++-- 9 files changed, 275 insertions(+), 19 deletions(-) create mode 100644 backend/apps/webui/internal/migrations/013_add_user_info.py diff --git a/backend/apps/webui/internal/migrations/013_add_user_info.py b/backend/apps/webui/internal/migrations/013_add_user_info.py new file mode 100644 index 000000000..0f68669cc --- /dev/null +++ b/backend/apps/webui/internal/migrations/013_add_user_info.py @@ -0,0 +1,48 @@ +"""Peewee migrations -- 002_add_local_sharing.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + # Adding fields info to the 'user' table + migrator.add_fields("user", info=pw.TextField(null=True)) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + # Remove the settings field + migrator.remove_fields("user", "info") diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index 48811e8af..485a9eea4 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -26,6 +26,7 @@ class User(Model): api_key = CharField(null=True, unique=True) settings = JSONField(null=True) + info = JSONField(null=True) class Meta: database = DB @@ -50,6 +51,7 @@ class UserModel(BaseModel): api_key: Optional[str] = None settings: Optional[UserSettings] = None + info: Optional[dict] = None #################### diff --git a/backend/apps/webui/routers/users.py b/backend/apps/webui/routers/users.py index eccafde10..270d72a23 100644 --- a/backend/apps/webui/routers/users.py +++ b/backend/apps/webui/routers/users.py @@ -115,6 +115,52 @@ async def update_user_settings_by_session_user( ) +############################ +# GetUserInfoBySessionUser +############################ + + +@router.get("/user/info", response_model=Optional[dict]) +async def get_user_info_by_session_user(user=Depends(get_verified_user)): + user = Users.get_user_by_id(user.id) + if user: + return user.info + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.USER_NOT_FOUND, + ) + + +############################ +# UpdateUserInfoBySessionUser +############################ + + +@router.post("/user/info/update", response_model=Optional[dict]) +async def update_user_settings_by_session_user( + form_data: dict, user=Depends(get_verified_user) +): + user = Users.get_user_by_id(user.id) + if user: + if user.info is None: + user.info = {} + + user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}}) + if user: + return user.info + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.USER_NOT_FOUND, + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.USER_NOT_FOUND, + ) + + ############################ # GetUserById ############################ diff --git a/backend/main.py b/backend/main.py index 02fc93911..04f886162 100644 --- a/backend/main.py +++ b/backend/main.py @@ -764,7 +764,12 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE content = title_generation_template( - template, form_data["prompt"], user.model_dump() + template, + form_data["prompt"], + { + "name": user.name, + "location": user.info.get("location") if user.info else None, + }, ) payload = { @@ -830,7 +835,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE content = search_query_generation_template( - template, form_data["prompt"], user.model_dump() + template, form_data["prompt"], {"name": user.name} ) payload = { @@ -893,7 +898,12 @@ Message: """{{prompt}}""" ''' content = title_generation_template( - template, form_data["prompt"], user.model_dump() + template, + form_data["prompt"], + { + "name": user.name, + "location": user.info.get("location") if user.info else None, + }, ) payload = { diff --git a/backend/utils/task.py b/backend/utils/task.py index 06787196c..ea277eb0b 100644 --- a/backend/utils/task.py +++ b/backend/utils/task.py @@ -6,7 +6,7 @@ from typing import Optional def prompt_template( - template: str, user_name: str = None, current_location: str = None + template: str, user_name: str = None, user_location: str = None ) -> str: # Get the current date current_date = datetime.now() @@ -25,9 +25,9 @@ def prompt_template( # Replace {{USER_NAME}} in the template with the user's name template = template.replace("{{USER_NAME}}", user_name) - if current_location: - # Replace {{CURRENT_LOCATION}} in the template with the current location - template = template.replace("{{CURRENT_LOCATION}}", current_location) + if user_location: + # Replace {{USER_LOCATION}} in the template with the current location + template = template.replace("{{USER_LOCATION}}", user_location) return template @@ -65,7 +65,7 @@ def title_generation_template( template = prompt_template( template, **( - {"user_name": user.get("name"), "current_location": user.get("location")} + {"user_name": user.get("name"), "user_location": user.get("location")} if user else {} ), @@ -108,7 +108,7 @@ def search_query_generation_template( template = prompt_template( template, **( - {"user_name": user.get("name"), "current_location": user.get("location")} + {"user_name": user.get("name"), "user_location": user.get("location")} if user else {} ), diff --git a/src/lib/apis/users/index.ts b/src/lib/apis/users/index.ts index 4c97b0036..0b22b7171 100644 --- a/src/lib/apis/users/index.ts +++ b/src/lib/apis/users/index.ts @@ -1,4 +1,5 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; +import { getUserPosition } from '$lib/utils'; export const getUserPermissions = async (token: string) => { let error = null; @@ -198,6 +199,75 @@ export const getUserById = async (token: string, userId: string) => { return res; }; +export const getUserInfo = async (token: string) => { + let error = null; + const res = await fetch(`${WEBUI_API_BASE_URL}/users/user/info`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateUserInfo = async (token: string, info: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/users/user/info/update`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...info + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getAndUpdateUserLocation = async (token: string) => { + const location = await getUserPosition().catch((err) => { + throw err; + }); + + if (location) { + await updateUserInfo(token, { location: location }); + return location; + } else { + throw new Error('Failed to get user location'); + } +}; + export const deleteUserById = async (token: string, userId: string) => { let error = null; diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 73b480796..8819a0428 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -31,6 +31,7 @@ convertMessagesToHistory, copyToClipboard, extractSentencesForAudio, + getUserPosition, promptTemplate, splitStream } from '$lib/utils'; @@ -50,7 +51,7 @@ import { runWebSearch } from '$lib/apis/rag'; import { createOpenAITextStream } from '$lib/apis/streaming'; import { queryMemory } from '$lib/apis/memories'; - import { getUserSettings } from '$lib/apis/users'; + import { getAndUpdateUserLocation, getUserSettings } from '$lib/apis/users'; import { chatCompleted, generateTitle, generateSearchQuery } from '$lib/apis'; import Banner from '../common/Banner.svelte'; @@ -533,7 +534,13 @@ $settings.system || (responseMessage?.userContext ?? null) ? { role: 'system', - content: `${promptTemplate($settings?.system ?? '', $user.name)}${ + content: `${promptTemplate( + $settings?.system ?? '', + $user.name, + $settings?.userLocation + ? await getAndUpdateUserLocation(localStorage.token) + : undefined + )}${ responseMessage?.userContext ?? null ? `\n\nUser Context:\n${(responseMessage?.userContext ?? []).join('\n')}` : '' @@ -871,7 +878,13 @@ $settings.system || (responseMessage?.userContext ?? null) ? { role: 'system', - content: `${promptTemplate($settings?.system ?? '', $user.name)}${ + content: `${promptTemplate( + $settings?.system ?? '', + $user.name, + $settings?.userLocation + ? await getAndUpdateUserLocation(localStorage.token) + : undefined + )}${ responseMessage?.userContext ?? null ? `\n\nUser Context:\n${(responseMessage?.userContext ?? []).join('\n')}` : '' diff --git a/src/lib/components/chat/Settings/Interface.svelte b/src/lib/components/chat/Settings/Interface.svelte index 25a504ef5..b96a16d9d 100644 --- a/src/lib/components/chat/Settings/Interface.svelte +++ b/src/lib/components/chat/Settings/Interface.svelte @@ -5,6 +5,8 @@ import { createEventDispatcher, onMount, getContext } from 'svelte'; import { toast } from 'svelte-sonner'; import Tooltip from '$lib/components/common/Tooltip.svelte'; + import { updateUserInfo } from '$lib/apis/users'; + import { getUserPosition } from '$lib/utils'; const dispatch = createEventDispatcher(); const i18n = getContext('i18n'); @@ -16,6 +18,7 @@ let responseAutoCopy = false; let widescreenMode = false; let splitLargeChunks = false; + let userLocation = false; // Interface let defaultModelId = ''; @@ -51,6 +54,26 @@ saveSettings({ showEmojiInCall: showEmojiInCall }); }; + const toggleUserLocation = async () => { + userLocation = !userLocation; + + if (userLocation) { + const position = await getUserPosition().catch((error) => { + toast.error(error.message); + return null; + }); + + if (position) { + await updateUserInfo(localStorage.token, { location: position }); + toast.success('User location successfully retrieved.'); + } else { + userLocation = false; + } + } + + saveSettings({ userLocation }); + }; + const toggleTitleAutoGenerate = async () => { titleAutoGenerate = !titleAutoGenerate; saveSettings({ @@ -106,6 +129,7 @@ widescreenMode = $settings.widescreenMode ?? false; splitLargeChunks = $settings.splitLargeChunks ?? false; chatDirection = $settings.chatDirection ?? 'LTR'; + userLocation = $settings.userLocation ?? false; defaultModelId = ($settings?.models ?? ['']).at(0); }); @@ -142,6 +166,26 @@ +
+
+
{$i18n.t('Widescreen Mode')}
+ + +
+
+
{$i18n.t('Title Auto-Generation')}
@@ -186,16 +230,16 @@
-
{$i18n.t('Widescreen Mode')}
+
{$i18n.t('Allow User Location')}