refac: substandard code
Some checks are pending
Create and publish Docker images with specific build args / build-main-image (linux/amd64) (push) Waiting to run
Create and publish Docker images with specific build args / build-main-image (linux/arm64) (push) Waiting to run
Create and publish Docker images with specific build args / merge-main-images (push) Blocked by required conditions

This commit is contained in:
Timothy Jaeryang Baek 2025-04-30 13:55:20 +04:00
parent 1566a34c9c
commit fe95c6a6fa
2 changed files with 34 additions and 27 deletions

View File

@ -36,32 +36,34 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
for tool in tools: for tool in tools:
endpoint_name = tool.name endpoint_name = tool.name
endpoint_description = tool.description endpoint_description = tool.description
inputSchema = tool.inputSchema inputSchema = tool.inputSchema
outputSchema = getattr(tool, "outputSchema", None) outputSchema = getattr(tool, "outputSchema", None)
custom_fileds = inputSchema.get("$defs", {})
required_fields = inputSchema.get("required", [])
properties = inputSchema.get("properties", {})
form_model_name = f"{endpoint_name}_form_model" form_model_name = f"{endpoint_name}_form_model"
model_fields = get_model_fields( form_model_fields = get_model_fields(
form_model_name, form_model_name,
properties, inputSchema.get("properties", {}),
required_fields, inputSchema.get("required", []),
custom_fileds, inputSchema.get("$defs", {}),
) )
response_model_fields = None
if outputSchema: if outputSchema:
output_model_name = f"{endpoint_name}_output_model" response_model_name = f"{endpoint_name}_response_model"
output_model_fields = get_model_fields( response_model_fields = get_model_fields(
output_model_name, response_model_name,
outputSchema.get("properties", {}), outputSchema.get("properties", {}),
outputSchema.get("required", []), outputSchema.get("required", []),
outputSchema.get("$defs", {}), outputSchema.get("$defs", {}),
) )
else:
output_model_fields = None
tool_handler = get_tool_handler( tool_handler = get_tool_handler(
session, endpoint_name, form_model_name, model_fields, output_model_fields session,
endpoint_name,
form_model_name,
form_model_fields,
response_model_fields,
) )
app.post( app.post(

View File

@ -54,7 +54,7 @@ def _process_schema_property(
model_name_prefix: str, model_name_prefix: str,
prop_name: str, prop_name: str,
is_required: bool, is_required: bool,
custom_fields: Optional[Dict] = None, schema_defs: Optional[Dict] = None,
) -> tuple[Union[Type, List, ForwardRef, Any], FieldInfo]: ) -> tuple[Union[Type, List, ForwardRef, Any], FieldInfo]:
""" """
Recursively processes a schema property to determine its Python type hint Recursively processes a schema property to determine its Python type hint
@ -67,11 +67,12 @@ def _process_schema_property(
if "$ref" in prop_schema: if "$ref" in prop_schema:
ref = prop_schema["$ref"] ref = prop_schema["$ref"]
ref = ref.split("/")[-1] ref = ref.split("/")[-1]
assert ref in custom_fields, "Custom field not found" assert ref in schema_defs, "Custom field not found"
prop_schema = custom_fields[ref] prop_schema = schema_defs[ref]
prop_type = prop_schema.get("type") prop_type = prop_schema.get("type")
prop_desc = prop_schema.get("description", "") prop_desc = prop_schema.get("description", "")
default_value = ... if is_required else prop_schema.get("default", None) default_value = ... if is_required else prop_schema.get("default", None)
pydantic_field = Field(default=default_value, description=prop_desc) pydantic_field = Field(default=default_value, description=prop_desc)
@ -126,7 +127,7 @@ def _process_schema_property(
nested_model_name, nested_model_name,
name, name,
is_nested_required, is_nested_required,
custom_fields, schema_defs,
) )
nested_fields[name] = (nested_type_hint, nested_pydantic_field) nested_fields[name] = (nested_type_hint, nested_pydantic_field)
@ -152,7 +153,7 @@ def _process_schema_property(
f"{model_name_prefix}_{prop_name}", f"{model_name_prefix}_{prop_name}",
"item", "item",
False, # Items aren't required at this level, False, # Items aren't required at this level,
custom_fields, schema_defs,
) )
list_type_hint = List[item_type_hint] list_type_hint = List[item_type_hint]
return list_type_hint, pydantic_field return list_type_hint, pydantic_field
@ -171,7 +172,7 @@ def _process_schema_property(
return Any, pydantic_field return Any, pydantic_field
def get_model_fields(form_model_name, properties, required_fields, custom_fields=None): def get_model_fields(form_model_name, properties, required_fields, schema_defs=None):
model_fields = {} model_fields = {}
_model_cache: Dict[str, Type] = {} _model_cache: Dict[str, Type] = {}
@ -184,7 +185,7 @@ def get_model_fields(form_model_name, properties, required_fields, custom_fields
form_model_name, form_model_name,
param_name, param_name,
is_required, is_required,
custom_fields, schema_defs,
) )
# Use the generated type hint and Field info # Use the generated type hint and Field info
model_fields[param_name] = (python_type_hint, pydantic_field_info) model_fields[param_name] = (python_type_hint, pydantic_field_info)
@ -192,20 +193,24 @@ def get_model_fields(form_model_name, properties, required_fields, custom_fields
def get_tool_handler( def get_tool_handler(
session, endpoint_name, form_model_name, model_fields, output_model_fileds=None session,
endpoint_name,
form_model_name,
form_model_fields,
response_model_fields=None,
): ):
if model_fields: if form_model_fields:
FormModel = create_model(form_model_name, **model_fields) FormModel = create_model(form_model_name, **form_model_fields)
OutputModel = ( ResponseModel = (
create_model(f"{endpoint_name}_output_model", **output_model_fileds) create_model(f"{endpoint_name}_response_model", **response_model_fields)
if output_model_fileds if response_model_fields
else Any else Any
) )
def make_endpoint_func( def make_endpoint_func(
endpoint_name: str, FormModel, session: ClientSession endpoint_name: str, FormModel, session: ClientSession
): # Parameterized endpoint ): # Parameterized endpoint
async def tool(form_data: FormModel) -> OutputModel: async def tool(form_data: FormModel) -> ResponseModel:
args = form_data.model_dump(exclude_none=True) args = form_data.model_dump(exclude_none=True)
print(f"Calling endpoint: {endpoint_name}, with args: {args}") print(f"Calling endpoint: {endpoint_name}, with args: {args}")
try: try: