Support Custom basemode and output schema

This commit is contained in:
JinY0ung-Shin 2025-04-28 14:28:49 +09:00 committed by GitHub
parent 4ad458f868
commit e392df0763
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,17 +4,15 @@ from contextlib import AsyncExitStack, asynccontextmanager
from typing import Optional
import uvicorn
from fastapi import FastAPI, Depends
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from starlette.routing import Mount
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.sse import sse_client
from mcpo.utils.main import get_model_fields, get_tool_handler
from mcpo.utils.auth import get_verify_api_key
from mcpo.utils.main import get_model_fields, get_tool_handler
async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
@ -37,21 +35,32 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
for tool in tools:
endpoint_name = tool.name
endpoint_description = tool.description
schema = tool.inputSchema
required_fields = schema.get("required", [])
properties = schema.get("properties", {})
inputSchema = tool.inputSchema
outputSchema = getattr(tool, "outputSchema", None)
custom_fileds = inputSchema.get("$defs", {})
required_fields = inputSchema.get("required", [])
properties = inputSchema.get("properties", {})
form_model_name = f"{endpoint_name}_form_model"
model_fields = get_model_fields(
form_model_name,
properties,
required_fields,
custom_fileds,
)
if outputSchema:
output_model_name = f"{endpoint_name}_output_model"
output_model_fields = get_model_fields(
output_model_name,
outputSchema.get("properties", {}),
outputSchema.get("required", []),
outputSchema.get("$defs", {}),
)
else:
output_model_fields = None
tool_handler = get_tool_handler(
session, endpoint_name, form_model_name, model_fields
session, endpoint_name, form_model_name, model_fields, output_model_fields
)
app.post(
@ -98,7 +107,10 @@ async def lifespan(app: FastAPI):
await create_dynamic_endpoints(app, api_dependency=api_dependency)
yield
if server_type == "sse":
async with sse_client(url=args[0], sse_read_timeout=None) as (reader, writer):
async with sse_client(url=args[0], sse_read_timeout=None) as (
reader,
writer,
):
async with ClientSession(reader, writer) as session:
app.state.session = session
await create_dynamic_endpoints(app, api_dependency=api_dependency)
@ -132,7 +144,6 @@ async def run(
ssl_certfile = kwargs.get("ssl_certfile")
ssl_keyfile = kwargs.get("ssl_keyfile")
path_prefix = kwargs.get("path_prefix") or "/"
main_app = FastAPI(
title=name,
description=description,