diff --git a/.gitignore b/.gitignore index 1250aef96..528e1f830 100644 --- a/.gitignore +++ b/.gitignore @@ -24,7 +24,6 @@ dist/ downloads/ eggs/ .eggs/ -lib/ lib64/ parts/ sdist/ diff --git a/README.md b/README.md index cd6558385..d77994c90 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,8 @@ Also check our sibling project, [OllamaHub](https://ollamahub.com/), where you c - ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction. +- 📜 **Prompt Preset Support**: Instantly access preset prompts using the '/' command in the chat input. Load predefined conversation starters effortlessly and expedite your interactions. Effortlessly import prompts through [OllamaHub](https://ollamahub.com/) integration. + - 📥🗑️ **Download/Delete Models**: Easily download or remove models directly from the web UI. - ⬆️ **GGUF File Model Creation**: Effortlessly create Ollama models by uploading GGUF files directly from the web UI. Streamlined process with options to upload from your machine or download GGUF files from Hugging Face. diff --git a/backend/apps/web/main.py b/backend/apps/web/main.py index f62857b79..153b5dcbb 100644 --- a/backend/apps/web/main.py +++ b/backend/apps/web/main.py @@ -1,7 +1,7 @@ from fastapi import FastAPI, Depends from fastapi.routing import APIRoute from fastapi.middleware.cors import CORSMiddleware -from apps.web.routers import auths, users, chats, modelfiles, utils +from apps.web.routers import auths, users, chats, modelfiles, prompts, configs, utils from config import WEBUI_VERSION, WEBUI_AUTH app = FastAPI() @@ -9,6 +9,7 @@ app = FastAPI() origins = ["*"] app.state.ENABLE_SIGNUP = True +app.state.DEFAULT_MODELS = None app.add_middleware( CORSMiddleware, @@ -19,13 +20,21 @@ app.add_middleware( ) app.include_router(auths.router, prefix="/auths", tags=["auths"]) - app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"]) +app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) + + +app.include_router(configs.router, prefix="/configs", tags=["configs"]) app.include_router(utils.router, prefix="/utils", tags=["utils"]) @app.get("/") async def get_status(): - return {"status": True, "version": WEBUI_VERSION, "auth": WEBUI_AUTH} + return { + "status": True, + "version": WEBUI_VERSION, + "auth": WEBUI_AUTH, + "default_models": app.state.DEFAULT_MODELS, + } diff --git a/backend/apps/web/models/modelfiles.py b/backend/apps/web/models/modelfiles.py index 4d8202db9..8231d8dff 100644 --- a/backend/apps/web/models/modelfiles.py +++ b/backend/apps/web/models/modelfiles.py @@ -12,7 +12,7 @@ from apps.web.internal.db import DB import json #################### -# User DB Schema +# Modelfile DB Schema #################### diff --git a/backend/apps/web/models/prompts.py b/backend/apps/web/models/prompts.py new file mode 100644 index 000000000..bb0710b60 --- /dev/null +++ b/backend/apps/web/models/prompts.py @@ -0,0 +1,117 @@ +from pydantic import BaseModel +from peewee import * +from playhouse.shortcuts import model_to_dict +from typing import List, Union, Optional +import time + +from utils.utils import decode_token +from utils.misc import get_gravatar_url + +from apps.web.internal.db import DB + +import json + +#################### +# Prompts DB Schema +#################### + + +class Prompt(Model): + command = CharField(unique=True) + user_id = CharField() + title = CharField() + content = TextField() + timestamp = DateField() + + class Meta: + database = DB + + +class PromptModel(BaseModel): + command: str + user_id: str + title: str + content: str + timestamp: int # timestamp in epoch + + +#################### +# Forms +#################### + + +class PromptForm(BaseModel): + command: str + title: str + content: str + + +class PromptsTable: + def __init__(self, db): + self.db = db + self.db.create_tables([Prompt]) + + def insert_new_prompt( + self, user_id: str, form_data: PromptForm + ) -> Optional[PromptModel]: + prompt = PromptModel( + **{ + "user_id": user_id, + "command": form_data.command, + "title": form_data.title, + "content": form_data.content, + "timestamp": int(time.time()), + } + ) + + try: + result = Prompt.create(**prompt.model_dump()) + if result: + return prompt + else: + return None + except: + return None + + def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: + try: + prompt = Prompt.get(Prompt.command == command) + return PromptModel(**model_to_dict(prompt)) + except: + return None + + def get_prompts(self) -> List[PromptModel]: + return [ + PromptModel(**model_to_dict(prompt)) + for prompt in Prompt.select() + # .limit(limit).offset(skip) + ] + + def update_prompt_by_command( + self, command: str, form_data: PromptForm + ) -> Optional[PromptModel]: + try: + query = Prompt.update( + title=form_data.title, + content=form_data.content, + timestamp=int(time.time()), + ).where(Prompt.command == command) + + query.execute() + + prompt = Prompt.get(Prompt.command == command) + return PromptModel(**model_to_dict(prompt)) + except: + return None + + def delete_prompt_by_command(self, command: str) -> bool: + try: + query = Prompt.delete().where((Prompt.command == command)) + query.execute() # Remove the rows, return number of rows removed. + + return True + except: + return False + + +Prompts = PromptsTable(DB) diff --git a/backend/apps/web/routers/auths.py b/backend/apps/web/routers/auths.py index fb1139898..714982e34 100644 --- a/backend/apps/web/routers/auths.py +++ b/backend/apps/web/routers/auths.py @@ -8,6 +8,7 @@ from pydantic import BaseModel import time import uuid + from apps.web.models.auths import ( SigninForm, SignupForm, @@ -20,7 +21,7 @@ from apps.web.models.users import Users from utils.utils import get_password_hash, get_current_user, create_token -from utils.misc import get_gravatar_url +from utils.misc import get_gravatar_url, validate_email_format from constants import ERROR_MESSAGES @@ -95,33 +96,38 @@ async def signin(form_data: SigninForm): @router.post("/signup", response_model=SigninResponse) async def signup(request: Request, form_data: SignupForm): if request.app.state.ENABLE_SIGNUP: - if not Users.get_user_by_email(form_data.email.lower()): - try: - role = "admin" if Users.get_num_users() == 0 else "pending" - hashed = get_password_hash(form_data.password) - user = Auths.insert_new_auth( - form_data.email.lower(), hashed, form_data.name, role - ) + if validate_email_format(form_data.email.lower()): + if not Users.get_user_by_email(form_data.email.lower()): + try: + role = "admin" if Users.get_num_users() == 0 else "pending" + hashed = get_password_hash(form_data.password) + user = Auths.insert_new_auth( + form_data.email.lower(), hashed, form_data.name, role + ) - if user: - token = create_token(data={"email": user.email}) - # response.set_cookie(key='token', value=token, httponly=True) + if user: + token = create_token(data={"email": user.email}) + # response.set_cookie(key='token', value=token, httponly=True) - return { - "token": token, - "token_type": "Bearer", - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - "profile_image_url": user.profile_image_url, - } - else: - raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) - except Exception as err: - raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) + return { + "token": token, + "token_type": "Bearer", + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + "profile_image_url": user.profile_image_url, + } + else: + raise HTTPException( + 500, detail=ERROR_MESSAGES.CREATE_USER_ERROR + ) + except Exception as err: + raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) + else: + raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) else: - raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT) else: raise HTTPException(400, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) diff --git a/backend/apps/web/routers/configs.py b/backend/apps/web/routers/configs.py new file mode 100644 index 000000000..b57fae3d5 --- /dev/null +++ b/backend/apps/web/routers/configs.py @@ -0,0 +1,41 @@ +from fastapi import Response, Request +from fastapi import Depends, FastAPI, HTTPException, status +from datetime import datetime, timedelta +from typing import List, Union + +from fastapi import APIRouter +from pydantic import BaseModel +import time +import uuid + +from apps.web.models.users import Users + + +from utils.utils import get_password_hash, get_current_user, create_token +from utils.misc import get_gravatar_url, validate_email_format +from constants import ERROR_MESSAGES + +router = APIRouter() + + +class SetDefaultModelsForm(BaseModel): + models: str + + +############################ +# SetDefaultModels +############################ + + +@router.post("/default/models", response_model=str) +async def set_global_default_models( + request: Request, form_data: SetDefaultModelsForm, user=Depends(get_current_user) +): + if user.role == "admin": + request.app.state.DEFAULT_MODELS = form_data.models + return request.app.state.DEFAULT_MODELS + else: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) diff --git a/backend/apps/web/routers/prompts.py b/backend/apps/web/routers/prompts.py new file mode 100644 index 000000000..5a002c941 --- /dev/null +++ b/backend/apps/web/routers/prompts.py @@ -0,0 +1,115 @@ +from fastapi import Depends, FastAPI, HTTPException, status +from datetime import datetime, timedelta +from typing import List, Union, Optional + +from fastapi import APIRouter +from pydantic import BaseModel +import json + + +from apps.web.models.prompts import Prompts, PromptForm, PromptModel + +from utils.utils import get_current_user +from constants import ERROR_MESSAGES + +router = APIRouter() + +############################ +# GetPrompts +############################ + + +@router.get("/", response_model=List[PromptModel]) +async def get_prompts(user=Depends(get_current_user)): + return Prompts.get_prompts() + + +############################ +# CreateNewPrompt +############################ + + +@router.post("/create", response_model=Optional[PromptModel]) +async def create_new_prompt(form_data: PromptForm, user=Depends(get_current_user)): + if user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + prompt = Prompts.get_prompt_by_command(form_data.command) + if prompt == None: + prompt = Prompts.insert_new_prompt(user.id, form_data) + + if prompt: + return prompt + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.DEFAULT(), + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.COMMAND_TAKEN, + ) + + +############################ +# GetPromptByCommand +############################ + + +@router.get("/{command}", response_model=Optional[PromptModel]) +async def get_prompt_by_command(command: str, user=Depends(get_current_user)): + prompt = Prompts.get_prompt_by_command(f"/{command}") + + if prompt: + return prompt + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdatePromptByCommand +############################ + + +@router.post("/{command}/update", response_model=Optional[PromptModel]) +async def update_prompt_by_command( + command: str, form_data: PromptForm, user=Depends(get_current_user) +): + if user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) + if prompt: + return prompt + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + +############################ +# DeletePromptByCommand +############################ + + +@router.delete("/{command}/delete", response_model=bool) +async def delete_prompt_by_command(command: str, user=Depends(get_current_user)): + if user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + result = Prompts.delete_prompt_by_command(f"/{command}") + return result diff --git a/backend/constants.py b/backend/constants.py index 761507f2b..0817445b5 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -17,10 +17,12 @@ class ERROR_MESSAGES(str, Enum): USERNAME_TAKEN = ( "Uh-oh! This username is already registered. Please choose another username." ) + COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string." INVALID_TOKEN = ( "Your session has expired or the token is invalid. Please sign in again." ) INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again." + INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)." INVALID_PASSWORD = ( "The password provided is incorrect. Please check for typos and try again." ) @@ -31,5 +33,4 @@ class ERROR_MESSAGES(str, Enum): ) NOT_FOUND = "We could not find what you're looking for :/" USER_NOT_FOUND = "We could not find what you're looking for :/" - MALICIOUS = "Unusual activities detected, please try again in a few minutes." diff --git a/backend/start.sh b/backend/start.sh old mode 100644 new mode 100755 index 9fb1f576f..ba68207ec --- a/backend/start.sh +++ b/backend/start.sh @@ -1 +1,3 @@ -uvicorn main:app --host 0.0.0.0 --port 8080 --forwarded-allow-ips '*' \ No newline at end of file +#!/usr/bin/env bash + +uvicorn main:app --host 0.0.0.0 --port 8080 --forwarded-allow-ips '*' diff --git a/backend/utils/misc.py b/backend/utils/misc.py index c6508487a..5635c57ac 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -1,4 +1,5 @@ import hashlib +import re def get_gravatar_url(email): @@ -21,3 +22,9 @@ def calculate_sha256(file): for chunk in iter(lambda: file.read(8192), b""): sha256.update(chunk) return sha256.hexdigest() + + +def validate_email_format(email: str) -> bool: + if not re.match(r"[^@]+@[^@]+\.[^@]+", email): + return False + return True diff --git a/src/app.html b/src/app.html index c2268851c..9b1099b0b 100644 --- a/src/app.html +++ b/src/app.html @@ -12,11 +12,12 @@ (!('theme' in localStorage) && window.matchMedia('(prefers-color-scheme: light)').matches) ) { document.documentElement.classList.add('light'); - } else if (localStorage.theme === 'dark') { - document.documentElement.classList.add('dark'); + } else if (localStorage.theme) { + localStorage.theme.split(' ').forEach((e) => { + document.documentElement.classList.add(e); + }); } else { document.documentElement.classList.add('dark'); - document.documentElement.classList.add(localStorage.theme); } diff --git a/src/lib/apis/configs/index.ts b/src/lib/apis/configs/index.ts new file mode 100644 index 000000000..76256f4d8 --- /dev/null +++ b/src/lib/apis/configs/index.ts @@ -0,0 +1,31 @@ +import { WEBUI_API_BASE_URL } from '$lib/constants'; + +export const setDefaultModels = async (token: string, models: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/default/models`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + models: models + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/apis/prompts/index.ts b/src/lib/apis/prompts/index.ts new file mode 100644 index 000000000..7ed303b3e --- /dev/null +++ b/src/lib/apis/prompts/index.ts @@ -0,0 +1,178 @@ +import { WEBUI_API_BASE_URL } from '$lib/constants'; + +export const createNewPrompt = async ( + token: string, + command: string, + title: string, + content: string +) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/create`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + command: `/${command}`, + title: title, + content: content + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getPrompts = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getPromptByCommand = async (token: string, command: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/${command}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updatePromptByCommand = async ( + token: string, + command: string, + title: string, + content: string +) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/${command}/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + command: `/${command}`, + title: title, + content: content + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const deletePromptByCommand = async (token: string, command: string) => { + let error = null; + + command = command.charAt(0) === '/' ? command.slice(1) : command; + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/${command}/delete`, { + method: 'DELETE', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index d2fe8ca29..1468310d4 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -1,8 +1,11 @@ + +{#if filteredPromptCommands.length > 0} +