This commit is contained in:
Timothy J. Baek
2024-06-15 03:35:44 -06:00
parent a6ee7415d8
commit 3c599e24e5
5 changed files with 183 additions and 63 deletions

View File

@@ -43,9 +43,11 @@ async def get_memories(user=Depends(get_verified_user)):
class AddMemoryForm(BaseModel):
content: str
class MemoryUpdateModel(BaseModel):
content: Optional[str] = None
@router.post("/add", response_model=Optional[MemoryModel])
async def add_memory(
request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user)
@@ -64,9 +66,12 @@ async def add_memory(
return memory
@router.post("/{memory_id}", response_model=Optional[MemoryModel])
@router.post("/{memory_id}/update", response_model=Optional[MemoryModel])
async def update_memory_by_id(
memory_id: str, request: Request, form_data: MemoryUpdateModel, user=Depends(get_verified_user)
memory_id: str,
request: Request,
form_data: MemoryUpdateModel,
user=Depends(get_verified_user),
):
memory = Memories.update_memory_by_id(memory_id, form_data.content)
if memory is None:
@@ -74,12 +79,16 @@ async def update_memory_by_id(
if form_data.content is not None:
memory_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
collection = CHROMA_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
)
collection.upsert(
documents=[form_data.content],
ids=[memory.id],
embeddings=[memory_embedding],
metadatas=[{"created_at": memory.created_at, "updated_at": memory.updated_at}],
metadatas=[
{"created_at": memory.created_at, "updated_at": memory.updated_at}
],
)
return memory