feat: db migration to sqlite

This commit is contained in:
Timothy J. Baek 2023-12-25 21:44:28 -08:00
parent eadbfeb277
commit 9174331025
13 changed files with 302 additions and 60 deletions

3
backend/.gitignore vendored
View File

@ -2,3 +2,6 @@ __pycache__
.env
_old
uploads
.ipynb_checkpoints
*.db
_test

View File

@ -25,7 +25,7 @@ TARGET_SERVER_URL = OLLAMA_API_BASE_URL
def proxy(path):
# Combine the base URL of the target server with the requested path
target_url = f"{TARGET_SERVER_URL}/{path}"
print(path)
print(target_url)
# Get data from the original request
data = request.get_data()

View File

@ -0,0 +1,4 @@
from peewee import *
DB = SqliteDatabase("./ollama.db")
DB.connect()

View File

@ -1,7 +1,7 @@
from fastapi import FastAPI, Request, Depends, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from apps.web.routers import auths, users, utils
from apps.web.routers import auths, users, chats, utils
from config import WEBUI_VERSION, WEBUI_AUTH
app = FastAPI()
@ -20,6 +20,7 @@ app.add_middleware(
app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(utils.router, prefix="/utils", tags=["utils"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])
@app.get("/")

View File

@ -2,6 +2,7 @@ from pydantic import BaseModel
from typing import List, Union, Optional
import time
import uuid
from peewee import *
from apps.web.models.users import UserModel, Users
@ -12,15 +13,23 @@ from utils.utils import (
create_token,
)
import config
DB = config.DB
from apps.web.internal.db import DB
####################
# DB MODEL
####################
class Auth(Model):
id = CharField(unique=True)
email = CharField()
password = CharField()
active = BooleanField()
class Meta:
database = DB
class AuthModel(BaseModel):
id: str
email: str
@ -64,7 +73,7 @@ class SignupForm(BaseModel):
class AuthsTable:
def __init__(self, db):
self.db = db
self.table = db.auths
self.db.create_tables([Auth])
def insert_new_auth(
self, email: str, password: str, name: str, role: str = "pending"
@ -76,7 +85,9 @@ class AuthsTable:
auth = AuthModel(
**{"id": id, "email": email, "password": password, "active": True}
)
result = self.table.insert_one(auth.model_dump())
result = Auth.create(**auth.model_dump())
print(result)
user = Users.insert_new_user(id, name, email, role)
print(result, user)
@ -86,14 +97,19 @@ class AuthsTable:
return None
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
print("authenticate_user")
print("authenticate_user", email)
auth = self.table.find_one({"email": email, "active": True})
auth = Auth.get(Auth.email == email, Auth.active == True)
print(auth.email)
if auth:
if verify_password(password, auth["password"]):
user = self.db.users.find_one({"id": auth["id"]})
return UserModel(**user)
print(password, str(auth.password))
print(verify_password(password, str(auth.password)))
if verify_password(password, auth.password):
user = Users.get_user_by_id(auth.id)
print(user)
return user
else:
return None
else:

View File

@ -0,0 +1,108 @@
from pydantic import BaseModel
from typing import List, Union, Optional
from peewee import *
from playhouse.shortcuts import model_to_dict
import json
import uuid
import time
from apps.web.internal.db import DB
####################
# Chat DB Schema
####################
class Chat(Model):
id = CharField(unique=True)
user_id: CharField()
title = CharField()
chat = TextField() # Save Chat JSON as Text
timestamp = DateField()
class Meta:
database = DB
class ChatModel(BaseModel):
id: str
user_id: str
title: str
chat: dict
timestamp: int # timestamp in epoch
####################
# Forms
####################
class ChatForm(BaseModel):
chat: dict
class ChatUpdateForm(ChatForm):
id: str
class ChatTitleIdResponse(BaseModel):
id: str
title: str
class ChatTable:
def __init__(self, db):
self.db = db
db.create_tables([Chat])
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
id = str(uuid.uuid4())
chat = ChatModel(
**{
"id": id,
"user_id": user_id,
"title": form_data.chat["title"],
"chat": json.dump(form_data.chat),
"timestamp": int(time.time()),
}
)
result = Chat.create(**chat.model_dump())
return chat if result else None
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
try:
query = Chat.update(chat=json.dump(chat)).where(Chat.id == id)
query.execute()
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
except:
return None
def get_chat_lists_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select(Chat.user_id == user_id).limit(limit).offset(skip)
]
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
try:
chat = Chat.get(Chat.id == id, Chat.user_id == user_id)
return ChatModel(**model_to_dict(chat))
except:
return None
def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select().limit(limit).offset(skip)
]
Chats = ChatTable(DB)

View File

