feat: model update

This commit is contained in:
Timothy J. Baek 2024-05-24 18:26:36 -07:00
parent 0a48114bd2
commit 708d755eda
8 changed files with 396 additions and 407 deletions

View File

@ -40,6 +40,9 @@ app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.MODELS = {}
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER

View File

@ -33,6 +33,8 @@ class ModelParams(BaseModel):
# ModelMeta is a model for the data stored in the meta field of the Model table # ModelMeta is a model for the data stored in the meta field of the Model table
# It isn't currently used in the backend, but it's here as a reference # It isn't currently used in the backend, but it's here as a reference
class ModelMeta(BaseModel): class ModelMeta(BaseModel):
profile_image_url: Optional[str] = "/favicon.png"
description: Optional[str] = None description: Optional[str] = None
""" """
User-facing description of the model. User-facing description of the model.
@ -84,6 +86,7 @@ class Model(pw.Model):
class ModelModel(BaseModel): class ModelModel(BaseModel):
id: str id: str
user_id: str
base_model_id: Optional[str] = None base_model_id: Optional[str] = None
name: str name: str
@ -123,18 +126,26 @@ class ModelsTable:
self.db = db self.db = db
self.db.create_tables([Model]) self.db.create_tables([Model])
def insert_new_model(self, model: ModelForm, user_id: str) -> Optional[ModelModel]: def insert_new_model(
try: self, form_data: ModelForm, user_id: str
model = Model.create( ) -> Optional[ModelModel]:
model = ModelModel(
**{ **{
**model.model_dump(), **form_data.model_dump(),
"user_id": user_id, "user_id": user_id,
"created_at": int(time.time()), "created_at": int(time.time()),
"updated_at": int(time.time()), "updated_at": int(time.time()),
} }
) )
return ModelModel(**model_to_dict(model)) try:
except: result = Model.create(**model.model_dump())
if result:
return model
else:
return None
except Exception as e:
print(e)
return None return None
def get_all_models(self) -> List[ModelModel]: def get_all_models(self) -> List[ModelModel]:

View File

@ -1,4 +1,4 @@
from fastapi import Depends, FastAPI, HTTPException, status from fastapi import Depends, FastAPI, HTTPException, status, Request
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Union, Optional from typing import List, Union, Optional
@ -65,16 +65,27 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/update", response_model=Optional[ModelModel]) @router.post("/{id}/update", response_model=Optional[ModelModel])
async def update_model_by_id( async def update_model_by_id(
id: str, form_data: ModelForm, user=Depends(get_admin_user) request: Request, id: str, form_data: ModelForm, user=Depends(get_admin_user)
): ):
model = Models.get_model_by_id(id) model = Models.get_model_by_id(id)
if model: if model:
model = Models.update_model_by_id(id, form_data) model = Models.update_model_by_id(id, form_data)
return model return model
else:
if form_data.id in request.app.state.MODELS:
model = Models.insert_new_model(form_data, user.id)
print(model)
if model:
return model
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.DEFAULT(),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
) )

View File

