feat: user_location

This commit is contained in:
Timothy J. Baek 2024-06-16 15:32:26 -07:00
parent 8e62c36148
commit 4b6b33b08b
9 changed files with 275 additions and 19 deletions

View File

@ -0,0 +1,48 @@
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Adding fields info to the 'user' table
migrator.add_fields("user", info=pw.TextField(null=True))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
# Remove the settings field
migrator.remove_fields("user", "info")

View File

@ -26,6 +26,7 @@ class User(Model):
api_key = CharField(null=True, unique=True)
settings = JSONField(null=True)
info = JSONField(null=True)
class Meta:
database = DB
@ -50,6 +51,7 @@ class UserModel(BaseModel):
api_key: Optional[str] = None
settings: Optional[UserSettings] = None
info: Optional[dict] = None
####################

View File

@ -115,6 +115,52 @@ async def update_user_settings_by_session_user(
)
############################
# GetUserInfoBySessionUser
############################
@router.get("/user/info", response_model=Optional[dict])
async def get_user_info_by_session_user(user=Depends(get_verified_user)):
user = Users.get_user_by_id(user.id)
if user:
return user.info
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
############################
# UpdateUserInfoBySessionUser
############################
@router.post("/user/info/update", response_model=Optional[dict])
async def update_user_settings_by_session_user(
form_data: dict, user=Depends(get_verified_user)
):
user = Users.get_user_by_id(user.id)
if user:
if user.info is None:
user.info = {}
user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}})
if user:
return user.info
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
############################
# GetUserById
############################

View File

