from fastapi import FastAPI, HTTPException, Body from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from typing import List, Literal, Union from pathlib import Path import json import os app = FastAPI( title="Knowledge Graph Server", version="1.0.0", description="A structured knowledge graph memory system that supports entity and relation storage, observation tracking, and manipulation.", ) origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ----- Persistence Setup ----- MEMORY_FILE_PATH_ENV = os.getenv("MEMORY_FILE_PATH", "memory.json") MEMORY_FILE_PATH = Path( MEMORY_FILE_PATH_ENV if Path(MEMORY_FILE_PATH_ENV).is_absolute() else Path(__file__).parent / MEMORY_FILE_PATH_ENV ) # ----- Data Models ----- class Entity(BaseModel): name: str = Field(..., description="The name of the entity") entityType: str = Field(..., description="The type of the entity") observations: List[str] = Field( ..., description="An array of observation contents associated with the entity" ) class Relation(BaseModel): from_: str = Field( ..., alias="from", description="The name of the entity where the relation starts", ) to: str = Field(..., description="The name of the entity where the relation ends") relationType: str = Field(..., description="The type of the relation") class KnowledgeGraph(BaseModel): entities: List[Entity] relations: List[Relation] class EntityWrapper(BaseModel): type: Literal["entity"] name: str entityType: str observations: List[str] class RelationWrapper(BaseModel): type: Literal["relation"] from_: str = Field(..., alias="from") to: str relationType: str # ----- I/O Handlers ----- def read_graph_file() -> KnowledgeGraph: if not MEMORY_FILE_PATH.exists(): return KnowledgeGraph(entities=[], relations=[]) with open(MEMORY_FILE_PATH, "r", encoding="utf-8") as f: lines = [line for line in f if line.strip()] entities = [] relations = [] for line in lines: print(line) item = json.loads(line) if item["type"] == "entity": entities.append( Entity( name=item["name"], entityType=item["entityType"], observations=item["observations"], ) ) elif item["type"] == "relation": relations.append(Relation(**item)) return KnowledgeGraph(entities=entities, relations=relations) def save_graph(graph: KnowledgeGraph): lines = [json.dumps({"type": "entity", **e.dict()}) for e in graph.entities] + [ json.dumps({"type": "relation", **r.dict(by_alias=True)}) for r in graph.relations ] with open(MEMORY_FILE_PATH, "w", encoding="utf-8") as f: f.write("\n".join(lines)) # ----- Request Models ----- class CreateEntitiesRequest(BaseModel): entities: List[Entity] = Field(..., description="List of entities to create") class CreateRelationsRequest(BaseModel): relations: List[Relation] = Field( ..., description="List of relations to create. All must be in active voice." ) class ObservationItem(BaseModel): entityName: str = Field( ..., description="The name of the entity to add the observations to" ) contents: List[str] = Field( ..., description="An array of observation contents to add" ) class DeletionItem(BaseModel): entityName: str = Field( ..., description="The name of the entity containing the observations" ) observations: List[str] = Field( ..., description="An array of observations to delete" ) class AddObservationsRequest(BaseModel): observations: List[ObservationItem] = Field( ..., description="A list of observation additions, each specifying an entity and contents to add", ) class DeleteObservationsRequest(BaseModel): deletions: List[DeletionItem] = Field( ..., description="A list of observation deletions, each specifying an entity and observations to remove", ) class DeleteEntitiesRequest(BaseModel): entityNames: List[str] = Field( ..., description="An array of entity names to delete" ) class DeleteRelationsRequest(BaseModel): relations: List[Relation] = Field( ..., description="An array of relations to delete" ) class SearchNodesRequest(BaseModel): query: str = Field( ..., description="The search query to match against entity names, types, and observation content", ) class OpenNodesRequest(BaseModel): names: List[str] = Field(..., description="An array of entity names to retrieve") # ----- Endpoints ----- @app.post("/create_entities", summary="Create multiple entities in the graph") def create_entities(req: CreateEntitiesRequest): graph = read_graph_file() existing_names = {e.name for e in graph.entities} new_entities = [e for e in req.entities if e.name not in existing_names] graph.entities.extend(new_entities) save_graph(graph) return new_entities @app.post("/create_relations", summary="Create multiple relations between entities") def create_relations(req: CreateRelationsRequest): graph = read_graph_file() existing = {(r.from_, r.to, r.relationType) for r in graph.relations} new = [r for r in req.relations if (r.from_, r.to, r.relationType) not in existing] graph.relations.extend(new) save_graph(graph) return new @app.post("/add_observations", summary="Add new observations to existing entities") def add_observations(req: AddObservationsRequest): graph = read_graph_file() results = [] for obs in req.observations: name = obs.entityName.lower() contents = obs.contents entity = next((e for e in graph.entities if e.name == name), None) if not entity: raise HTTPException(status_code=404, detail=f"Entity {name} not found") added = [c for c in contents if c not in entity.observations] entity.observations.extend(added) results.append({"entityName": name, "addedObservations": added}) save_graph(graph) return results @app.post("/delete_entities", summary="Delete entities and associated relations") def delete_entities(req: DeleteEntitiesRequest): graph = read_graph_file() graph.entities = [e for e in graph.entities if e.name not in req.entityNames] graph.relations = [ r for r in graph.relations if r.from_ not in req.entityNames and r.to not in req.entityNames ] save_graph(graph) return {"message": "Entities deleted successfully"} @app.post("/delete_observations", summary="Delete specific observations from entities") def delete_observations(req: DeleteObservationsRequest): graph = read_graph_file() for deletion in req.deletions: name = deletion.entityName.lower() to_delete = deletion.observations entity = next((e for e in graph.entities if e.name == name), None) if entity: entity.observations = [ obs for obs in entity.observations if obs not in to_delete ] save_graph(graph) return {"message": "Observations deleted successfully"} @app.post("/delete_relations", summary="Delete relations from the graph") def delete_relations(req: DeleteRelationsRequest): graph = read_graph_file() del_set = {(r.from_, r.to, r.relationType) for r in req.relations} graph.relations = [ r for r in graph.relations if (r.from_, r.to, r.relationType) not in del_set ] save_graph(graph) return {"message": "Relations deleted successfully"} @app.get( "/read_graph", response_model=KnowledgeGraph, summary="Read entire knowledge graph" ) def read_graph(): return read_graph_file() @app.post( "/search_nodes", response_model=KnowledgeGraph, summary="Search for nodes by keyword", ) def search_nodes(req: SearchNodesRequest): graph = read_graph_file() print(graph) entities = [ e for e in graph.entities if req.query.lower() in e.name.lower() or req.query.lower() in e.entityType.lower() or any(req.query.lower() in o.lower() for o in e.observations) ] names = {e.name for e in entities} relations = [r for r in graph.relations if r.from_ in names and r.to in names] print(names, relations) return KnowledgeGraph(entities=entities, relations=relations) @app.post( "/open_nodes", response_model=KnowledgeGraph, summary="Open specific nodes by name" ) def open_nodes(req: OpenNodesRequest): graph = read_graph_file() entities = [e for e in graph.entities if e.name in req.names] names = {e.name for e in entities} relations = [r for r in graph.relations if r.from_ in names and r.to in names] return KnowledgeGraph(entities=entities, relations=relations)