mirror of
https://github.com/open-webui/openapi-servers
synced 2025-06-26 18:17:04 +00:00
feat: memory server
This commit is contained in:
parent
ff8de67527
commit
3117338af3
275
servers/memory/main.py
Normal file
275
servers/memory/main.py
Normal file
@ -0,0 +1,275 @@
|
||||
from fastapi import FastAPI, HTTPException, Body
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Literal, Union
|
||||
from pathlib import Path
|
||||
import json
|
||||
import os
|
||||
|
||||
app = FastAPI(
|
||||
title="Knowledge Graph Server",
|
||||
version="1.0.0",
|
||||
description="A structured knowledge graph memory system that supports entity and relation storage, observation tracking, and manipulation.",
|
||||
)
|
||||
|
||||
origins = ["*"]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# ----- Persistence Setup -----
|
||||
MEMORY_FILE_PATH_ENV = os.getenv("MEMORY_FILE_PATH", "memory.json")
|
||||
MEMORY_FILE_PATH = Path(
|
||||
MEMORY_FILE_PATH_ENV
|
||||
if Path(MEMORY_FILE_PATH_ENV).is_absolute()
|
||||
else Path(__file__).parent / MEMORY_FILE_PATH_ENV
|
||||
)
|
||||
|
||||
|
||||
# ----- Data Models -----
|
||||
class Entity(BaseModel):
|
||||
name: str = Field(..., description="The name of the entity")
|
||||
entityType: str = Field(..., description="The type of the entity")
|
||||
observations: List[str] = Field(
|
||||
..., description="An array of observation contents associated with the entity"
|
||||
)
|
||||
|
||||
|
||||
class Relation(BaseModel):
|
||||
from_: str = Field(
|
||||
...,
|
||||
alias="from",
|
||||
description="The name of the entity where the relation starts",
|
||||
)
|
||||
to: str = Field(..., description="The name of the entity where the relation ends")
|
||||
relationType: str = Field(..., description="The type of the relation")
|
||||
|
||||
|
||||
class KnowledgeGraph(BaseModel):
|
||||
entities: List[Entity]
|
||||
relations: List[Relation]
|
||||
|
||||
|
||||
class EntityWrapper(BaseModel):
|
||||
type: Literal["entity"]
|
||||
name: str
|
||||
entityType: str
|
||||
observations: List[str]
|
||||
|
||||
|
||||
class RelationWrapper(BaseModel):
|
||||
type: Literal["relation"]
|
||||
from_: str = Field(..., alias="from")
|
||||
to: str
|
||||
relationType: str
|
||||
|
||||
|
||||
# ----- I/O Handlers -----
|
||||
def read_graph_file() -> KnowledgeGraph:
|
||||
if not MEMORY_FILE_PATH.exists():
|
||||
return KnowledgeGraph(entities=[], relations=[])
|
||||
with open(MEMORY_FILE_PATH, "r", encoding="utf-8") as f:
|
||||
lines = [line for line in f if line.strip()]
|
||||
entities = []
|
||||
relations = []
|
||||
for line in lines:
|
||||
print(line)
|
||||
item = json.loads(line)
|
||||
if item["type"] == "entity":
|
||||
entities.append(
|
||||
Entity(
|
||||
name=item["name"],
|
||||
entityType=item["entityType"],
|
||||
observations=item["observations"],
|
||||
)
|
||||
)
|
||||
elif item["type"] == "relation":
|
||||
relations.append(Relation(**item))
|
||||
|
||||
return KnowledgeGraph(entities=entities, relations=relations)
|
||||
|
||||
|
||||
def save_graph(graph: KnowledgeGraph):
|
||||
lines = [json.dumps({"type": "entity", **e.dict()}) for e in graph.entities] + [
|
||||
json.dumps({"type": "relation", **r.dict(by_alias=True)})
|
||||
for r in graph.relations
|
||||
]
|
||||
with open(MEMORY_FILE_PATH, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(lines))
|
||||
|
||||
|
||||
# ----- Request Models -----
|
||||
|
||||
|
||||
class CreateEntitiesRequest(BaseModel):
|
||||
entities: List[Entity] = Field(..., description="List of entities to create")
|
||||
|
||||
|
||||
class CreateRelationsRequest(BaseModel):
|
||||
relations: List[Relation] = Field(
|
||||
..., description="List of relations to create. All must be in active voice."
|
||||
)
|
||||
|
||||
|
||||
class AddObservationsRequest(BaseModel):
|
||||
observations: List[dict] = Field(
|
||||
...,
|
||||
description="Each item includes an entity name and an array of observations to add",
|
||||
)
|
||||
|
||||
|
||||
class DeleteEntitiesRequest(BaseModel):
|
||||
entityNames: List[str] = Field(
|
||||
..., description="An array of entity names to delete"
|
||||
)
|
||||
|
||||
|
||||
class DeleteObservationsRequest(BaseModel):
|
||||
deletions: List[dict] = Field(
|
||||
...,
|
||||
description="Each item includes an entity name and an array of observations to delete",
|
||||
)
|
||||
|
||||
|
||||
class DeleteRelationsRequest(BaseModel):
|
||||
relations: List[Relation] = Field(
|
||||
..., description="An array of relations to delete"
|
||||
)
|
||||
|
||||
|
||||
class SearchNodesRequest(BaseModel):
|
||||
query: str = Field(
|
||||
...,
|
||||
description="The search query to match against entity names, types, and observation content",
|
||||
)
|
||||
|
||||
|
||||
class OpenNodesRequest(BaseModel):
|
||||
names: List[str] = Field(..., description="An array of entity names to retrieve")
|
||||
|
||||
|
||||
# ----- Endpoints -----
|
||||
|
||||
|
||||
@app.post("/create_entities", summary="Create multiple entities in the graph")
|
||||
def create_entities(req: CreateEntitiesRequest):
|
||||
graph = read_graph_file()
|
||||
existing_names = {e.name for e in graph.entities}
|
||||
new_entities = [e for e in req.entities if e.name not in existing_names]
|
||||
graph.entities.extend(new_entities)
|
||||
save_graph(graph)
|
||||
return new_entities
|
||||
|
||||
|
||||
@app.post("/create_relations", summary="Create multiple relations between entities")
|
||||
def create_relations(req: CreateRelationsRequest):
|
||||
graph = read_graph_file()
|
||||
existing = {(r.from_, r.to, r.relationType) for r in graph.relations}
|
||||
new = [r for r in req.relations if (r.from_, r.to, r.relationType) not in existing]
|
||||
graph.relations.extend(new)
|
||||
save_graph(graph)
|
||||
return new
|
||||
|
||||
|
||||
@app.post("/add_observations", summary="Add new observations to existing entities")
|
||||
def add_observations(req: AddObservationsRequest):
|
||||
graph = read_graph_file()
|
||||
results = []
|
||||
for obs in req.observations:
|
||||
name = obs["entityName"]
|
||||
contents = obs["contents"]
|
||||
entity = next((e for e in graph.entities if e.name == name), None)
|
||||
if not entity:
|
||||
raise HTTPException(status_code=404, detail=f"Entity {name} not found")
|
||||
added = [c for c in contents if c not in entity.observations]
|
||||
entity.observations.extend(added)
|
||||
results.append({"entityName": name, "addedObservations": added})
|
||||
save_graph(graph)
|
||||
return results
|
||||
|
||||
|
||||
@app.post("/delete_entities", summary="Delete entities and associated relations")
|
||||
def delete_entities(req: DeleteEntitiesRequest):
|
||||
graph = read_graph_file()
|
||||
graph.entities = [e for e in graph.entities if e.name not in req.entityNames]
|
||||
graph.relations = [
|
||||
r
|
||||
for r in graph.relations
|
||||
if r.from_ not in req.entityNames and r.to not in req.entityNames
|
||||
]
|
||||
save_graph(graph)
|
||||
return {"message": "Entities deleted successfully"}
|
||||
|
||||
|
||||
@app.post("/delete_observations", summary="Delete specific observations from entities")
|
||||
def delete_observations(req: DeleteObservationsRequest):
|
||||
graph = read_graph_file()
|
||||
for deletion in req.deletions:
|
||||
name = deletion["entityName"]
|
||||
obs_to_delete = deletion["observations"]
|
||||
entity = next((e for e in graph.entities if e.name == name), None)
|
||||
if entity:
|
||||
entity.observations = [
|
||||
o for o in entity.observations if o not in obs_to_delete
|
||||
]
|
||||
save_graph(graph)
|
||||
return {"message": "Observations deleted successfully"}
|
||||
|
||||
|
||||
@app.post("/delete_relations", summary="Delete relations from the graph")
|
||||
def delete_relations(req: DeleteRelationsRequest):
|
||||
graph = read_graph_file()
|
||||
del_set = {(r.from_, r.to, r.relationType) for r in req.relations}
|
||||
graph.relations = [
|
||||
r for r in graph.relations if (r.from_, r.to, r.relationType) not in del_set
|
||||
]
|
||||
save_graph(graph)
|
||||
return {"message": "Relations deleted successfully"}
|
||||
|
||||
|
||||
@app.get(
|
||||
"/read_graph", response_model=KnowledgeGraph, summary="Read entire knowledge graph"
|
||||
)
|
||||
def read_graph():
|
||||
return read_graph_file()
|
||||
|
||||
|
||||
@app.post(
|
||||
"/search_nodes",
|
||||
response_model=KnowledgeGraph,
|
||||
summary="Search for nodes by keyword",
|
||||
)
|
||||
def search_nodes(req: SearchNodesRequest):
|
||||
graph = read_graph_file()
|
||||
print(graph)
|
||||
entities = [
|
||||
e
|
||||
for e in graph.entities
|
||||
if req.query.lower() in e.name.lower()
|
||||
or req.query.lower() in e.entityType.lower()
|
||||
or any(req.query.lower() in o.lower() for o in e.observations)
|
||||
]
|
||||
names = {e.name for e in entities}
|
||||
relations = [r for r in graph.relations if r.from_ in names and r.to in names]
|
||||
|
||||
print(names, relations)
|
||||
return KnowledgeGraph(entities=entities, relations=relations)
|
||||
|
||||
|
||||
@app.post(
|
||||
"/open_nodes", response_model=KnowledgeGraph, summary="Open specific nodes by name"
|
||||
)
|
||||
def open_nodes(req: OpenNodesRequest):
|
||||
graph = read_graph_file()
|
||||
entities = [e for e in graph.entities if e.name in req.names]
|
||||
names = {e.name for e in entities}
|
||||
relations = [r for r in graph.relations if r.from_ in names and r.to in names]
|
||||
return KnowledgeGraph(entities=entities, relations=relations)
|
7
servers/memory/requirements.txt
Normal file
7
servers/memory/requirements.txt
Normal file
@ -0,0 +1,7 @@
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
pydantic
|
||||
python-multipart
|
||||
|
||||
pytz
|
||||
python-dateutil
|
Loading…
Reference in New Issue
Block a user