mirror of
https://github.com/open-webui/open-webui
synced 2024-11-17 05:53:11 +00:00
feat: user_location
This commit is contained in:
parent
8e62c36148
commit
4b6b33b08b
48
backend/apps/webui/internal/migrations/013_add_user_info.py
Normal file
48
backend/apps/webui/internal/migrations/013_add_user_info.py
Normal file
@ -0,0 +1,48 @@
|
||||
"""Peewee migrations -- 002_add_local_sharing.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
# Adding fields info to the 'user' table
|
||||
migrator.add_fields("user", info=pw.TextField(null=True))
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
# Remove the settings field
|
||||
migrator.remove_fields("user", "info")
|
@ -26,6 +26,7 @@ class User(Model):
|
||||
|
||||
api_key = CharField(null=True, unique=True)
|
||||
settings = JSONField(null=True)
|
||||
info = JSONField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = DB
|
||||
@ -50,6 +51,7 @@ class UserModel(BaseModel):
|
||||
|
||||
api_key: Optional[str] = None
|
||||
settings: Optional[UserSettings] = None
|
||||
info: Optional[dict] = None
|
||||
|
||||
|
||||
####################
|
||||
|
@ -115,6 +115,52 @@ async def update_user_settings_by_session_user(
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetUserInfoBySessionUser
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/user/info", response_model=Optional[dict])
|
||||
async def get_user_info_by_session_user(user=Depends(get_verified_user)):
|
||||
user = Users.get_user_by_id(user.id)
|
||||
if user:
|
||||
return user.info
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateUserInfoBySessionUser
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/user/info/update", response_model=Optional[dict])
|
||||
async def update_user_settings_by_session_user(
|
||||
form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
user = Users.get_user_by_id(user.id)
|
||||
if user:
|
||||
if user.info is None:
|
||||
user.info = {}
|
||||
|
||||
user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}})
|
||||
if user:
|
||||
return user.info
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetUserById
|
||||
############################
|
||||
|
@ -764,7 +764,12 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
||||
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = title_generation_template(
|
||||
template, form_data["prompt"], user.model_dump()
|
||||
template,
|
||||
form_data["prompt"],
|
||||
{
|
||||
"name": user.name,
|
||||
"location": user.info.get("location") if user.info else None,
|
||||
},
|
||||
)
|
||||
|
||||
payload = {
|
||||
@ -830,7 +835,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
||||
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = search_query_generation_template(
|
||||
template, form_data["prompt"], user.model_dump()
|
||||
template, form_data["prompt"], {"name": user.name}
|
||||
)
|
||||
|
||||
payload = {
|
||||
@ -893,7 +898,12 @@ Message: """{{prompt}}"""
|
||||
'''
|
||||
|
||||
content = title_generation_template(
|
||||
template, form_data["prompt"], user.model_dump()
|
||||
template,
|
||||
form_data["prompt"],
|
||||
{
|
||||
"name": user.name,
|
||||
"location": user.info.get("location") if user.info else None,
|
||||
},
|
||||
)
|
||||
|
||||
payload = {
|
||||
|
@ -6,7 +6,7 @@ from typing import Optional
|
||||
|
||||
|
||||
def prompt_template(
|
||||
template: str, user_name: str = None, current_location: str = None
|
||||
template: str, user_name: str = None, user_location: str = None
|
||||
) -> str:
|
||||
# Get the current date
|
||||
current_date = datetime.now()
|
||||
@ -25,9 +25,9 @@ def prompt_template(
|
||||
# Replace {{USER_NAME}} in the template with the user's name
|
||||
template = template.replace("{{USER_NAME}}", user_name)
|
||||
|
||||
if current_location:
|
||||
# Replace {{CURRENT_LOCATION}} in the template with the current location
|
||||
template = template.replace("{{CURRENT_LOCATION}}", current_location)
|
||||
if user_location:
|
||||
# Replace {{USER_LOCATION}} in the template with the current location
|
||||
template = template.replace("{{USER_LOCATION}}", user_location)
|
||||
|
||||
return template
|
||||
|
||||
@ -65,7 +65,7 @@ def title_generation_template(
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "current_location": user.get("location")}
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
@ -108,7 +108,7 @@ def search_query_generation_template(
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "current_location": user.get("location")}
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { WEBUI_API_BASE_URL } from '$lib/constants';
|
||||
import { getUserPosition } from '$lib/utils';
|
||||
|
||||
export const getUserPermissions = async (token: string) => {
|
||||
let error = null;
|
||||
@ -198,6 +199,75 @@ export const getUserById = async (token: string, userId: string) => {
|
||||
return res;
|
||||
};
|
||||
|
||||
export const getUserInfo = async (token: string) => {
|
||||
let error = null;
|
||||
const res = await fetch(`${WEBUI_API_BASE_URL}/users/user/info`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`
|
||||
}
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw await res.json();
|
||||
return res.json();
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
error = err.detail;
|
||||
return null;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
||||
|
||||
export const updateUserInfo = async (token: string, info: object) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${WEBUI_API_BASE_URL}/users/user/info/update`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`
|
||||
},
|
||||
body: JSON.stringify({
|
||||
...info
|
||||
})
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw await res.json();
|
||||
return res.json();
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
error = err.detail;
|
||||
return null;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
||||
|
||||
export const getAndUpdateUserLocation = async (token: string) => {
|
||||
const location = await getUserPosition().catch((err) => {
|
||||
throw err;
|
||||
});
|
||||
|
||||
if (location) {
|
||||
await updateUserInfo(token, { location: location });
|
||||
return location;
|
||||
} else {
|
||||
throw new Error('Failed to get user location');
|
||||
}
|
||||
};
|
||||
|
||||
export const deleteUserById = async (token: string, userId: string) => {
|
||||
let error = null;
|
||||
|
||||
|
@ -31,6 +31,7 @@
|
||||
convertMessagesToHistory,
|
||||
copyToClipboard,
|
||||
extractSentencesForAudio,
|
||||
getUserPosition,
|
||||
promptTemplate,
|
||||
splitStream
|
||||
} from '$lib/utils';
|
||||
@ -50,7 +51,7 @@
|
||||
import { runWebSearch } from '$lib/apis/rag';
|
||||
import { createOpenAITextStream } from '$lib/apis/streaming';
|
||||
import { queryMemory } from '$lib/apis/memories';
|
||||
import { getUserSettings } from '$lib/apis/users';
|
||||
import { getAndUpdateUserLocation, getUserSettings } from '$lib/apis/users';
|
||||
import { chatCompleted, generateTitle, generateSearchQuery } from '$lib/apis';
|
||||
|
||||
import Banner from '../common/Banner.svelte';
|
||||
@ -533,7 +534,13 @@
|
||||
$settings.system || (responseMessage?.userContext ?? null)
|
||||
? {
|
||||
role: 'system',
|
||||
content: `${promptTemplate($settings?.system ?? '', $user.name)}${
|
||||
content: `${promptTemplate(
|
||||
$settings?.system ?? '',
|
||||
$user.name,
|
||||
$settings?.userLocation
|
||||
? await getAndUpdateUserLocation(localStorage.token)
|
||||
: undefined
|
||||
)}${
|
||||
responseMessage?.userContext ?? null
|
||||
? `\n\nUser Context:\n${(responseMessage?.userContext ?? []).join('\n')}`
|
||||
: ''
|
||||
@ -871,7 +878,13 @@
|
||||
$settings.system || (responseMessage?.userContext ?? null)
|
||||
? {
|
||||
role: 'system',
|
||||
content: `${promptTemplate($settings?.system ?? '', $user.name)}${
|
||||
content: `${promptTemplate(
|
||||
$settings?.system ?? '',
|
||||
$user.name,
|
||||
$settings?.userLocation
|
||||
? await getAndUpdateUserLocation(localStorage.token)
|
||||
: undefined
|
||||
)}${
|
||||
responseMessage?.userContext ?? null
|
||||
? `\n\nUser Context:\n${(responseMessage?.userContext ?? []).join('\n')}`
|
||||
: ''
|
||||
|
@ -5,6 +5,8 @@
|
||||
import { createEventDispatcher, onMount, getContext } from 'svelte';
|
||||
import { toast } from 'svelte-sonner';
|
||||
import Tooltip from '$lib/components/common/Tooltip.svelte';
|
||||
import { updateUserInfo } from '$lib/apis/users';
|
||||
import { getUserPosition } from '$lib/utils';
|
||||
const dispatch = createEventDispatcher();
|
||||
|
||||
const i18n = getContext('i18n');
|
||||
@ -16,6 +18,7 @@
|
||||
let responseAutoCopy = false;
|
||||
let widescreenMode = false;
|
||||
let splitLargeChunks = false;
|
||||
let userLocation = false;
|
||||
|
||||
// Interface
|
||||
let defaultModelId = '';
|
||||
@ -51,6 +54,26 @@
|
||||
saveSettings({ showEmojiInCall: showEmojiInCall });
|
||||
};
|
||||
|
||||
const toggleUserLocation = async () => {
|
||||
userLocation = !userLocation;
|
||||
|
||||
if (userLocation) {
|
||||
const position = await getUserPosition().catch((error) => {
|
||||
toast.error(error.message);
|
||||
return null;
|
||||
});
|
||||
|
||||
if (position) {
|
||||
await updateUserInfo(localStorage.token, { location: position });
|
||||
toast.success('User location successfully retrieved.');
|
||||
} else {
|
||||
userLocation = false;
|
||||
}
|
||||
}
|
||||
|
||||
saveSettings({ userLocation });
|
||||
};
|
||||
|
||||
const toggleTitleAutoGenerate = async () => {
|
||||
titleAutoGenerate = !titleAutoGenerate;
|
||||
saveSettings({
|
||||
@ -106,6 +129,7 @@
|
||||
widescreenMode = $settings.widescreenMode ?? false;
|
||||
splitLargeChunks = $settings.splitLargeChunks ?? false;
|
||||
chatDirection = $settings.chatDirection ?? 'LTR';
|
||||
userLocation = $settings.userLocation ?? false;
|
||||
|
||||
defaultModelId = ($settings?.models ?? ['']).at(0);
|
||||
});
|
||||
@ -142,6 +166,26 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<div class=" py-0.5 flex w-full justify-between">
|
||||
<div class=" self-center text-xs font-medium">{$i18n.t('Widescreen Mode')}</div>
|
||||
|
||||
<button
|
||||
class="p-1 px-3 text-xs flex rounded transition"
|
||||
on:click={() => {
|
||||
togglewidescreenMode();
|
||||
}}
|
||||
type="button"
|
||||
>
|
||||
{#if widescreenMode === true}
|
||||
<span class="ml-2 self-center">{$i18n.t('On')}</span>
|
||||
{:else}
|
||||
<span class="ml-2 self-center">{$i18n.t('Off')}</span>
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<div class=" py-0.5 flex w-full justify-between">
|
||||
<div class=" self-center text-xs font-medium">{$i18n.t('Title Auto-Generation')}</div>
|
||||
@ -186,16 +230,16 @@
|
||||
|
||||
<div>
|
||||
<div class=" py-0.5 flex w-full justify-between">
|
||||
<div class=" self-center text-xs font-medium">{$i18n.t('Widescreen Mode')}</div>
|
||||
<div class=" self-center text-xs font-medium">{$i18n.t('Allow User Location')}</div>
|
||||
|
||||
<button
|
||||
class="p-1 px-3 text-xs flex rounded transition"
|
||||
on:click={() => {
|
||||
togglewidescreenMode();
|
||||
toggleUserLocation();
|
||||
}}
|
||||
type="button"
|
||||
>
|
||||
{#if widescreenMode === true}
|
||||
{#if userLocation === true}
|
||||
<span class="ml-2 self-center">{$i18n.t('On')}</span>
|
||||
{:else}
|
||||
<span class="ml-2 self-center">{$i18n.t('Off')}</span>
|
||||
|
@ -302,6 +302,29 @@ export const getImportOrigin = (_chats) => {
|
||||
return 'webui';
|
||||
};
|
||||
|
||||
export const getUserPosition = async (raw = false) => {
|
||||
// Get the user's location using the Geolocation API
|
||||
const position = await new Promise((resolve, reject) => {
|
||||
navigator.geolocation.getCurrentPosition(resolve, reject);
|
||||
}).catch((error) => {
|
||||
console.error('Error getting user location:', error);
|
||||
throw error;
|
||||
});
|
||||
|
||||
if (!position) {
|
||||
return 'Location not available';
|
||||
}
|
||||
|
||||
// Extract the latitude and longitude from the position
|
||||
const { latitude, longitude } = position.coords;
|
||||
|
||||
if (raw) {
|
||||
return { latitude, longitude };
|
||||
} else {
|
||||
return `${latitude.toFixed(3)}, ${longitude.toFixed(3)} (lat, long)`;
|
||||
}
|
||||
};
|
||||
|
||||
const convertOpenAIMessages = (convo) => {
|
||||
// Parse OpenAI chat messages and create chat dictionary for creating new chats
|
||||
const mapping = convo['mapping'];
|
||||
@ -474,7 +497,7 @@ export const blobToFile = (blob, fileName) => {
|
||||
export const promptTemplate = (
|
||||
template: string,
|
||||
user_name?: string,
|
||||
current_location?: string
|
||||
user_location?: string
|
||||
): string => {
|
||||
// Get the current date
|
||||
const currentDate = new Date();
|
||||
@ -509,9 +532,9 @@ export const promptTemplate = (
|
||||
template = template.replace('{{USER_NAME}}', user_name);
|
||||
}
|
||||
|
||||
if (current_location) {
|
||||
// Replace {{CURRENT_LOCATION}} in the template with the current location
|
||||
template = template.replace('{{CURRENT_LOCATION}}', current_location);
|
||||
if (user_location) {
|
||||
// Replace {{USER_LOCATION}} in the template with the current location
|
||||
template = template.replace('{{USER_LOCATION}}', user_location);
|
||||
}
|
||||
|
||||
return template;
|
||||
|
Loading…
Reference in New Issue
Block a user