diff --git a/servers/memory/main.py b/servers/memory/main.py new file mode 100644 index 0000000..9e62b8d --- /dev/null +++ b/servers/memory/main.py @@ -0,0 +1,275 @@ +from fastapi import FastAPI, HTTPException, Body +from fastapi.middleware.cors import CORSMiddleware + + +from pydantic import BaseModel, Field +from typing import List, Literal, Union +from pathlib import Path +import json +import os + +app = FastAPI( + title="Knowledge Graph Server", + version="1.0.0", + description="A structured knowledge graph memory system that supports entity and relation storage, observation tracking, and manipulation.", +) + +origins = ["*"] + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# ----- Persistence Setup ----- +MEMORY_FILE_PATH_ENV = os.getenv("MEMORY_FILE_PATH", "memory.json") +MEMORY_FILE_PATH = Path( + MEMORY_FILE_PATH_ENV + if Path(MEMORY_FILE_PATH_ENV).is_absolute() + else Path(__file__).parent / MEMORY_FILE_PATH_ENV +) + + +# ----- Data Models ----- +class Entity(BaseModel): + name: str = Field(..., description="The name of the entity") + entityType: str = Field(..., description="The type of the entity") + observations: List[str] = Field( + ..., description="An array of observation contents associated with the entity" + ) + + +class Relation(BaseModel): + from_: str = Field( + ..., + alias="from", + description="The name of the entity where the relation starts", + ) + to: str = Field(..., description="The name of the entity where the relation ends") + relationType: str = Field(..., description="The type of the relation") + + +class KnowledgeGraph(BaseModel): + entities: List[Entity] + relations: List[Relation] + + +class EntityWrapper(BaseModel): + type: Literal["entity"] + name: str + entityType: str + observations: List[str] + + +class RelationWrapper(BaseModel): + type: Literal["relation"] + from_: str = Field(..., alias="from") + to: str + relationType: str + + +# ----- I/O Handlers ----- +def read_graph_file() -> KnowledgeGraph: + if not MEMORY_FILE_PATH.exists(): + return KnowledgeGraph(entities=[], relations=[]) + with open(MEMORY_FILE_PATH, "r", encoding="utf-8") as f: + lines = [line for line in f if line.strip()] + entities = [] + relations = [] + for line in lines: + print(line) + item = json.loads(line) + if item["type"] == "entity": + entities.append( + Entity( + name=item["name"], + entityType=item["entityType"], + observations=item["observations"], + ) + ) + elif item["type"] == "relation": + relations.append(Relation(**item)) + + return KnowledgeGraph(entities=entities, relations=relations) + + +def save_graph(graph: KnowledgeGraph): + lines = [json.dumps({"type": "entity", **e.dict()}) for e in graph.entities] + [ + json.dumps({"type": "relation", **r.dict(by_alias=True)}) + for r in graph.relations + ] + with open(MEMORY_FILE_PATH, "w", encoding="utf-8") as f: + f.write("\n".join(lines)) + + +# ----- Request Models ----- + + +class CreateEntitiesRequest(BaseModel): + entities: List[Entity] = Field(..., description="List of entities to create") + + +class CreateRelationsRequest(BaseModel): + relations: List[Relation] = Field( + ..., description="List of relations to create. All must be in active voice." + ) + + +class AddObservationsRequest(BaseModel): + observations: List[dict] = Field( + ..., + description="Each item includes an entity name and an array of observations to add", + ) + + +class DeleteEntitiesRequest(BaseModel): + entityNames: List[str] = Field( + ..., description="An array of entity names to delete" + ) + + +class DeleteObservationsRequest(BaseModel): + deletions: List[dict] = Field( + ..., + description="Each item includes an entity name and an array of observations to delete", + ) + + +class DeleteRelationsRequest(BaseModel): + relations: List[Relation] = Field( + ..., description="An array of relations to delete" + ) + + +class SearchNodesRequest(BaseModel): + query: str = Field( + ..., + description="The search query to match against entity names, types, and observation content", + ) + + +class OpenNodesRequest(BaseModel): + names: List[str] = Field(..., description="An array of entity names to retrieve") + + +# ----- Endpoints ----- + + +@app.post("/create_entities", summary="Create multiple entities in the graph") +def create_entities(req: CreateEntitiesRequest): + graph = read_graph_file() + existing_names = {e.name for e in graph.entities} + new_entities = [e for e in req.entities if e.name not in existing_names] + graph.entities.extend(new_entities) + save_graph(graph) + return new_entities + + +@app.post("/create_relations", summary="Create multiple relations between entities") +def create_relations(req: CreateRelationsRequest): + graph = read_graph_file() + existing = {(r.from_, r.to, r.relationType) for r in graph.relations} + new = [r for r in req.relations if (r.from_, r.to, r.relationType) not in existing] + graph.relations.extend(new) + save_graph(graph) + return new + + +@app.post("/add_observations", summary="Add new observations to existing entities") +def add_observations(req: AddObservationsRequest): + graph = read_graph_file() + results = [] + for obs in req.observations: + name = obs["entityName"] + contents = obs["contents"] + entity = next((e for e in graph.entities if e.name == name), None) + if not entity: + raise HTTPException(status_code=404, detail=f"Entity {name} not found") + added = [c for c in contents if c not in entity.observations] + entity.observations.extend(added) + results.append({"entityName": name, "addedObservations": added}) + save_graph(graph) + return results + + +@app.post("/delete_entities", summary="Delete entities and associated relations") +def delete_entities(req: DeleteEntitiesRequest): + graph = read_graph_file() + graph.entities = [e for e in graph.entities if e.name not in req.entityNames] + graph.relations = [ + r + for r in graph.relations + if r.from_ not in req.entityNames and r.to not in req.entityNames + ] + save_graph(graph) + return {"message": "Entities deleted successfully"} + + +@app.post("/delete_observations", summary="Delete specific observations from entities") +def delete_observations(req: DeleteObservationsRequest): + graph = read_graph_file() + for deletion in req.deletions: + name = deletion["entityName"] + obs_to_delete = deletion["observations"] + entity = next((e for e in graph.entities if e.name == name), None) + if entity: + entity.observations = [ + o for o in entity.observations if o not in obs_to_delete + ] + save_graph(graph) + return {"message": "Observations deleted successfully"} + + +@app.post("/delete_relations", summary="Delete relations from the graph") +def delete_relations(req: DeleteRelationsRequest): + graph = read_graph_file() + del_set = {(r.from_, r.to, r.relationType) for r in req.relations} + graph.relations = [ + r for r in graph.relations if (r.from_, r.to, r.relationType) not in del_set + ] + save_graph(graph) + return {"message": "Relations deleted successfully"} + + +@app.get( + "/read_graph", response_model=KnowledgeGraph, summary="Read entire knowledge graph" +) +def read_graph(): + return read_graph_file() + + +@app.post( + "/search_nodes", + response_model=KnowledgeGraph, + summary="Search for nodes by keyword", +) +def search_nodes(req: SearchNodesRequest): + graph = read_graph_file() + print(graph) + entities = [ + e + for e in graph.entities + if req.query.lower() in e.name.lower() + or req.query.lower() in e.entityType.lower() + or any(req.query.lower() in o.lower() for o in e.observations) + ] + names = {e.name for e in entities} + relations = [r for r in graph.relations if r.from_ in names and r.to in names] + + print(names, relations) + return KnowledgeGraph(entities=entities, relations=relations) + + +@app.post( + "/open_nodes", response_model=KnowledgeGraph, summary="Open specific nodes by name" +) +def open_nodes(req: OpenNodesRequest): + graph = read_graph_file() + entities = [e for e in graph.entities if e.name in req.names] + names = {e.name for e in entities} + relations = [r for r in graph.relations if r.from_ in names and r.to in names] + return KnowledgeGraph(entities=entities, relations=relations) diff --git a/servers/memory/requirements.txt b/servers/memory/requirements.txt new file mode 100644 index 0000000..b08d10d --- /dev/null +++ b/servers/memory/requirements.txt @@ -0,0 +1,7 @@ +fastapi +uvicorn[standard] +pydantic +python-multipart + +pytz +python-dateutil \ No newline at end of file