@ -122,6 +122,9 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.MODELS = {}
origins = ["*"] origins = ["*"]
@ -238,6 +241,11 @@ app.add_middleware(
@app.middleware("http") @app.middleware("http")
async def check_url(request: Request, call_next): async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0:
await get_all_models()
else:
pass
start_time = int(time.time()) start_time = int(time.time())
response = await call_next(request) response = await call_next(request)
process_time = int(time.time()) - start_time process_time = int(time.time()) - start_time
@ -269,8 +277,7 @@ app.mount("/api/v1", webui_app)
webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
@app.get("/api/models") async def get_all_models():
async def get_models(user=Depends(get_verified_user)):
openai_models = [] openai_models = []
ollama_models = [] ollama_models = []
@ -282,8 +289,6 @@ async def get_models(user=Depends(get_verified_user)):
if app.state.config.ENABLE_OLLAMA_API: if app.state.config.ENABLE_OLLAMA_API:
ollama_models = await get_ollama_models() ollama_models = await get_ollama_models()
print(ollama_models)
ollama_models = [ ollama_models = [
{ {
"id": model["model"], "id": model["model"],
@ -296,9 +301,6 @@ async def get_models(user=Depends(get_verified_user)):
for model in ollama_models["models"] for model in ollama_models["models"]
] ]
print("openai", openai_models)
print("ollama", ollama_models)
models = openai_models + ollama_models models = openai_models + ollama_models
custom_models = Models.get_all_models() custom_models = Models.get_all_models()
@ -330,6 +332,16 @@ async def get_models(user=Depends(get_verified_user)):
} }
) )
app.state.MODELS = {model["id"]: model for model in models}
webui_app.state.MODELS = app.state.MODELS
return models
@app.get("/api/models")
async def get_models(user=Depends(get_verified_user)):
models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER: if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user": if user.role == "user":
models = list( models = list(

View File

@ -7,6 +7,8 @@
import { WEBUI_NAME, modelfiles, models, settings, user } from '$lib/stores'; import { WEBUI_NAME, modelfiles, models, settings, user } from '$lib/stores';
import { addNewModel, deleteModelById, getModelInfos } from '$lib/apis/models'; import { addNewModel, deleteModelById, getModelInfos } from '$lib/apis/models';
import { deleteModel } from '$lib/apis/ollama';
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { getModels } from '$lib/apis'; import { getModels } from '$lib/apis';
@ -17,13 +19,42 @@
let importFiles; let importFiles;
let modelfilesImportInputElement: HTMLInputElement; let modelfilesImportInputElement: HTMLInputElement;
const deleteModelHandler = async (id) => { const deleteModelHandler = async (model) => {
const res = await deleteModelById(localStorage.token, id); if (model?.info?.base_model_id) {
const res = await deleteModelById(localStorage.token, model.id);
if (res) { if (res) {
toast.success($i18n.t(`Deleted {{tagName}}`, { id })); toast.success($i18n.t(`Deleted {{name}}`, { name: model.id }));
} }
await models.set(await getModels(localStorage.token)); await models.set(await getModels(localStorage.token));
} else if (model?.owned_by === 'ollama') {
const res = await deleteModel(localStorage.token, model.id);
if (res) {
toast.success($i18n.t(`Deleted {{name}}`, { name: model.id }));
}
await models.set(await getModels(localStorage.token));
} else {
toast.error(
$i18n.t('{{ owner }}: You cannot delete this model', {
owner: model.owned_by.toUpperCase()
})
);
}
};
const cloneModelHandler = async (model) => {
if ((model?.info?.base_model_id ?? null) === null) {
toast.error($i18n.t('You cannot clone a base model'));
return;
} else {
sessionStorage.model = JSON.stringify({
...model,
id: `${model.id}-clone`,
name: `${model.name} (Clone)`
});
goto('/workspace/models/create');
}
}; };
const shareModelHandler = async (model) => { const shareModelHandler = async (model) => {
@ -104,7 +135,7 @@
<div class=" self-center w-10"> <div class=" self-center w-10">
<div class=" rounded-full bg-stone-700"> <div class=" rounded-full bg-stone-700">
<img <img
src={model?.meta?.profile_image_url ?? '/favicon.png'} src={model?.info?.meta?.profile_image_url ?? '/favicon.png'}
alt="modelfile profile" alt="modelfile profile"
class=" rounded-full w-full h-auto object-cover" class=" rounded-full w-full h-auto object-cover"
/> />
@ -114,7 +145,7 @@
<div class=" flex-1 self-center"> <div class=" flex-1 self-center">
<div class=" font-bold capitalize">{model.name}</div> <div class=" font-bold capitalize">{model.name}</div>
<div class=" text-sm overflow-hidden text-ellipsis line-clamp-1"> <div class=" text-sm overflow-hidden text-ellipsis line-clamp-1">
{model?.meta?.description ?? 'No description'} {model?.info?.meta?.description ?? model.id}
</div> </div>
</div> </div>
</a> </a>
@ -122,7 +153,7 @@
<a <a
class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl" class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
type="button" type="button"
href={`/workspace/models/edit?tag=${encodeURIComponent(model.id)}`} href={`/workspace/models/edit?id=${encodeURIComponent(model.id)}`}
> >
<svg <svg
xmlns="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg"
@ -144,8 +175,7 @@
class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl" class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
type="button" type="button"
on:click={() => { on:click={() => {
sessionStorage.model = JSON.stringify(model); cloneModelHandler(model);
goto('/workspace/models/create');
}} }}
> >
<svg <svg
@ -191,7 +221,7 @@
class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl" class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
type="button" type="button"
on:click={() => { on:click={() => {
deleteModelHandler(model.id); deleteModelHandler(model);
}} }}
> >
<svg <svg

View File

@ -48,7 +48,7 @@ export type Model = OpenAIModel | OllamaModel;
type BaseModel = { type BaseModel = {
id: string; id: string;
name: string; name: string;
custom_info?: ModelConfig; info?: ModelConfig;
}; };
export interface OpenAIModel extends BaseModel { export interface OpenAIModel extends BaseModel {

View File

@ -5,181 +5,83 @@
import { onMount, getContext } from 'svelte'; import { onMount, getContext } from 'svelte';
import { page } from '$app/stores'; import { page } from '$app/stores';
import { settings, user, config, models } from '$lib/stores';
import { settings, user, config, modelfiles } from '$lib/stores';
import { splitStream } from '$lib/utils'; import { splitStream } from '$lib/utils';
import { createModel } from '$lib/apis/ollama';
import { getModelInfos, updateModelById } from '$lib/apis/models'; import { getModelInfos, updateModelById } from '$lib/apis/models';
import AdvancedParams from '$lib/components/chat/Settings/Advanced/AdvancedParams.svelte'; import AdvancedParams from '$lib/components/chat/Settings/Advanced/AdvancedParams.svelte';
import { getModels } from '$lib/apis';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
let loading = false; let loading = false;
let success = false;
let filesInputElement; let filesInputElement;
let inputFiles; let inputFiles;
let imageUrl = null;
let digest = ''; let digest = '';
let pullProgress = null; let pullProgress = null;
let success = false;
let modelfile = null;
// /////////// // ///////////
// Modelfile // model
// /////////// // ///////////
let title = ''; let model = null;
let tagName = ''; let info = {
let desc = ''; id: '',
base_model_id: null,
// Raw Mode name: '',
let content = ''; meta: {
profile_image_url: '/favicon.png',
let suggestions = [ description: '',
{ content: '',
content: '' suggestion_prompts: []
} },
]; params: {}
let categories = {
character: false,
assistant: false,
writing: false,
productivity: false,
programming: false,
'data analysis': false,
lifestyle: false,
education: false,
business: false
};
onMount(() => {
tagName = $page.url.searchParams.get('tag');
if (tagName) {
modelfile = $modelfiles.filter((modelfile) => modelfile.tagName === tagName)[0];
console.log(modelfile);
imageUrl = modelfile.imageUrl;
title = modelfile.title;
desc = modelfile.desc;
content = modelfile.content;
suggestions =
modelfile.suggestionPrompts.length != 0
? modelfile.suggestionPrompts
: [
{
content: ''
}
];
for (const category of modelfile.categories) {
categories[category.toLowerCase()] = true;
}
} else {
goto('/workspace/modelfiles');
}
});
const updateModelfile = async (modelfile) => {
await updateModelById(localStorage.token, modelfile.tagName, modelfile);
await modelfiles.set(await getModelInfos(localStorage.token));
}; };
const updateHandler = async () => { const updateHandler = async () => {
loading = true; loading = true;
const res = await updateModelById(localStorage.token, info.id, info);
if (Object.keys(categories).filter((category) => categories[category]).length == 0) {
toast.error(
'Uh-oh! It looks like you missed selecting a category. Please choose one to complete your modelfile.'
);
}
if (
title !== '' &&
desc !== '' &&
content !== '' &&
Object.keys(categories).filter((category) => categories[category]).length > 0
) {
const res = await createModel(localStorage.token, tagName, content);
if (res) { if (res) {
const reader = res.body await goto('/workspace/models');
.pipeThrough(new TextDecoderStream()) await models.set(await getModels(localStorage.token));
.pipeThrough(splitStream('\n'))
.getReader();
while (true) {
const { value, done } = await reader.read();
if (done) break;
try {
let lines = value.split('\n');
for (const line of lines) {
if (line !== '') {
console.log(line);
let data = JSON.parse(line);
console.log(data);
if (data.error) {
throw data.error;
}
if (data.detail) {
throw data.detail;
} }
if (data.status) {
if (
!data.digest &&
!data.status.includes('writing') &&
!data.status.includes('sha256')
) {
toast.success(data.status);
if (data.status === 'success') {
success = true;
}
} else {
if (data.digest) {
digest = data.digest;
if (data.completed) {
pullProgress = Math.round((data.completed / data.total) * 1000) / 10;
} else {
pullProgress = 100;
}
}
}
}
}
}
} catch (error) {
console.log(error);
toast.error(error);
}
}
}
if (success) {
await updateModelfile({
tagName: tagName,
imageUrl: imageUrl,
title: title,
desc: desc,
content: content,
suggestionPrompts: suggestions.filter((prompt) => prompt.content !== ''),
categories: Object.keys(categories).filter((category) => categories[category])
});
await goto('/workspace/modelfiles');
}
}
loading = false; loading = false;
success = false; success = false;
}; };
onMount(() => {
const id = $page.url.searchParams.get('id');
if (id) {
model = $models.find((m) => m.id === id);
if (model) {
info = {
...info,
...JSON.parse(
JSON.stringify(
model?.info
? model?.info
: {
id: model.id,
name: model.name
}
)
)
};
console.log(model);
} else {
goto('/workspace/models');
}
} else {
goto('/workspace/models');
}
});
</script> </script>
<div class="w-full max-h-full"> <div class="w-full max-h-full">
@ -229,7 +131,7 @@
const compressedSrc = canvas.toDataURL('image/jpeg'); const compressedSrc = canvas.toDataURL('image/jpeg');
// Display the compressed image // Display the compressed image
imageUrl = compressedSrc; info.meta.profile_image_url = compressedSrc;
inputFiles = null; inputFiles = null;
}; };
@ -270,6 +172,8 @@
</div> </div>
<div class=" self-center font-medium text-sm">{$i18n.t('Back')}</div> <div class=" self-center font-medium text-sm">{$i18n.t('Back')}</div>
</button> </button>
{#if model}
<form <form
class="flex flex-col max-w-2xl mx-auto mt-4 mb-10" class="flex flex-col max-w-2xl mx-auto mt-4 mb-10"
on:submit|preventDefault={() => { on:submit|preventDefault={() => {
@ -279,7 +183,7 @@
<div class="flex justify-center my-4"> <div class="flex justify-center my-4">
<div class="self-center"> <div class="self-center">
<button <button
class=" {imageUrl class=" {info?.meta?.profile_image_url
? '' ? ''
: 'p-6'} rounded-full dark:bg-gray-700 border border-dashed border-gray-200" : 'p-6'} rounded-full dark:bg-gray-700 border border-dashed border-gray-200"
type="button" type="button"
@ -287,9 +191,9 @@
filesInputElement.click(); filesInputElement.click();
}} }}
> >
{#if imageUrl} {#if info?.meta?.profile_image_url}
<img <img
src={imageUrl} src={info?.meta?.profile_image_url}
alt="modelfile profile" alt="modelfile profile"
class=" rounded-full w-20 h-20 object-cover" class=" rounded-full w-20 h-20 object-cover"
/> />
@ -318,21 +222,21 @@
<div> <div>
<input <input
class="px-3 py-1.5 text-sm w-full bg-transparent border dark:border-gray-600 outline-none rounded-lg" class="px-3 py-1.5 text-sm w-full bg-transparent border dark:border-gray-600 outline-none rounded-lg"
placeholder={$i18n.t('Name your modelfile')} placeholder={$i18n.t('Name your model')}
bind:value={title} bind:value={info.name}
required required
/> />
</div> </div>
</div> </div>
<div class="flex-1"> <div class="flex-1">
<div class=" text-sm font-semibold mb-2">{$i18n.t('Model Tag Name')}*</div> <div class=" text-sm font-semibold mb-2">{$i18n.t('Model ID')}*</div>
<div> <div>
<input <input
class="px-3 py-1.5 text-sm w-full bg-transparent disabled:text-gray-500 border dark:border-gray-600 outline-none rounded-lg" class="px-3 py-1.5 text-sm w-full bg-transparent disabled:text-gray-500 border dark:border-gray-600 outline-none rounded-lg"
placeholder={$i18n.t('Add a model tag name')} placeholder={$i18n.t('Add a model id')}
value={tagName} value={info.id}
disabled disabled
required required
/> />
@ -341,13 +245,13 @@
</div> </div>
<div class="my-2"> <div class="my-2">
<div class=" text-sm font-semibold mb-2">{$i18n.t('Description')}*</div> <div class=" text-sm font-semibold mb-2">{$i18n.t('description')}*</div>
<div> <div>
<input <input
class="px-3 py-1.5 text-sm w-full bg-transparent border dark:border-gray-600 outline-none rounded-lg" class="px-3 py-1.5 text-sm w-full bg-transparent border dark:border-gray-600 outline-none rounded-lg"
placeholder={$i18n.t('Add a short description about what this modelfile does')} placeholder={$i18n.t('Add a short description about what this model does')}
bind:value={desc} bind:value={info.meta.description}
required required
/> />
</div> </div>
@ -355,22 +259,22 @@
<div class="my-2"> <div class="my-2">
<div class="flex w-full justify-between"> <div class="flex w-full justify-between">
<div class=" self-center text-sm font-semibold">{$i18n.t('Modelfile')}</div> <div class=" self-center text-sm font-semibold">{$i18n.t('Model')}</div>
</div> </div>
<!-- <div class=" text-sm font-semibold mb-2"></div> --> <!-- <div class=" text-sm font-semibold mb-2"></div> -->
<div class="mt-2"> <div class="mt-2">
<div class=" text-xs font-semibold mb-2">{$i18n.t('Content')}*</div> <div class=" text-xs font-semibold mb-2">{$i18n.t('Params')}*</div>
<div> <div>
<textarea <!-- <textarea
class="px-3 py-1.5 text-sm w-full bg-transparent border dark:border-gray-600 outline-none rounded-lg" class="px-3 py-1.5 text-sm w-full bg-transparent border dark:border-gray-600 outline-none rounded-lg"
placeholder={`FROM llama2\nPARAMETER temperature 1\nSYSTEM """\nYou are Mario from Super Mario Bros, acting as an assistant.\n"""`} placeholder={`FROM llama2\nPARAMETER temperature 1\nSYSTEM """\nYou are Mario from Super Mario Bros, acting as an assistant.\n"""`}
rows="6" rows="6"
bind:value={content} bind:value={content}
required required
/> /> -->
</div> </div>
</div> </div>
</div> </div>
@ -383,8 +287,11 @@
class="p-1 px-3 text-xs flex rounded transition" class="p-1 px-3 text-xs flex rounded transition"
type="button" type="button"
on:click={() => { on:click={() => {
if (suggestions.length === 0 || suggestions.at(-1).content !== '') { if (
suggestions = [...suggestions, { content: '' }]; info.meta.suggestion_prompts.length === 0 ||
info.meta.suggestion_prompts.at(-1).content !== ''
) {
info.meta.suggestion_prompts = [...info.meta.suggestion_prompts, { content: '' }];
} }
}} }}
> >
@ -401,7 +308,7 @@
</button> </button>
</div> </div>
<div class="flex flex-col space-y-1"> <div class="flex flex-col space-y-1">
{#each suggestions as prompt, promptIdx} {#each info.meta.suggestion_prompts as prompt, promptIdx}
<div class=" flex border dark:border-gray-600 rounded-lg"> <div class=" flex border dark:border-gray-600 rounded-lg">
<input <input
class="px-3 py-1.5 text-sm w-full bg-transparent outline-none border-r dark:border-gray-600" class="px-3 py-1.5 text-sm w-full bg-transparent outline-none border-r dark:border-gray-600"
@ -413,8 +320,8 @@
class="px-2" class="px-2"
type="button" type="button"
on:click={() => { on:click={() => {
suggestions.splice(promptIdx, 1); info.meta.suggestion_prompts.splice(promptIdx, 1);
suggestions = suggestions; info.meta.suggestion_prompts = info.meta.suggestion_prompts;
}} }}
> >
<svg <svg
@ -433,20 +340,6 @@
</div> </div>
</div> </div>
<div class="my-2">
<div class=" text-sm font-semibold mb-2">{$i18n.t('Categories')}</div>
<div class="grid grid-cols-4">
{#each Object.keys(categories) as category}
<div class="flex space-x-2 text-sm">
<input type="checkbox" bind:checked={categories[category]} />
<div class=" capitalize">{category}</div>
</div>
{/each}
</div>
</div>
{#if pullProgress !== null} {#if pullProgress !== null}
<div class="my-2"> <div class="my-2">
<div class=" text-sm font-semibold mb-2">{$i18n.t('Pull Progress')}</div> <div class=" text-sm font-semibold mb-2">{$i18n.t('Pull Progress')}</div>
@ -504,4 +397,5 @@
</button> </button>
</div> </div>
</form> </form>
{/if}
</div> </div>

File diff suppressed because one or more lines are too long