fix: memory server schema

This commit is contained in:
Timothy Jaeryang Baek 2025-03-28 01:30:03 -07:00
parent 3117338af3
commit 2d4b4756ba

View File

@ -119,10 +119,35 @@ class CreateRelationsRequest(BaseModel):
)
class ObservationItem(BaseModel):
entityName: str = Field(
..., description="The name of the entity to add the observations to"
)
contents: List[str] = Field(
..., description="An array of observation contents to add"
)
class DeletionItem(BaseModel):
entityName: str = Field(
..., description="The name of the entity containing the observations"
)
observations: List[str] = Field(
..., description="An array of observations to delete"
)
class AddObservationsRequest(BaseModel):
observations: List[dict] = Field(
observations: List[ObservationItem] = Field(
...,
description="Each item includes an entity name and an array of observations to add",
description="A list of observation additions, each specifying an entity and contents to add",
)
class DeleteObservationsRequest(BaseModel):
deletions: List[DeletionItem] = Field(
...,
description="A list of observation deletions, each specifying an entity and observations to remove",
)
@ -132,13 +157,6 @@ class DeleteEntitiesRequest(BaseModel):
)
class DeleteObservationsRequest(BaseModel):
deletions: List[dict] = Field(
...,
description="Each item includes an entity name and an array of observations to delete",
)
class DeleteRelationsRequest(BaseModel):
relations: List[Relation] = Field(
..., description="An array of relations to delete"
@ -183,15 +201,17 @@ def create_relations(req: CreateRelationsRequest):
def add_observations(req: AddObservationsRequest):
graph = read_graph_file()
results = []
for obs in req.observations:
name = obs["entityName"]
contents = obs["contents"]
name = obs.entityName
contents = obs.contents
entity = next((e for e in graph.entities if e.name == name), None)
if not entity:
raise HTTPException(status_code=404, detail=f"Entity {name} not found")
added = [c for c in contents if c not in entity.observations]
entity.observations.extend(added)
results.append({"entityName": name, "addedObservations": added})
save_graph(graph)
return results
@ -212,14 +232,16 @@ def delete_entities(req: DeleteEntitiesRequest):
@app.post("/delete_observations", summary="Delete specific observations from entities")
def delete_observations(req: DeleteObservationsRequest):
graph = read_graph_file()
for deletion in req.deletions:
name = deletion["entityName"]
obs_to_delete = deletion["observations"]
name = deletion.entityName
to_delete = deletion.observations
entity = next((e for e in graph.entities if e.name == name), None)
if entity:
entity.observations = [
o for o in entity.observations if o not in obs_to_delete
obs for obs in entity.observations if obs not in to_delete
]
save_graph(graph)
return {"message": "Observations deleted successfully"}