feat: memory backend

This commit is contained in:
Timothy J. Baek 2024-05-19 08:00:07 -07:00
parent 1fb5ef99e2
commit 288d8a3e32
5 changed files with 296 additions and 2 deletions

View File

@ -0,0 +1,53 @@
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
@migrator.create_model
class Memory(pw.Model):
id = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
content = pw.TextField(null=False)
updated_at = pw.BigIntegerField(null=False)
created_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "memory"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("memory")

View File

@ -9,6 +9,7 @@ from apps.web.routers import (
modelfiles,
prompts,
configs,
memories,
utils,
)
from config import (
@ -41,6 +42,7 @@ app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
@ -52,9 +54,12 @@ app.add_middleware(
app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(documents.router, prefix="/documents", tags=["documents"])
app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(memories.router, prefix="/memories", tags=["memories"])
app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(utils.router, prefix="/utils", tags=["utils"])

View File

@ -0,0 +1,109 @@
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
from apps.web.internal.db import DB
from apps.web.models.chats import Chats
import time
import uuid
####################
# Memory DB Schema
####################
class Memory(Model):
id = CharField(unique=True)
user_id = CharField()
content = TextField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Meta:
database = DB
class MemoryModel(BaseModel):
id: str
user_id: str
content: str
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
####################
# Forms
####################
class MemoriesTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Memory])
def insert_new_memory(
self,
user_id: str,
content: str,
) -> Optional[MemoryModel]:
id = str(uuid.uuid4())
memory = MemoryModel(
**{
"id": id,
"user_id": user_id,
"content": content,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
result = Memory.create(**memory.model_dump())
if result:
return memory
else:
return None
def get_memories(self) -> List[MemoryModel]:
try:
memories = Memory.select()
return [MemoryModel(**model_to_dict(memory)) for memory in memories]
except:
return None
def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]:
try:
memories = Memory.select().where(Memory.user_id == user_id)
return [MemoryModel(**model_to_dict(memory)) for memory in memories]
except:
return None
def get_memory_by_id(self, id) -> Optional[MemoryModel]:
try:
memory = Memory.get(Memory.id == id)
return MemoryModel(**model_to_dict(memory))
except:
return None
def delete_memory_by_id(self, id: str) -> bool:
try:
query = Memory.delete().where(Memory.id == id)
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try:
query = Memory.delete().where(Memory.id == id, Memory.user_id == user_id)
query.execute()
return True
except:
return False
Memories = MemoriesTable(DB)

View File

@ -0,0 +1,117 @@
from fastapi import Response, Request
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import logging
from apps.web.models.memories import Memories, MemoryModel
from utils.utils import get_verified_user
from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
@router.get("/ef")
async def get_embeddings(request: Request):
return {"result": request.app.state.EMBEDDING_FUNCTION("hello world")}
############################
# GetMemories
############################
@router.get("/", response_model=List[MemoryModel])
async def get_memories(user=Depends(get_verified_user)):
return Memories.get_memories_by_user_id(user.id)
############################
# AddMemory
############################
class AddMemoryForm(BaseModel):
content: str
@router.post("/add", response_model=Optional[MemoryModel])
async def add_memory(
request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user)
):
memory = Memories.insert_new_memory(user.id, form_data.content)
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
collection.upsert(
documents=[memory.content],
ids=[memory.id],
embeddings=[memory_embedding],
metadatas=[{"created_at": memory.created_at}],
)
return memory
############################
# QueryMemory
############################
class QueryMemoryForm(BaseModel):
content: str
@router.post("/query", response_model=Optional[MemoryModel])
async def add_memory(
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
):
query_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
results = collection.query(
query_embeddings=[query_embedding],
n_results=1, # how many results to return
)
return results
############################
# ResetMemoryFromVectorDB
############################
@router.get("/reset", response_model=bool)
async def reset_memory_from_vector_db(
request: Request, user=Depends(get_verified_user)
):
CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
memories = Memories.get_memories_by_user_id(user.id)
for memory in memories:
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
collection.upsert(
documents=[memory.content],
ids=[memory.id],
embeddings=[memory_embedding],
)
return True
############################
# DeleteUserById
############################
@router.delete("/{memory_id}", response_model=bool)
async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
return Memories.delete_memory_by_id_and_user_id(memory_id, user.id)

View File

@ -238,9 +238,15 @@ async def check_url(request: Request, call_next):
return response
app.mount("/api/v1", webui_app)
app.mount("/litellm/api", litellm_app)
@app.middleware("http")
async def update_embedding_function(request: Request, call_next):
response = await call_next(request)
if "/embedding/update" in request.url.path:
webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
return response
app.mount("/litellm/api", litellm_app)
app.mount("/ollama", ollama_app)
app.mount("/openai/api", openai_app)
@ -248,6 +254,10 @@ app.mount("/images/api/v1", images_app)
app.mount("/audio/api/v1", audio_app)
app.mount("/rag/api/v1", rag_app)
app.mount("/api/v1", webui_app)
webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
@app.get("/api/config")
async def get_app_config():