fix: memory server schema

This commit is contained in:
Timothy Jaeryang Baek
2025-03-28 01:30:03 -07:00
parent 3117338af3
commit 2d4b4756ba

View File

@@ -119,10 +119,35 @@ class CreateRelationsRequest(BaseModel):
) )
class ObservationItem(BaseModel):
entityName: str = Field(
..., description="The name of the entity to add the observations to"
)
contents: List[str] = Field(
..., description="An array of observation contents to add"
)
class DeletionItem(BaseModel):
entityName: str = Field(
..., description="The name of the entity containing the observations"
)
observations: List[str] = Field(
..., description="An array of observations to delete"
)
class AddObservationsRequest(BaseModel): class AddObservationsRequest(BaseModel):
observations: List[dict] = Field( observations: List[ObservationItem] = Field(
..., ...,
description="Each item includes an entity name and an array of observations to add", description="A list of observation additions, each specifying an entity and contents to add",
)
class DeleteObservationsRequest(BaseModel):
deletions: List[DeletionItem] = Field(
...,
description="A list of observation deletions, each specifying an entity and observations to remove",
) )
@@ -132,13 +157,6 @@ class DeleteEntitiesRequest(BaseModel):
) )
class DeleteObservationsRequest(BaseModel):
deletions: List[dict] = Field(
...,
description="Each item includes an entity name and an array of observations to delete",
)
class DeleteRelationsRequest(BaseModel): class DeleteRelationsRequest(BaseModel):
relations: List[Relation] = Field( relations: List[Relation] = Field(
..., description="An array of relations to delete" ..., description="An array of relations to delete"
@@ -183,15 +201,17 @@ def create_relations(req: CreateRelationsRequest):
def add_observations(req: AddObservationsRequest): def add_observations(req: AddObservationsRequest):
graph = read_graph_file() graph = read_graph_file()
results = [] results = []
for obs in req.observations: for obs in req.observations:
name = obs["entityName"] name = obs.entityName
contents = obs["contents"] contents = obs.contents
entity = next((e for e in graph.entities if e.name == name), None) entity = next((e for e in graph.entities if e.name == name), None)
if not entity: if not entity:
raise HTTPException(status_code=404, detail=f"Entity {name} not found") raise HTTPException(status_code=404, detail=f"Entity {name} not found")
added = [c for c in contents if c not in entity.observations] added = [c for c in contents if c not in entity.observations]
entity.observations.extend(added) entity.observations.extend(added)
results.append({"entityName": name, "addedObservations": added}) results.append({"entityName": name, "addedObservations": added})
save_graph(graph) save_graph(graph)
return results return results
@@ -212,14 +232,16 @@ def delete_entities(req: DeleteEntitiesRequest):
@app.post("/delete_observations", summary="Delete specific observations from entities") @app.post("/delete_observations", summary="Delete specific observations from entities")
def delete_observations(req: DeleteObservationsRequest): def delete_observations(req: DeleteObservationsRequest):
graph = read_graph_file() graph = read_graph_file()
for deletion in req.deletions: for deletion in req.deletions:
name = deletion["entityName"] name = deletion.entityName
obs_to_delete = deletion["observations"] to_delete = deletion.observations
entity = next((e for e in graph.entities if e.name == name), None) entity = next((e for e in graph.entities if e.name == name), None)
if entity: if entity:
entity.observations = [ entity.observations = [
o for o in entity.observations if o not in obs_to_delete obs for obs in entity.observations if obs not in to_delete
] ]
save_graph(graph) save_graph(graph)
return {"message": "Observations deleted successfully"} return {"message": "Observations deleted successfully"}