From 2d4b4756baeca495ae4daebd6b9c96c764736ad0 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Fri, 28 Mar 2025 01:30:03 -0700 Subject: [PATCH] fix: memory server schema --- servers/memory/main.py | 50 ++++++++++++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 14 deletions(-) diff --git a/servers/memory/main.py b/servers/memory/main.py index 9e62b8d..5c54cca 100644 --- a/servers/memory/main.py +++ b/servers/memory/main.py @@ -119,10 +119,35 @@ class CreateRelationsRequest(BaseModel): ) +class ObservationItem(BaseModel): + entityName: str = Field( + ..., description="The name of the entity to add the observations to" + ) + contents: List[str] = Field( + ..., description="An array of observation contents to add" + ) + + +class DeletionItem(BaseModel): + entityName: str = Field( + ..., description="The name of the entity containing the observations" + ) + observations: List[str] = Field( + ..., description="An array of observations to delete" + ) + + class AddObservationsRequest(BaseModel): - observations: List[dict] = Field( + observations: List[ObservationItem] = Field( ..., - description="Each item includes an entity name and an array of observations to add", + description="A list of observation additions, each specifying an entity and contents to add", + ) + + +class DeleteObservationsRequest(BaseModel): + deletions: List[DeletionItem] = Field( + ..., + description="A list of observation deletions, each specifying an entity and observations to remove", ) @@ -132,13 +157,6 @@ class DeleteEntitiesRequest(BaseModel): ) -class DeleteObservationsRequest(BaseModel): - deletions: List[dict] = Field( - ..., - description="Each item includes an entity name and an array of observations to delete", - ) - - class DeleteRelationsRequest(BaseModel): relations: List[Relation] = Field( ..., description="An array of relations to delete" @@ -183,15 +201,17 @@ def create_relations(req: CreateRelationsRequest): def add_observations(req: AddObservationsRequest): graph = read_graph_file() results = [] + for obs in req.observations: - name = obs["entityName"] - contents = obs["contents"] + name = obs.entityName + contents = obs.contents entity = next((e for e in graph.entities if e.name == name), None) if not entity: raise HTTPException(status_code=404, detail=f"Entity {name} not found") added = [c for c in contents if c not in entity.observations] entity.observations.extend(added) results.append({"entityName": name, "addedObservations": added}) + save_graph(graph) return results @@ -212,14 +232,16 @@ def delete_entities(req: DeleteEntitiesRequest): @app.post("/delete_observations", summary="Delete specific observations from entities") def delete_observations(req: DeleteObservationsRequest): graph = read_graph_file() + for deletion in req.deletions: - name = deletion["entityName"] - obs_to_delete = deletion["observations"] + name = deletion.entityName + to_delete = deletion.observations entity = next((e for e in graph.entities if e.name == name), None) if entity: entity.observations = [ - o for o in entity.observations if o not in obs_to_delete + obs for obs in entity.observations if obs not in to_delete ] + save_graph(graph) return {"message": "Observations deleted successfully"}