@ -1,25 +1,41 @@
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
from pymongo import ReturnDocument
import time
from utils.utils import decode_token
from utils.misc import get_gravatar_url
from config import DB
from apps.web.internal.db import DB
####################
# User DB Schema
####################
class User(Model):
id = CharField(unique=True)
name = CharField()
email = CharField()
role = CharField()
profile_image_url = CharField()
timestamp = DateField()
class Meta:
database = DB
class UserModel(BaseModel):
class Config:
orm_mode = True
id: str
name: str
email: str
role: str = "pending"
profile_image_url: str = "/user.png"
created_at: int # timestamp in epoch
timestamp: int # timestamp in epoch
####################
@ -35,7 +51,7 @@ class UserRoleUpdateForm(BaseModel):
class UsersTable:
def __init__(self, db):
self.db = db
self.table = db.users
self.db.create_tables([User])
def insert_new_user(
self, id: str, name: str, email: str, role: str = "pending"
@ -47,22 +63,27 @@ class UsersTable:
"email": email,
"role": role,
"profile_image_url": get_gravatar_url(email),
"created_at": int(time.time()),
"timestamp": int(time.time()),
}
)
result = self.table.insert_one(user.model_dump())
result = User.create(**user.model_dump())
if result:
return user
else:
return None
def get_user_by_email(self, email: str) -> Optional[UserModel]:
user = self.table.find_one({"email": email}, {"_id": False})
def get_user_by_id(self, id: str) -> Optional[UserModel]:
try:
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
except:
return None
if user:
return UserModel(**user)
else:
def get_user_by_email(self, email: str) -> Optional[UserModel]:
try:
user = User.get(User.email == email)
return UserModel(**model_to_dict(user))
except:
return None
def get_user_by_token(self, token: str) -> Optional[UserModel]:
@ -75,23 +96,22 @@ class UsersTable:
def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
return [
UserModel(**user)
for user in list(
self.table.find({}, {"_id": False}).skip(skip).limit(limit)
)
UserModel(**model_to_dict(user))
for user in User.select().limit(limit).offset(skip)
]
def get_num_users(self) -> Optional[int]:
return self.table.count_documents({})
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
user = self.table.find_one_and_update(
{"id": id}, {"$set": updated}, return_document=ReturnDocument.AFTER
)
return UserModel(**user)
return User.select().count()
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
return self.update_user_by_id(id, {"role": role})
try:
query = User.update(role=role).where(User.id == id)
query.execute()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
except:
return None
Users = UsersTable(DB)

View File

@ -104,8 +104,8 @@ async def signup(form_data: SignupForm):
"profile_image_url": user.profile_image_url,
}
else:
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT())
except Exception as err:
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
else:
raise HTTPException(400, detail=ERROR_MESSAGES.DEFAULT())
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)

View File

@ -0,0 +1,100 @@
from fastapi import Response
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
from apps.web.models.users import Users
from apps.web.models.chats import (
ChatModel,
ChatForm,
ChatUpdateForm,
ChatTitleIdResponse,
Chats,
)
from utils.utils import (
bearer_scheme,
)
from constants import ERROR_MESSAGES
router = APIRouter()
############################
# GetChats
############################
@router.get("/", response_model=List[ChatTitleIdResponse])
async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
return Chats.get_chat_titles_and_ids_by_user_id(user.id, skip, limit)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################
# CreateNewChat
############################
@router.post("/new", response_model=Optional[ChatModel])
async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
return Chats.insert_new_chat(user.id, form_data)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################
# GetChatById
############################
@router.get("/{id}", response_model=Optional[ChatModel])
async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
return Chats.get_chat_by_id_and_user_id(id, user.id)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################
# UpdateChatById
############################
@router.post("/{id}", response_model=Optional[ChatModel])
async def update_chat_by_id(
id: str, form_data: ChatUpdateForm, cred=Depends(bearer_scheme)
):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
return Chats.update_chat_by_id_and_user_id(id, user.id, form_data.chat)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)

View File

@ -1,9 +1,10 @@
from dotenv import load_dotenv, find_dotenv
from pymongo import MongoClient
from constants import ERROR_MESSAGES
from secrets import token_bytes
from base64 import b64encode
import os
load_dotenv(find_dotenv("../.env"))
@ -36,25 +37,8 @@ WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.40")
# WEBUI_AUTH
####################################
WEBUI_AUTH = True if os.environ.get("WEBUI_AUTH", "FALSE") == "TRUE" else False
####################################
# WEBUI_DB (Deprecated, Should be removed)
####################################
WEBUI_DB_URL = os.environ.get("WEBUI_DB_URL", "mongodb://root:root@localhost:27017/")
if WEBUI_AUTH and WEBUI_DB_URL == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
DB_CLIENT = MongoClient(f"{WEBUI_DB_URL}?authSource=admin")
DB = DB_CLIENT["ollama-webui"]
####################################
# WEBUI_JWT_SECRET_KEY
####################################

View File

@ -11,6 +11,12 @@ class ERROR_MESSAGES(str, Enum):
DEFAULT = lambda err="": f"Something went wrong :/\n{err if err else ''}"
ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now."
EMAIL_TAKEN = "Uh-oh! This email is already registered. Sign in with your existing account or choose another email to start anew."
USERNAME_TAKEN = (
"Uh-oh! This username is already registered. Please choose another username."
)
INVALID_TOKEN = (
"Your session has expired or the token is invalid. Please sign in again."
)

View File

@ -13,7 +13,7 @@ uuid
requests
aiohttp
pymongo
peewee
bcrypt
PyJWT

View File

@ -66,7 +66,7 @@
if (res) {
console.log(res);
toast.success(`Account creation successful."`);
toast.success(`Account creation successful.`);
localStorage.token = res.token;
await user.set(res);
goto('/');