mirror of
https://github.com/open-webui/openapi-servers
synced 2025-06-26 18:17:04 +00:00
298 lines
8.9 KiB
Python
298 lines
8.9 KiB
Python
from fastapi import FastAPI, HTTPException, Body
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
|
from pydantic import BaseModel, Field
|
|
from typing import List, Literal, Union
|
|
from pathlib import Path
|
|
import json
|
|
import os
|
|
|
|
app = FastAPI(
|
|
title="Knowledge Graph Server",
|
|
version="1.0.0",
|
|
description="A structured knowledge graph memory system that supports entity and relation storage, observation tracking, and manipulation.",
|
|
)
|
|
|
|
origins = ["*"]
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# ----- Persistence Setup -----
|
|
MEMORY_FILE_PATH_ENV = os.getenv("MEMORY_FILE_PATH", "memory.json")
|
|
MEMORY_FILE_PATH = Path(
|
|
MEMORY_FILE_PATH_ENV
|
|
if Path(MEMORY_FILE_PATH_ENV).is_absolute()
|
|
else Path(__file__).parent / MEMORY_FILE_PATH_ENV
|
|
)
|
|
|
|
|
|
# ----- Data Models -----
|
|
class Entity(BaseModel):
|
|
name: str = Field(..., description="The name of the entity")
|
|
entityType: str = Field(..., description="The type of the entity")
|
|
observations: List[str] = Field(
|
|
..., description="An array of observation contents associated with the entity"
|
|
)
|
|
|
|
|
|
class Relation(BaseModel):
|
|
from_: str = Field(
|
|
...,
|
|
alias="from",
|
|
description="The name of the entity where the relation starts",
|
|
)
|
|
to: str = Field(..., description="The name of the entity where the relation ends")
|
|
relationType: str = Field(..., description="The type of the relation")
|
|
|
|
|
|
class KnowledgeGraph(BaseModel):
|
|
entities: List[Entity]
|
|
relations: List[Relation]
|
|
|
|
|
|
class EntityWrapper(BaseModel):
|
|
type: Literal["entity"]
|
|
name: str
|
|
entityType: str
|
|
observations: List[str]
|
|
|
|
|
|
class RelationWrapper(BaseModel):
|
|
type: Literal["relation"]
|
|
from_: str = Field(..., alias="from")
|
|
to: str
|
|
relationType: str
|
|
|
|
|
|
# ----- I/O Handlers -----
|
|
def read_graph_file() -> KnowledgeGraph:
|
|
if not MEMORY_FILE_PATH.exists():
|
|
return KnowledgeGraph(entities=[], relations=[])
|
|
with open(MEMORY_FILE_PATH, "r", encoding="utf-8") as f:
|
|
lines = [line for line in f if line.strip()]
|
|
entities = []
|
|
relations = []
|
|
for line in lines:
|
|
print(line)
|
|
item = json.loads(line)
|
|
if item["type"] == "entity":
|
|
entities.append(
|
|
Entity(
|
|
name=item["name"],
|
|
entityType=item["entityType"],
|
|
observations=item["observations"],
|
|
)
|
|
)
|
|
elif item["type"] == "relation":
|
|
relations.append(Relation(**item))
|
|
|
|
return KnowledgeGraph(entities=entities, relations=relations)
|
|
|
|
|
|
def save_graph(graph: KnowledgeGraph):
|
|
lines = [json.dumps({"type": "entity", **e.dict()}) for e in graph.entities] + [
|
|
json.dumps({"type": "relation", **r.dict(by_alias=True)})
|
|
for r in graph.relations
|
|
]
|
|
with open(MEMORY_FILE_PATH, "w", encoding="utf-8") as f:
|
|
f.write("\n".join(lines))
|
|
|
|
|
|
# ----- Request Models -----
|
|
|
|
|
|
class CreateEntitiesRequest(BaseModel):
|
|
entities: List[Entity] = Field(..., description="List of entities to create")
|
|
|
|
|
|
class CreateRelationsRequest(BaseModel):
|
|
relations: List[Relation] = Field(
|
|
..., description="List of relations to create. All must be in active voice."
|
|
)
|
|
|
|
|
|
class ObservationItem(BaseModel):
|
|
entityName: str = Field(
|
|
..., description="The name of the entity to add the observations to"
|
|
)
|
|
contents: List[str] = Field(
|
|
..., description="An array of observation contents to add"
|
|
)
|
|
|
|
|
|
class DeletionItem(BaseModel):
|
|
entityName: str = Field(
|
|
..., description="The name of the entity containing the observations"
|
|
)
|
|
observations: List[str] = Field(
|
|
..., description="An array of observations to delete"
|
|
)
|
|
|
|
|
|
class AddObservationsRequest(BaseModel):
|
|
observations: List[ObservationItem] = Field(
|
|
...,
|
|
description="A list of observation additions, each specifying an entity and contents to add",
|
|
)
|
|
|
|
|
|
class DeleteObservationsRequest(BaseModel):
|
|
deletions: List[DeletionItem] = Field(
|
|
...,
|
|
description="A list of observation deletions, each specifying an entity and observations to remove",
|
|
)
|
|
|
|
|
|
class DeleteEntitiesRequest(BaseModel):
|
|
entityNames: List[str] = Field(
|
|
..., description="An array of entity names to delete"
|
|
)
|
|
|
|
|
|
class DeleteRelationsRequest(BaseModel):
|
|
relations: List[Relation] = Field(
|
|
..., description="An array of relations to delete"
|
|
)
|
|
|
|
|
|
class SearchNodesRequest(BaseModel):
|
|
query: str = Field(
|
|
...,
|
|
description="The search query to match against entity names, types, and observation content",
|
|
)
|
|
|
|
|
|
class OpenNodesRequest(BaseModel):
|
|
names: List[str] = Field(..., description="An array of entity names to retrieve")
|
|
|
|
|
|
# ----- Endpoints -----
|
|
|
|
|
|
@app.post("/create_entities", summary="Create multiple entities in the graph")
|
|
def create_entities(req: CreateEntitiesRequest):
|
|
graph = read_graph_file()
|
|
existing_names = {e.name for e in graph.entities}
|
|
new_entities = [e for e in req.entities if e.name not in existing_names]
|
|
graph.entities.extend(new_entities)
|
|
save_graph(graph)
|
|
return new_entities
|
|
|
|
|
|
@app.post("/create_relations", summary="Create multiple relations between entities")
|
|
def create_relations(req: CreateRelationsRequest):
|
|
graph = read_graph_file()
|
|
existing = {(r.from_, r.to, r.relationType) for r in graph.relations}
|
|
new = [r for r in req.relations if (r.from_, r.to, r.relationType) not in existing]
|
|
graph.relations.extend(new)
|
|
save_graph(graph)
|
|
return new
|
|
|
|
|
|
@app.post("/add_observations", summary="Add new observations to existing entities")
|
|
def add_observations(req: AddObservationsRequest):
|
|
graph = read_graph_file()
|
|
results = []
|
|
|
|
for obs in req.observations:
|
|
name = obs.entityName.lower()
|
|
contents = obs.contents
|
|
entity = next((e for e in graph.entities if e.name == name), None)
|
|
if not entity:
|
|
raise HTTPException(status_code=404, detail=f"Entity {name} not found")
|
|
added = [c for c in contents if c not in entity.observations]
|
|
entity.observations.extend(added)
|
|
results.append({"entityName": name, "addedObservations": added})
|
|
|
|
save_graph(graph)
|
|
return results
|
|
|
|
|
|
@app.post("/delete_entities", summary="Delete entities and associated relations")
|
|
def delete_entities(req: DeleteEntitiesRequest):
|
|
graph = read_graph_file()
|
|
graph.entities = [e for e in graph.entities if e.name not in req.entityNames]
|
|
graph.relations = [
|
|
r
|
|
for r in graph.relations
|
|
if r.from_ not in req.entityNames and r.to not in req.entityNames
|
|
]
|
|
save_graph(graph)
|
|
return {"message": "Entities deleted successfully"}
|
|
|
|
|
|
@app.post("/delete_observations", summary="Delete specific observations from entities")
|
|
def delete_observations(req: DeleteObservationsRequest):
|
|
graph = read_graph_file()
|
|
|
|
for deletion in req.deletions:
|
|
name = deletion.entityName.lower()
|
|
to_delete = deletion.observations
|
|
entity = next((e for e in graph.entities if e.name == name), None)
|
|
if entity:
|
|
entity.observations = [
|
|
obs for obs in entity.observations if obs not in to_delete
|
|
]
|
|
|
|
save_graph(graph)
|
|
return {"message": "Observations deleted successfully"}
|
|
|
|
|
|
@app.post("/delete_relations", summary="Delete relations from the graph")
|
|
def delete_relations(req: DeleteRelationsRequest):
|
|
graph = read_graph_file()
|
|
del_set = {(r.from_, r.to, r.relationType) for r in req.relations}
|
|
graph.relations = [
|
|
r for r in graph.relations if (r.from_, r.to, r.relationType) not in del_set
|
|
]
|
|
save_graph(graph)
|
|
return {"message": "Relations deleted successfully"}
|
|
|
|
|
|
@app.get(
|
|
"/read_graph", response_model=KnowledgeGraph, summary="Read entire knowledge graph"
|
|
)
|
|
def read_graph():
|
|
return read_graph_file()
|
|
|
|
|
|
@app.post(
|
|
"/search_nodes",
|
|
response_model=KnowledgeGraph,
|
|
summary="Search for nodes by keyword",
|
|
)
|
|
def search_nodes(req: SearchNodesRequest):
|
|
graph = read_graph_file()
|
|
print(graph)
|
|
entities = [
|
|
e
|
|
for e in graph.entities
|
|
if req.query.lower() in e.name.lower()
|
|
or req.query.lower() in e.entityType.lower()
|
|
or any(req.query.lower() in o.lower() for o in e.observations)
|
|
]
|
|
names = {e.name for e in entities}
|
|
relations = [r for r in graph.relations if r.from_ in names and r.to in names]
|
|
|
|
print(names, relations)
|
|
return KnowledgeGraph(entities=entities, relations=relations)
|
|
|
|
|
|
@app.post(
|
|
"/open_nodes", response_model=KnowledgeGraph, summary="Open specific nodes by name"
|
|
)
|
|
def open_nodes(req: OpenNodesRequest):
|
|
graph = read_graph_file()
|
|
entities = [e for e in graph.entities if e.name in req.names]
|
|
names = {e.name for e in entities}
|
|
relations = [r for r in graph.relations if r.from_ in names and r.to in names]
|
|
return KnowledgeGraph(entities=entities, relations=relations)
|