mirror of
https://github.com/open-webui/mcpo
synced 2025-06-26 18:26:58 +00:00
Support Custom basemode and output schema
This commit is contained in:
parent
4ad458f868
commit
e392df0763
@ -4,17 +4,15 @@ from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Depends
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from starlette.routing import Mount
|
||||
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from mcpo.utils.main import get_model_fields, get_tool_handler
|
||||
from mcpo.utils.auth import get_verify_api_key
|
||||
from mcpo.utils.main import get_model_fields, get_tool_handler
|
||||
|
||||
|
||||
async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
|
||||
@ -37,21 +35,32 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
|
||||
for tool in tools:
|
||||
endpoint_name = tool.name
|
||||
endpoint_description = tool.description
|
||||
schema = tool.inputSchema
|
||||
|
||||
required_fields = schema.get("required", [])
|
||||
properties = schema.get("properties", {})
|
||||
inputSchema = tool.inputSchema
|
||||
outputSchema = getattr(tool, "outputSchema", None)
|
||||
|
||||
custom_fileds = inputSchema.get("$defs", {})
|
||||
required_fields = inputSchema.get("required", [])
|
||||
properties = inputSchema.get("properties", {})
|
||||
form_model_name = f"{endpoint_name}_form_model"
|
||||
|
||||
model_fields = get_model_fields(
|
||||
form_model_name,
|
||||
properties,
|
||||
required_fields,
|
||||
custom_fileds,
|
||||
)
|
||||
if outputSchema:
|
||||
output_model_name = f"{endpoint_name}_output_model"
|
||||
output_model_fields = get_model_fields(
|
||||
output_model_name,
|
||||
outputSchema.get("properties", {}),
|
||||
outputSchema.get("required", []),
|
||||
outputSchema.get("$defs", {}),
|
||||
)
|
||||
else:
|
||||
output_model_fields = None
|
||||
|
||||
tool_handler = get_tool_handler(
|
||||
session, endpoint_name, form_model_name, model_fields
|
||||
session, endpoint_name, form_model_name, model_fields, output_model_fields
|
||||
)
|
||||
|
||||
app.post(
|
||||
@ -98,7 +107,10 @@ async def lifespan(app: FastAPI):
|
||||
await create_dynamic_endpoints(app, api_dependency=api_dependency)
|
||||
yield
|
||||
if server_type == "sse":
|
||||
async with sse_client(url=args[0], sse_read_timeout=None) as (reader, writer):
|
||||
async with sse_client(url=args[0], sse_read_timeout=None) as (
|
||||
reader,
|
||||
writer,
|
||||
):
|
||||
async with ClientSession(reader, writer) as session:
|
||||
app.state.session = session
|
||||
await create_dynamic_endpoints(app, api_dependency=api_dependency)
|
||||
@ -132,7 +144,6 @@ async def run(
|
||||
ssl_certfile = kwargs.get("ssl_certfile")
|
||||
ssl_keyfile = kwargs.get("ssl_keyfile")
|
||||
path_prefix = kwargs.get("path_prefix") or "/"
|
||||
|
||||
main_app = FastAPI(
|
||||
title=name,
|
||||
description=description,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user