From 288d8a3e32de1b50760054a8b6ababa94ad75ecc Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 19 May 2024 08:00:07 -0700 Subject: [PATCH] feat: memory backend --- .../web/internal/migrations/008_add_memory.py | 53 ++++++++ backend/apps/web/main.py | 5 + backend/apps/web/models/memories.py | 109 ++++++++++++++++ backend/apps/web/routers/memories.py | 117 ++++++++++++++++++ backend/main.py | 14 ++- 5 files changed, 296 insertions(+), 2 deletions(-) create mode 100644 backend/apps/web/internal/migrations/008_add_memory.py create mode 100644 backend/apps/web/models/memories.py create mode 100644 backend/apps/web/routers/memories.py diff --git a/backend/apps/web/internal/migrations/008_add_memory.py b/backend/apps/web/internal/migrations/008_add_memory.py new file mode 100644 index 000000000..9307aa4d5 --- /dev/null +++ b/backend/apps/web/internal/migrations/008_add_memory.py @@ -0,0 +1,53 @@ +"""Peewee migrations -- 002_add_local_sharing.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + @migrator.create_model + class Memory(pw.Model): + id = pw.CharField(max_length=255, unique=True) + user_id = pw.CharField(max_length=255) + content = pw.TextField(null=False) + updated_at = pw.BigIntegerField(null=False) + created_at = pw.BigIntegerField(null=False) + + class Meta: + table_name = "memory" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model("memory") diff --git a/backend/apps/web/main.py b/backend/apps/web/main.py index 755e3911b..2b6966381 100644 --- a/backend/apps/web/main.py +++ b/backend/apps/web/main.py @@ -9,6 +9,7 @@ from apps.web.routers import ( modelfiles, prompts, configs, + memories, utils, ) from config import ( @@ -41,6 +42,7 @@ app.state.config.USER_PERMISSIONS = USER_PERMISSIONS app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER + app.add_middleware( CORSMiddleware, allow_origins=origins, @@ -52,9 +54,12 @@ app.add_middleware( app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(chats.router, prefix="/chats", tags=["chats"]) + app.include_router(documents.router, prefix="/documents", tags=["documents"]) app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) +app.include_router(memories.router, prefix="/memories", tags=["memories"]) + app.include_router(configs.router, prefix="/configs", tags=["configs"]) app.include_router(utils.router, prefix="/utils", tags=["utils"]) diff --git a/backend/apps/web/models/memories.py b/backend/apps/web/models/memories.py new file mode 100644 index 000000000..7b432a9ca --- /dev/null +++ b/backend/apps/web/models/memories.py @@ -0,0 +1,109 @@ +from pydantic import BaseModel +from peewee import * +from playhouse.shortcuts import model_to_dict +from typing import List, Union, Optional + +from apps.web.internal.db import DB +from apps.web.models.chats import Chats + +import time +import uuid + +#################### +# Memory DB Schema +#################### + + +class Memory(Model): + id = CharField(unique=True) + user_id = CharField() + content = TextField() + updated_at = BigIntegerField() + created_at = BigIntegerField() + + class Meta: + database = DB + + +class MemoryModel(BaseModel): + id: str + user_id: str + content: str + updated_at: int # timestamp in epoch + created_at: int # timestamp in epoch + + +#################### +# Forms +#################### + + +class MemoriesTable: + def __init__(self, db): + self.db = db + self.db.create_tables([Memory]) + + def insert_new_memory( + self, + user_id: str, + content: str, + ) -> Optional[MemoryModel]: + id = str(uuid.uuid4()) + + memory = MemoryModel( + **{ + "id": id, + "user_id": user_id, + "content": content, + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) + result = Memory.create(**memory.model_dump()) + if result: + return memory + else: + return None + + def get_memories(self) -> List[MemoryModel]: + try: + memories = Memory.select() + return [MemoryModel(**model_to_dict(memory)) for memory in memories] + except: + return None + + def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]: + try: + memories = Memory.select().where(Memory.user_id == user_id) + return [MemoryModel(**model_to_dict(memory)) for memory in memories] + except: + return None + + def get_memory_by_id(self, id) -> Optional[MemoryModel]: + try: + memory = Memory.get(Memory.id == id) + return MemoryModel(**model_to_dict(memory)) + except: + return None + + def delete_memory_by_id(self, id: str) -> bool: + try: + query = Memory.delete().where(Memory.id == id) + query.execute() # Remove the rows, return number of rows removed. + + return True + + except: + return False + + def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool: + try: + query = Memory.delete().where(Memory.id == id, Memory.user_id == user_id) + query.execute() + + return True + except: + return False + + +Memories = MemoriesTable(DB) diff --git a/backend/apps/web/routers/memories.py b/backend/apps/web/routers/memories.py new file mode 100644 index 000000000..d234d46de --- /dev/null +++ b/backend/apps/web/routers/memories.py @@ -0,0 +1,117 @@ +from fastapi import Response, Request +from fastapi import Depends, FastAPI, HTTPException, status +from datetime import datetime, timedelta +from typing import List, Union, Optional + +from fastapi import APIRouter +from pydantic import BaseModel +import logging + +from apps.web.models.memories import Memories, MemoryModel + +from utils.utils import get_verified_user +from constants import ERROR_MESSAGES + +from config import SRC_LOG_LEVELS, CHROMA_CLIENT + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + +router = APIRouter() + + +@router.get("/ef") +async def get_embeddings(request: Request): + return {"result": request.app.state.EMBEDDING_FUNCTION("hello world")} + + +############################ +# GetMemories +############################ + + +@router.get("/", response_model=List[MemoryModel]) +async def get_memories(user=Depends(get_verified_user)): + return Memories.get_memories_by_user_id(user.id) + + +############################ +# AddMemory +############################ + + +class AddMemoryForm(BaseModel): + content: str + + +@router.post("/add", response_model=Optional[MemoryModel]) +async def add_memory( + request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user) +): + memory = Memories.insert_new_memory(user.id, form_data.content) + memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content) + + collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}") + collection.upsert( + documents=[memory.content], + ids=[memory.id], + embeddings=[memory_embedding], + metadatas=[{"created_at": memory.created_at}], + ) + + return memory + + +############################ +# QueryMemory +############################ + + +class QueryMemoryForm(BaseModel): + content: str + + +@router.post("/query", response_model=Optional[MemoryModel]) +async def add_memory( + request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user) +): + query_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content) + collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}") + + results = collection.query( + query_embeddings=[query_embedding], + n_results=1, # how many results to return + ) + + return results + + +############################ +# ResetMemoryFromVectorDB +############################ +@router.get("/reset", response_model=bool) +async def reset_memory_from_vector_db( + request: Request, user=Depends(get_verified_user) +): + CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}") + collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}") + + memories = Memories.get_memories_by_user_id(user.id) + for memory in memories: + memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content) + collection.upsert( + documents=[memory.content], + ids=[memory.id], + embeddings=[memory_embedding], + ) + return True + + +############################ +# DeleteUserById +############################ + + +@router.delete("/{memory_id}", response_model=bool) +async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)): + return Memories.delete_memory_by_id_and_user_id(memory_id, user.id) diff --git a/backend/main.py b/backend/main.py index 209199591..4cf3243f7 100644 --- a/backend/main.py +++ b/backend/main.py @@ -238,9 +238,15 @@ async def check_url(request: Request, call_next): return response -app.mount("/api/v1", webui_app) -app.mount("/litellm/api", litellm_app) +@app.middleware("http") +async def update_embedding_function(request: Request, call_next): + response = await call_next(request) + if "/embedding/update" in request.url.path: + webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION + return response + +app.mount("/litellm/api", litellm_app) app.mount("/ollama", ollama_app) app.mount("/openai/api", openai_app) @@ -248,6 +254,10 @@ app.mount("/images/api/v1", images_app) app.mount("/audio/api/v1", audio_app) app.mount("/rag/api/v1", rag_app) +app.mount("/api/v1", webui_app) + +webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION + @app.get("/api/config") async def get_app_config():