mirror of
https://github.com/open-webui/openapi-servers
synced 2025-06-26 18:17:04 +00:00
feat: mcp proxy
This commit is contained in:
parent
50f1a7eb45
commit
72681052cd
@ -0,0 +1,160 @@
|
||||
from fastapi import FastAPI, Body
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import create_model
|
||||
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters, types
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from typing import Dict, Any
|
||||
|
||||
import asyncio
|
||||
import uvicorn
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
async def create_dynamic_endpoints(app: FastAPI, session: ClientSession):
|
||||
tools_result = await session.list_tools()
|
||||
tools = tools_result.tools
|
||||
|
||||
for tool in tools:
|
||||
print(tool)
|
||||
endpoint_name = tool.name
|
||||
endpoint_description = tool.description
|
||||
schema = tool.inputSchema
|
||||
|
||||
# Dynamically creating a Pydantic model for validation and openAPI coverage
|
||||
model_fields = {}
|
||||
required_fields = schema.get("required", [])
|
||||
|
||||
for param_name, param_schema in schema["properties"].items():
|
||||
param_type = param_schema["type"]
|
||||
param_desc = param_schema.get("description", "")
|
||||
python_type = str # default
|
||||
|
||||
if param_type == "string":
|
||||
python_type = str
|
||||
elif param_type == "integer":
|
||||
python_type = int
|
||||
elif param_type == "boolean":
|
||||
python_type = bool
|
||||
elif param_type == "number":
|
||||
python_type = float
|
||||
elif param_type == "object":
|
||||
python_type = Dict[str, Any]
|
||||
elif param_type == "array":
|
||||
python_type = list
|
||||
# Expand as needed. PRs welcome!
|
||||
|
||||
default_value = ... if param_name in required_fields else None
|
||||
model_fields[param_name] = (
|
||||
python_type,
|
||||
Body(default_value, description=param_desc),
|
||||
)
|
||||
|
||||
FormModel = create_model(f"{endpoint_name}_form_model", **model_fields)
|
||||
|
||||
def make_endpoint_func(endpoint_name: str, FormModel):
|
||||
async def endpoint_func(form_data: FormModel):
|
||||
args = form_data.model_dump()
|
||||
print("Calling tool with arguments:", args)
|
||||
print("Tool name:", endpoint_name)
|
||||
result = await session.call_tool(endpoint_name, arguments=args)
|
||||
return result
|
||||
|
||||
return endpoint_func
|
||||
|
||||
endpoint_func = make_endpoint_func(endpoint_name, FormModel)
|
||||
|
||||
# Add endpoint to FastAPI with tool descriptions
|
||||
app.post(
|
||||
f"/{endpoint_name}",
|
||||
summary=endpoint_name.replace("_", " ").title(),
|
||||
description=endpoint_description,
|
||||
)(endpoint_func)
|
||||
|
||||
|
||||
async def run(host: str, port: int, server_cmd: list[str]):
|
||||
server_params = StdioServerParameters(
|
||||
command=server_cmd[0],
|
||||
args=server_cmd[1:],
|
||||
env={**os.environ},
|
||||
)
|
||||
|
||||
# Open connection to MCP first:
|
||||
async with stdio_client(server_params) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
result = await session.initialize()
|
||||
|
||||
server_name = (
|
||||
result.serverInfo.name
|
||||
if hasattr(result, "serverInfo") and hasattr(result.serverInfo, "name")
|
||||
else None
|
||||
)
|
||||
|
||||
server_description = (
|
||||
f"{server_name.capitalize()} MCP OpenAPI Proxy"
|
||||
if server_name
|
||||
else "Automatically generated API endpoints based on MCP tool schemas."
|
||||
)
|
||||
|
||||
server_version = (
|
||||
result.serverInfo.version
|
||||
if hasattr(result, "serverInfo")
|
||||
and hasattr(result.serverInfo, "version")
|
||||
else "1.0"
|
||||
)
|
||||
|
||||
app = FastAPI(
|
||||
title=server_name if server_name else "MCP OpenAPI Proxy",
|
||||
description=server_description,
|
||||
version=server_version,
|
||||
)
|
||||
|
||||
origins = ["*"]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Dynamic endpoint creation
|
||||
await create_dynamic_endpoints(app, session)
|
||||
|
||||
config = uvicorn.Config(app=app, host=host, port=port, log_level="info")
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
|
||||
def parse_args():
|
||||
# Separate user args before and after "--"
|
||||
if "--" not in sys.argv:
|
||||
print("Usage: python main.py --host 0.0.0.0 --port 8000 -- your_mcp_command")
|
||||
sys.exit(1)
|
||||
|
||||
split_index = sys.argv.index("--")
|
||||
proxy_args = sys.argv[1:split_index]
|
||||
mcp_args = sys.argv[split_index + 1 :]
|
||||
|
||||
parser = argparse.ArgumentParser(description="FastAPI MCP OpenAPI Proxy")
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to listen on")
|
||||
parser.add_argument("--port", type=int, default=8000, help="Port to listen on")
|
||||
|
||||
args = parser.parse_args(proxy_args)
|
||||
|
||||
if not mcp_args:
|
||||
print("Error: You must specify the MCP server command after '--'")
|
||||
sys.exit(1)
|
||||
|
||||
return args.host, args.port, mcp_args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
host, port, server_cmd = parse_args()
|
||||
asyncio.run(run(host, port, server_cmd))
|
Loading…
Reference in New Issue
Block a user