@ -764,7 +764,12 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
content = title_generation_template(
template, form_data["prompt"], user.model_dump()
template,
form_data["prompt"],
{
"name": user.name,
"location": user.info.get("location") if user.info else None,
},
)
payload = {
@ -830,7 +835,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
content = search_query_generation_template(
template, form_data["prompt"], user.model_dump()
template, form_data["prompt"], {"name": user.name}
)
payload = {
@ -893,7 +898,12 @@ Message: """{{prompt}}"""
'''
content = title_generation_template(
template, form_data["prompt"], user.model_dump()
template,
form_data["prompt"],
{
"name": user.name,
"location": user.info.get("location") if user.info else None,
},
)
payload = {

View File

@ -6,7 +6,7 @@ from typing import Optional
def prompt_template(
template: str, user_name: str = None, current_location: str = None
template: str, user_name: str = None, user_location: str = None
) -> str:
# Get the current date
current_date = datetime.now()
@ -25,9 +25,9 @@ def prompt_template(
# Replace {{USER_NAME}} in the template with the user's name
template = template.replace("{{USER_NAME}}", user_name)
if current_location:
# Replace {{CURRENT_LOCATION}} in the template with the current location
template = template.replace("{{CURRENT_LOCATION}}", current_location)
if user_location:
# Replace {{USER_LOCATION}} in the template with the current location
template = template.replace("{{USER_LOCATION}}", user_location)
return template
@ -65,7 +65,7 @@ def title_generation_template(
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "current_location": user.get("location")}
{"user_name": user.get("name"), "user_location": user.get("location")}
if user
else {}
),
@ -108,7 +108,7 @@ def search_query_generation_template(
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "current_location": user.get("location")}
{"user_name": user.get("name"), "user_location": user.get("location")}
if user
else {}
),

View File

@ -1,4 +1,5 @@
import { WEBUI_API_BASE_URL } from '$lib/constants';
import { getUserPosition } from '$lib/utils';
export const getUserPermissions = async (token: string) => {
let error = null;
@ -198,6 +199,75 @@ export const getUserById = async (token: string, userId: string) => {
return res;
};
export const getUserInfo = async (token: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/users/user/info`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateUserInfo = async (token: string, info: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/users/user/info/update`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
...info
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res;
};
export const getAndUpdateUserLocation = async (token: string) => {
const location = await getUserPosition().catch((err) => {
throw err;
});
if (location) {
await updateUserInfo(token, { location: location });
return location;
} else {
throw new Error('Failed to get user location');
}
};
export const deleteUserById = async (token: string, userId: string) => {
let error = null;

View File

@ -31,6 +31,7 @@
convertMessagesToHistory,
copyToClipboard,
extractSentencesForAudio,
getUserPosition,
promptTemplate,
splitStream
} from '$lib/utils';
@ -50,7 +51,7 @@
import { runWebSearch } from '$lib/apis/rag';
import { createOpenAITextStream } from '$lib/apis/streaming';
import { queryMemory } from '$lib/apis/memories';
import { getUserSettings } from '$lib/apis/users';
import { getAndUpdateUserLocation, getUserSettings } from '$lib/apis/users';
import { chatCompleted, generateTitle, generateSearchQuery } from '$lib/apis';
import Banner from '../common/Banner.svelte';
@ -533,7 +534,13 @@
$settings.system || (responseMessage?.userContext ?? null)
? {
role: 'system',
content: `${promptTemplate($settings?.system ?? '', $user.name)}${
content: `${promptTemplate(
$settings?.system ?? '',
$user.name,
$settings?.userLocation
? await getAndUpdateUserLocation(localStorage.token)
: undefined
)}${
responseMessage?.userContext ?? null
? `\n\nUser Context:\n${(responseMessage?.userContext ?? []).join('\n')}`
: ''
@ -871,7 +878,13 @@
$settings.system || (responseMessage?.userContext ?? null)
? {
role: 'system',
content: `${promptTemplate($settings?.system ?? '', $user.name)}${
content: `${promptTemplate(
$settings?.system ?? '',
$user.name,
$settings?.userLocation
? await getAndUpdateUserLocation(localStorage.token)
: undefined
)}${
responseMessage?.userContext ?? null
? `\n\nUser Context:\n${(responseMessage?.userContext ?? []).join('\n')}`
: ''

View File

@ -5,6 +5,8 @@
import { createEventDispatcher, onMount, getContext } from 'svelte';
import { toast } from 'svelte-sonner';
import Tooltip from '$lib/components/common/Tooltip.svelte';
import { updateUserInfo } from '$lib/apis/users';
import { getUserPosition } from '$lib/utils';
const dispatch = createEventDispatcher();
const i18n = getContext('i18n');
@ -16,6 +18,7 @@
let responseAutoCopy = false;
let widescreenMode = false;
let splitLargeChunks = false;
let userLocation = false;
// Interface
let defaultModelId = '';
@ -51,6 +54,26 @@
saveSettings({ showEmojiInCall: showEmojiInCall });
};
const toggleUserLocation = async () => {
userLocation = !userLocation;
if (userLocation) {
const position = await getUserPosition().catch((error) => {
toast.error(error.message);
return null;
});
if (position) {
await updateUserInfo(localStorage.token, { location: position });
toast.success('User location successfully retrieved.');
} else {
userLocation = false;
}
}
saveSettings({ userLocation });
};
const toggleTitleAutoGenerate = async () => {
titleAutoGenerate = !titleAutoGenerate;
saveSettings({
@ -106,6 +129,7 @@
widescreenMode = $settings.widescreenMode ?? false;
splitLargeChunks = $settings.splitLargeChunks ?? false;
chatDirection = $settings.chatDirection ?? 'LTR';
userLocation = $settings.userLocation ?? false;
defaultModelId = ($settings?.models ?? ['']).at(0);
});
@ -142,6 +166,26 @@
</div>
</div>
<div>
<div class=" py-0.5 flex w-full justify-between">
<div class=" self-center text-xs font-medium">{$i18n.t('Widescreen Mode')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
on:click={() => {
togglewidescreenMode();
}}
type="button"
>
{#if widescreenMode === true}
<span class="ml-2 self-center">{$i18n.t('On')}</span>
{:else}
<span class="ml-2 self-center">{$i18n.t('Off')}</span>
{/if}
</button>
</div>
</div>
<div>
<div class=" py-0.5 flex w-full justify-between">
<div class=" self-center text-xs font-medium">{$i18n.t('Title Auto-Generation')}</div>
@ -186,16 +230,16 @@
<div>
<div class=" py-0.5 flex w-full justify-between">
<div class=" self-center text-xs font-medium">{$i18n.t('Widescreen Mode')}</div>
<div class=" self-center text-xs font-medium">{$i18n.t('Allow User Location')}</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
on:click={() => {
togglewidescreenMode();
toggleUserLocation();
}}
type="button"
>
{#if widescreenMode === true}
{#if userLocation === true}
<span class="ml-2 self-center">{$i18n.t('On')}</span>
{:else}
<span class="ml-2 self-center">{$i18n.t('Off')}</span>

View File

@ -302,6 +302,29 @@ export const getImportOrigin = (_chats) => {
return 'webui';
};
export const getUserPosition = async (raw = false) => {
// Get the user's location using the Geolocation API
const position = await new Promise((resolve, reject) => {
navigator.geolocation.getCurrentPosition(resolve, reject);
}).catch((error) => {
console.error('Error getting user location:', error);
throw error;
});
if (!position) {
return 'Location not available';
}
// Extract the latitude and longitude from the position
const { latitude, longitude } = position.coords;
if (raw) {
return { latitude, longitude };
} else {
return `${latitude.toFixed(3)}, ${longitude.toFixed(3)} (lat, long)`;
}
};
const convertOpenAIMessages = (convo) => {
// Parse OpenAI chat messages and create chat dictionary for creating new chats
const mapping = convo['mapping'];
@ -474,7 +497,7 @@ export const blobToFile = (blob, fileName) => {
export const promptTemplate = (
template: string,
user_name?: string,
current_location?: string
user_location?: string
): string => {
// Get the current date
const currentDate = new Date();
@ -509,9 +532,9 @@ export const promptTemplate = (
template = template.replace('{{USER_NAME}}', user_name);
}
if (current_location) {
// Replace {{CURRENT_LOCATION}} in the template with the current location
template = template.replace('{{CURRENT_LOCATION}}', current_location);
if (user_location) {
// Replace {{USER_LOCATION}} in the template with the current location
template = template.replace('{{USER_LOCATION}}', user_location);
}
return template;