mirror of
				https://github.com/open-webui/openapi-servers
				synced 2025-06-26 18:17:04 +00:00 
			
		
		
		
	refac: reuse shared httpx client and concurrent channel history fetch for major performance gain
This commit is contained in:
		
							parent
							
								
									565a563a77
								
							
						
					
					
						commit
						49900846d0
					
				| @ -1,276 +1,297 @@ | ||||
| """Slack MCP Server – high‑performance version | ||||
| ------------------------------------------------ | ||||
| Showcase‑level code quality and pythonic clarity. | ||||
| """ | ||||
| 
 | ||||
| import os | ||||
| import httpx | ||||
| import inspect | ||||
| import asyncio | ||||
| import logging | ||||
| import json  # For JSONDecodeError | ||||
| from typing import Optional, List, Dict, Any, Type, Callable | ||||
| from fastapi import FastAPI, HTTPException, Body, Depends, Security | ||||
| from fastapi.security import APIKeyHeader | ||||
| from fastapi.middleware.cors import CORSMiddleware | ||||
| from pydantic import BaseModel, Field | ||||
| from dotenv import load_dotenv | ||||
| 
 | ||||
| # --- Logging Setup --- | ||||
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | ||||
| import httpx | ||||
| from dotenv import load_dotenv | ||||
| from fastapi import FastAPI, HTTPException, Body, Depends, Security | ||||
| from fastapi.middleware.cors import CORSMiddleware | ||||
| from fastapi.security import APIKeyHeader | ||||
| from pydantic import BaseModel, Field | ||||
| 
 | ||||
| # --------------------------------------------------------------------------- | ||||
| # Logging | ||||
| # --------------------------------------------------------------------------- | ||||
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| # Load environment variables from .env file | ||||
| # --------------------------------------------------------------------------- | ||||
| # Environment variables | ||||
| # --------------------------------------------------------------------------- | ||||
| load_dotenv() | ||||
| 
 | ||||
| # --- Environment Variable Checks --- | ||||
| SLACK_BOT_TOKEN = os.getenv("SLACK_BOT_TOKEN") | ||||
| SLACK_TEAM_ID = os.getenv("SLACK_TEAM_ID") | ||||
| SLACK_CHANNEL_IDS_STR = os.getenv("SLACK_CHANNEL_IDS")  # Optional | ||||
| ALLOWED_ORIGINS_STR = os.getenv("ALLOWED_ORIGINS", "*") # Default to allow all | ||||
| ALLOWED_ORIGINS_STR = os.getenv("ALLOWED_ORIGINS", "*") | ||||
| SERVER_API_KEY = os.getenv("SERVER_API_KEY")  # Optional API key for security | ||||
| 
 | ||||
| if not SLACK_BOT_TOKEN: | ||||
|     # Fail fast if essential config is missing | ||||
|     logger.critical("SLACK_BOT_TOKEN environment variable not set.") | ||||
|     raise ValueError("SLACK_BOT_TOKEN environment variable not set.") | ||||
| if not SLACK_TEAM_ID: | ||||
|     logger.critical("SLACK_TEAM_ID environment variable not set.") | ||||
|     raise ValueError("SLACK_TEAM_ID environment variable not set.") | ||||
| 
 | ||||
| PREDEFINED_CHANNEL_IDS = [ | ||||
|     channel_id.strip() | ||||
|     for channel_id in SLACK_CHANNEL_IDS_STR.split(',') | ||||
| ] if SLACK_CHANNEL_IDS_STR else None | ||||
| PREDEFINED_CHANNEL_IDS: Optional[List[str]] = ( | ||||
|     [cid.strip() for cid in SLACK_CHANNEL_IDS_STR.split(",")] if SLACK_CHANNEL_IDS_STR else None | ||||
| ) | ||||
| 
 | ||||
| # --- FastAPI App Setup --- | ||||
| # --------------------------------------------------------------------------- | ||||
| # FastAPI app setup | ||||
| # --------------------------------------------------------------------------- | ||||
| app = FastAPI( | ||||
|     title="Slack API Server", | ||||
|     version="1.0.0", | ||||
|     description="FastAPI server providing Slack functionalities via specific, dynamically generated tool endpoints.", | ||||
| ) | ||||
| 
 | ||||
| # Configure CORS | ||||
| allow_origins = [origin.strip() for origin in ALLOWED_ORIGINS_STR.split(',')] | ||||
| # CORS | ||||
| allow_origins = [origin.strip() for origin in ALLOWED_ORIGINS_STR.split(",")] | ||||
| if allow_origins == ["*"]: | ||||
|     logger.warning("CORS allow_origins is set to '*' which is insecure for production. Consider setting ALLOWED_ORIGINS environment variable.") | ||||
| 
 | ||||
| app.add_middleware( | ||||
|     CORSMiddleware, | ||||
|     allow_origins=allow_origins, | ||||
|     allow_credentials=True, # Allow credentials if origins are specific, adjust if needed | ||||
|     allow_credentials=True, | ||||
|     allow_methods=["*"], | ||||
|     allow_headers=["*"], | ||||
| ) | ||||
| 
 | ||||
| # --- API Key Security --- | ||||
| api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) # auto_error=False to handle optional key | ||||
| # --------------------------------------------------------------------------- | ||||
| # API key security | ||||
| # --------------------------------------------------------------------------- | ||||
| api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) | ||||
| 
 | ||||
| 
 | ||||
| async def get_api_key(key: str = Security(api_key_header)): | ||||
|     if SERVER_API_KEY: # Only enforce key if it's set in the environment | ||||
|     if SERVER_API_KEY: | ||||
|         if not key: | ||||
|             logger.warning("API Key required but not provided in X-API-Key header.") | ||||
|             raise HTTPException(status_code=401, detail="X-API-Key header required") | ||||
|         if key != SERVER_API_KEY: | ||||
|             logger.warning("Invalid API Key provided.") | ||||
|             raise HTTPException(status_code=401, detail="Invalid API Key") | ||||
|         # If key is valid and required, proceed | ||||
|     # If SERVER_API_KEY is not set, allow access without a key | ||||
|     # logger.info("API Key check passed (or not required).") # Optional: Log successful checks | ||||
|     return key # Return the key or None if not required/provided | ||||
|     return key  # May be None when not required | ||||
| 
 | ||||
| 
 | ||||
| if not SERVER_API_KEY: | ||||
|     logger.warning("SERVER_API_KEY environment variable is not set. Server will allow unauthenticated requests.") | ||||
| 
 | ||||
| # [Previous Pydantic models remain the same...] | ||||
| # --------------------------------------------------------------------------- | ||||
| # Pydantic models (arguments & responses) | ||||
| # --------------------------------------------------------------------------- | ||||
| 
 | ||||
| class ListChannelsArgs(BaseModel): | ||||
|     limit: Optional[int] = Field(100, description="Maximum number of channels to return (default 100, max 200)") | ||||
|     cursor: Optional[str] = Field(None, description="Pagination cursor for next page of results") | ||||
| 
 | ||||
| 
 | ||||
| class PostMessageArgs(BaseModel): | ||||
|     channel_id: str = Field(..., description="The ID of the channel to post to") | ||||
|     text: str = Field(..., description="The message text to post") | ||||
| 
 | ||||
| 
 | ||||
| class ReplyToThreadArgs(BaseModel): | ||||
|     channel_id: str = Field(..., description="The ID of the channel containing the thread") | ||||
|     thread_ts: str = Field(..., description="The timestamp of the parent message (e.g., '1234567890.123456')") | ||||
|     text: str = Field(..., description="The reply text") | ||||
| 
 | ||||
| 
 | ||||
| class AddReactionArgs(BaseModel): | ||||
|     channel_id: str = Field(..., description="The ID of the channel containing the message") | ||||
|     timestamp: str = Field(..., description="The timestamp of the message to react to") | ||||
|     reaction: str = Field(..., description="The name of the emoji reaction (without colons)") | ||||
| 
 | ||||
| 
 | ||||
| class GetChannelHistoryArgs(BaseModel): | ||||
|     channel_id: str = Field(..., description="The ID of the channel") | ||||
|     limit: Optional[int] = Field(10, description="Number of messages to retrieve (default 10)") | ||||
| 
 | ||||
| 
 | ||||
| class GetThreadRepliesArgs(BaseModel): | ||||
|     channel_id: str = Field(..., description="The ID of the channel containing the thread") | ||||
|     thread_ts: str = Field(..., description="The timestamp of the parent message (e.g., '1234567890.123456')") | ||||
| 
 | ||||
| 
 | ||||
| class GetUsersArgs(BaseModel): | ||||
|     cursor: Optional[str] = Field(None, description="Pagination cursor for next page of results") | ||||
|     limit: Optional[int] = Field(100, description="Maximum number of users to return (default 100, max 200)") | ||||
| 
 | ||||
| 
 | ||||
| class GetUserProfileArgs(BaseModel): | ||||
|     user_id: str = Field(..., description="The ID of the user") | ||||
| 
 | ||||
| 
 | ||||
| class ToolResponse(BaseModel): | ||||
|     content: Dict[str, Any] = Field(..., description="The JSON response from the Slack API call") | ||||
| 
 | ||||
| # --- Slack Client Class --- | ||||
| 
 | ||||
| # --------------------------------------------------------------------------- | ||||
| # Slack client (high‑performance) | ||||
| # --------------------------------------------------------------------------- | ||||
| 
 | ||||
| class SlackClient: | ||||
|     """Thin async wrapper over Slack Web API with connection‑pool reuse.""" | ||||
| 
 | ||||
|     BASE_URL = "https://slack.com/api/" | ||||
| 
 | ||||
|     def __init__(self, token: str, team_id: str): | ||||
|     def __init__(self, token: str, team_id: str, *, max_connections: int = 20): | ||||
|         self.team_id = team_id | ||||
|         self.headers = { | ||||
|             "Authorization": f"Bearer {token}", | ||||
|             "Content-Type": "application/json; charset=utf-8", | ||||
|         } | ||||
|         self.team_id = team_id | ||||
|         limits = httpx.Limits(max_connections=max_connections, max_keepalive_connections=max_connections) | ||||
|         self._client = httpx.AsyncClient( | ||||
|             base_url=self.BASE_URL, | ||||
|             headers=self.headers, | ||||
|             limits=limits, | ||||
|             http2=True, | ||||
|             timeout=10, | ||||
|         ) | ||||
| 
 | ||||
|     async def _request(self, method: str, endpoint: str, params: Optional[Dict] = None, json_data: Optional[Dict] = None) -> Dict[str, Any]: | ||||
|         async with httpx.AsyncClient(base_url=self.BASE_URL, headers=self.headers) as client: | ||||
|     # ---------------- private helpers ---------------- # | ||||
|     async def _request( | ||||
|         self, | ||||
|         method: str, | ||||
|         endpoint: str, | ||||
|         *, | ||||
|         params: Optional[Dict] = None, | ||||
|         json_data: Optional[Dict] = None, | ||||
|     ) -> Dict[str, Any]: | ||||
|         try: | ||||
|                 response = await client.request(method, endpoint, params=params, json=json_data) | ||||
|             response = await self._client.request(method, endpoint, params=params, json=json_data) | ||||
|             response.raise_for_status() | ||||
|             data = response.json() | ||||
|             if not data.get("ok"): | ||||
|                 error_msg = data.get("error", "Unknown Slack API error") | ||||
|                     # Return the specific Slack error in the response | ||||
|                     logger.warning(f"Slack API Error for {method} {endpoint}: {error_msg}") | ||||
|                     raise HTTPException(status_code=400, detail={"slack_error": error_msg, "message": f"Slack API returned an error: {error_msg}"}) | ||||
|                 raise HTTPException(status_code=400, detail={"slack_error": error_msg}) | ||||
|             return data | ||||
|         except httpx.HTTPStatusError as e: | ||||
|                 # Handle specific HTTP errors like rate limiting (429) | ||||
|             if e.response.status_code == 429: | ||||
|                 retry_after = e.response.headers.get("Retry-After") | ||||
|                     detail = f"Slack API rate limit exceeded. Retry after {retry_after} seconds." if retry_after else "Slack API rate limit exceeded." | ||||
|                     logger.warning(f"Rate limit hit for {method} {endpoint}. Retry-After: {retry_after}") | ||||
|                 detail = ( | ||||
|                     f"Slack API rate limit exceeded. Retry after {retry_after} seconds." | ||||
|                     if retry_after | ||||
|                     else "Slack API rate limit exceeded." | ||||
|                 ) | ||||
|                 logger.warning("Rate limit hit: %s", detail) | ||||
|                 raise HTTPException(status_code=429, detail=detail, headers={"Retry-After": retry_after} if retry_after else {}) | ||||
|                 else: | ||||
|                     logger.error(f"HTTP Error: {e.response.status_code} - {e.response.text}", exc_info=True) | ||||
|                     raise HTTPException(status_code=e.response.status_code, detail=f"Slack API HTTP Error: Status {e.response.status_code}") | ||||
|             logger.error("HTTP Error %s - %s", e.response.status_code, e.response.text, exc_info=True) | ||||
|             raise HTTPException(status_code=e.response.status_code, detail="Slack API HTTP Error") | ||||
|         except httpx.RequestError as e: | ||||
|                 logger.error(f"Request Error connecting to Slack API: {e}", exc_info=True) | ||||
|             logger.error("Request Error connecting to Slack API: %s", e, exc_info=True) | ||||
|             raise HTTPException(status_code=503, detail=f"Could not connect to Slack API: {e}") | ||||
|         except json.JSONDecodeError as e: | ||||
|                  logger.error(f"Failed to decode JSON response from Slack API for {method} {endpoint}: {e}", exc_info=True) | ||||
|                  raise HTTPException(status_code=502, detail="Invalid response received from Slack API.") | ||||
|             except Exception as e: # Catch other unexpected errors | ||||
|                 logger.exception(f"Unexpected error during Slack request for {method} {endpoint}: {e}") # Use logger.exception to include traceback | ||||
|                 raise HTTPException(status_code=500, detail=f"An internal server error occurred: {type(e).__name__}") | ||||
|             logger.error("Failed to decode JSON: %s", e, exc_info=True) | ||||
|             raise HTTPException(status_code=502, detail="Invalid JSON from Slack API") | ||||
|         except Exception as e:  # noqa: BLE001 | ||||
|             logger.exception("Unexpected error during Slack request: %s", e) | ||||
|             raise HTTPException(status_code=500, detail=f"Internal error: {type(e).__name__}") | ||||
| 
 | ||||
|     async def get_channel_history(self, args: GetChannelHistoryArgs) -> Dict[str, Any]: | ||||
|         params = {"channel": args.channel_id, "limit": args.limit} | ||||
|         return await self._request("GET", "conversations.history", params=params) | ||||
| 
 | ||||
|     async def get_channels(self, args: ListChannelsArgs) -> Dict[str, Any]: | ||||
|         limit = args.limit | ||||
|         cursor = args.cursor | ||||
| 
 | ||||
|         async def fetch_channel_with_history(channel_id: str) -> Dict[str, Any]: | ||||
|             # First get channel info | ||||
|             channel_info = await self._request("GET", "conversations.info", params={"channel": channel_id}) | ||||
|             if not channel_info.get("ok") or channel_info.get("channel", {}).get("is_archived"): | ||||
|                 return None | ||||
| 
 | ||||
|             channel_data = channel_info["channel"] | ||||
| 
 | ||||
|             # Then get channel history | ||||
|     # ---------------- public helpers ---------------- # | ||||
|     async def channel_with_history(self, channel_id: str, *, history_limit: int = 1) -> Optional[Dict[str, Any]]: | ||||
|         """Return channel metadata plus ≤ ``history_limit`` recent messages, or None.""" | ||||
|         try: | ||||
|                 history = await self._request( | ||||
|             info = await self._request("GET", "conversations.info", params={"channel": channel_id}) | ||||
|             chan = info["channel"] | ||||
|             if chan.get("is_archived"): | ||||
|                 return None | ||||
|             hist = await self._request( | ||||
|                 "GET", | ||||
|                 "conversations.history", | ||||
|                     params={ | ||||
|                         "channel": channel_id, | ||||
|                         "limit": 1 # Fetch minimal history by default to speed up get_channels. Consider asyncio.gather for concurrency. | ||||
|                     } | ||||
|                 params={"channel": channel_id, "limit": history_limit}, | ||||
|             ) | ||||
|                 # Add history to channel data | ||||
|                 if history.get("ok"): | ||||
|                     channel_data["history"] = history.get("messages", []) | ||||
|             except Exception as e: # Catch errors during history fetch but don't fail the whole channel list | ||||
|                 logger.warning(f"Error fetching history for channel {channel_id}: {e}", exc_info=True) | ||||
|                 channel_data["history"] = [] # Ensure history key exists even if fetch fails | ||||
|             chan["history"] = hist.get("messages", []) | ||||
|             return chan | ||||
|         except Exception as exc:  # noqa: BLE001 | ||||
|             logger.warning("Skipping channel %s – %s", channel_id, exc, exc_info=True) | ||||
|             return None | ||||
| 
 | ||||
|             return channel_data | ||||
|     # ---------------- API surface ---------------- # | ||||
|     async def get_channel_history(self, args: GetChannelHistoryArgs) -> Dict[str, Any]: | ||||
|         return await self._request("GET", "conversations.history", params={"channel": args.channel_id, "limit": args.limit}) | ||||
| 
 | ||||
|     async def get_channels(self, args: ListChannelsArgs) -> Dict[str, Any]:  # noqa: C901 – keep cohesive | ||||
|         # 1. decide which ids to fetch | ||||
|         if PREDEFINED_CHANNEL_IDS: | ||||
|             channels_info = [] | ||||
|             for channel_id in PREDEFINED_CHANNEL_IDS: | ||||
|                 try: | ||||
|                     if channel_data := await fetch_channel_with_history(channel_id): | ||||
|                         channels_info.append(channel_data) | ||||
|                 except Exception as e: # Catch errors fetching predefined channels | ||||
|                     logger.warning(f"Could not fetch info for predefined channel {channel_id}: {e}", exc_info=True) | ||||
| 
 | ||||
|             return { | ||||
|                 "ok": True, | ||||
|                 "channels": channels_info, | ||||
|                 "response_metadata": {"next_cursor": ""} | ||||
|             } | ||||
|             ids = PREDEFINED_CHANNEL_IDS | ||||
|             next_cursor = "" | ||||
|         else: | ||||
|             # First get list of channels | ||||
|             params = { | ||||
|             params: Dict[str, Any] = { | ||||
|                 "types": "public_channel", | ||||
|                 "exclude_archived": "true", | ||||
|                 "limit": min(limit, 200), | ||||
|                 "team_id": self.team_id, | ||||
|             } | ||||
|             if cursor: | ||||
|                 params["cursor"] = cursor | ||||
| 
 | ||||
|             channels_list = await self._request("GET", "conversations.list", params=params) | ||||
| 
 | ||||
|             if not channels_list.get("ok"): | ||||
|                 return channels_list | ||||
| 
 | ||||
|             # Then fetch history for each channel | ||||
|             channels_with_history = [] | ||||
|             for channel in channels_list["channels"]: | ||||
|                 try: | ||||
|                     if channel_data := await fetch_channel_with_history(channel["id"]): | ||||
|                         channels_with_history.append(channel_data) | ||||
|                 except Exception as e: # Catch errors during history fetch but don't fail the whole channel list | ||||
|                     logger.warning(f"Error fetching history for channel {channel['id']}: {e}", exc_info=True) | ||||
|                     channel["history"] = [] # Add empty history on error | ||||
|                     channels_with_history.append(channel) | ||||
| 
 | ||||
|             return { | ||||
|                 "ok": True, | ||||
|                 "channels": channels_with_history, | ||||
|                 "response_metadata": channels_list.get("response_metadata", {"next_cursor": ""}) | ||||
|             } | ||||
| 
 | ||||
|     async def post_message(self, args: PostMessageArgs) -> Dict[str, Any]: | ||||
|         payload = {"channel": args.channel_id, "text": args.text} | ||||
|         return await self._request("POST", "chat.postMessage", json_data=payload) | ||||
| 
 | ||||
|     async def post_reply(self, args: ReplyToThreadArgs) -> Dict[str, Any]: | ||||
|         payload = {"channel": args.channel_id, "thread_ts": args.thread_ts, "text": args.text} | ||||
|         return await self._request("POST", "chat.postMessage", json_data=payload) | ||||
| 
 | ||||
|     async def add_reaction(self, args: AddReactionArgs) -> Dict[str, Any]: | ||||
|         payload = {"channel": args.channel_id, "timestamp": args.timestamp, "name": args.reaction} | ||||
|         return await self._request("POST", "reactions.add", json_data=payload) | ||||
| 
 | ||||
|     async def get_thread_replies(self, args: GetThreadRepliesArgs) -> Dict[str, Any]: | ||||
|         params = {"channel": args.channel_id, "ts": args.thread_ts} | ||||
|         return await self._request("GET", "conversations.replies", params=params) | ||||
| 
 | ||||
|     async def get_users(self, args: GetUsersArgs) -> Dict[str, Any]: | ||||
|         params = { | ||||
|                 "limit": min(args.limit, 200), | ||||
|                 "team_id": self.team_id, | ||||
|             } | ||||
|             if args.cursor: | ||||
|                 params["cursor"] = args.cursor | ||||
|             clist = await self._request("GET", "conversations.list", params=params) | ||||
|             ids = [c["id"] for c in clist["channels"]] | ||||
|             next_cursor = clist.get("response_metadata", {}).get("next_cursor", "") | ||||
| 
 | ||||
|         # 2. fetch metadata + history concurrently under a semaphore | ||||
|         sem = asyncio.Semaphore(10)  # adjust parallelism as desired | ||||
| 
 | ||||
|         async def guarded(cid: str): | ||||
|             async with sem: | ||||
|                 return await self.channel_with_history(cid) | ||||
| 
 | ||||
|         channels = [c for c in await asyncio.gather(*(guarded(cid) for cid in ids)) if c] | ||||
|         return {"ok": True, "channels": channels, "response_metadata": {"next_cursor": next_cursor}} | ||||
| 
 | ||||
|     async def post_message(self, args: PostMessageArgs) -> Dict[str, Any]: | ||||
|         return await self._request("POST", "chat.postMessage", json_data={"channel": args.channel_id, "text": args.text}) | ||||
| 
 | ||||
|     async def post_reply(self, args: ReplyToThreadArgs) -> Dict[str, Any]: | ||||
|         return await self._request( | ||||
|             "POST", | ||||
|             "chat.postMessage", | ||||
|             json_data={"channel": args.channel_id, "thread_ts": args.thread_ts, "text": args.text}, | ||||
|         ) | ||||
| 
 | ||||
|     async def add_reaction(self, args: AddReactionArgs) -> Dict[str, Any]: | ||||
|         return await self._request( | ||||
|             "POST", | ||||
|             "reactions.add", | ||||
|             json_data={"channel": args.channel_id, "timestamp": args.timestamp, "name": args.reaction}, | ||||
|         ) | ||||
| 
 | ||||
|     async def get_thread_replies(self, args: GetThreadRepliesArgs) -> Dict[str, Any]: | ||||
|         return await self._request("GET", "conversations.replies", params={"channel": args.channel_id, "ts": args.thread_ts}) | ||||
| 
 | ||||
|     async def get_users(self, args: GetUsersArgs) -> Dict[str, Any]: | ||||
|         params = {"limit": min(args.limit, 200), "team_id": self.team_id} | ||||
|         if args.cursor: | ||||
|             params["cursor"] = args.cursor | ||||
|         return await self._request("GET", "users.list", params=params) | ||||
| 
 | ||||
|     async def get_user_profile(self, args: GetUserProfileArgs) -> Dict[str, Any]: | ||||
|         params = {"user": args.user_id, "include_labels": "true"} | ||||
|         return await self._request("GET", "users.profile.get", params=params) | ||||
|         return await self._request("GET", "users.profile.get", params={"user": args.user_id, "include_labels": "true"}) | ||||
| 
 | ||||
| # --- Instantiate Slack Client --- | ||||
|     # ---------------- lifecycle ---------------- # | ||||
|     async def aclose(self) -> None:  # call on app shutdown | ||||
|         await self._client.aclose() | ||||
| 
 | ||||
| 
 | ||||
| # --------------------------------------------------------------------------- | ||||
| # Instantiate Slack client | ||||
| # --------------------------------------------------------------------------- | ||||
| slack_client = SlackClient(token=SLACK_BOT_TOKEN, team_id=SLACK_TEAM_ID) | ||||
| 
 | ||||
| # --- Tool Definitions & Endpoint Generation --- | ||||
| 
 | ||||
| # --------------------------------------------------------------------------- | ||||
| # Dynamic tool mapping / endpoint generation | ||||
| # --------------------------------------------------------------------------- | ||||
| TOOL_MAPPING = { | ||||
|     "slack_list_channels": { | ||||
|         "args_model": ListChannelsArgs, | ||||
| @ -314,40 +335,45 @@ TOOL_MAPPING = { | ||||
|     }, | ||||
| } | ||||
| 
 | ||||
| # Define a function factory to create endpoint handlers, including API key dependency | ||||
| 
 | ||||
| # ---------------- endpoint factory ---------------- # | ||||
| 
 | ||||
| def create_endpoint_handler(tool_name: str, method: Callable, args_model: Type[BaseModel]): | ||||
|     async def endpoint_handler( | ||||
|         args: args_model = Body(...), | ||||
|         api_key: str = Depends(get_api_key) # Add API key dependency here | ||||
|     ) -> ToolResponse: | ||||
|     async def handler(args: args_model = Body(...), api_key: str = Depends(get_api_key)) -> ToolResponse:  # noqa: ANN001 | ||||
|         try: | ||||
|             result = await method(args=args) | ||||
|             return {"content": result} | ||||
|         except HTTPException as e: | ||||
|             raise e | ||||
|         except Exception as e: | ||||
|             logger.exception(f"Error executing tool {tool_name}: {e}") # Use logger.exception here too | ||||
|             raise HTTPException(status_code=500, detail=f"Internal server error: {type(e).__name__}") | ||||
|     return endpoint_handler | ||||
|         except HTTPException: | ||||
|             raise  # re‑raise untouched | ||||
|         except Exception as exc:  # noqa: BLE001 | ||||
|             logger.exception("Error executing tool %s: %s", tool_name, exc) | ||||
|             raise HTTPException(status_code=500, detail=f"Internal server error: {type(exc).__name__}") | ||||
| 
 | ||||
| # Register endpoints for each tool | ||||
| for tool_name, config in TOOL_MAPPING.items(): | ||||
|     handler = create_endpoint_handler( | ||||
|         tool_name=tool_name, | ||||
|         method=config["method"], | ||||
|         args_model=config["args_model"] | ||||
|     ) | ||||
|     return handler | ||||
| 
 | ||||
| 
 | ||||
| for name, cfg in TOOL_MAPPING.items(): | ||||
|     app.post( | ||||
|         f"/{tool_name}", | ||||
|         f"/{name}", | ||||
|         response_model=ToolResponse, | ||||
|         summary=config["description"], | ||||
|         description=f"Executes the {tool_name} tool. Arguments are passed in the request body.", | ||||
|         summary=cfg["description"], | ||||
|         description=f"Executes the {name} tool. Arguments are passed in the request body.", | ||||
|         tags=["Slack Tools"], | ||||
|         name=tool_name | ||||
|     )(handler) | ||||
|         name=name, | ||||
|     )(create_endpoint_handler(name, cfg["method"], cfg["args_model"])) | ||||
| 
 | ||||
| # --- Root Endpoint --- | ||||
| 
 | ||||
| # --------------------------------------------------------------------------- | ||||
| # Lifecycle events | ||||
| # --------------------------------------------------------------------------- | ||||
| @app.on_event("shutdown") | ||||
| async def _close_slack_client(): | ||||
|     await slack_client.aclose() | ||||
| 
 | ||||
| 
 | ||||
| # --------------------------------------------------------------------------- | ||||
| # Root endpoint | ||||
| # --------------------------------------------------------------------------- | ||||
| @app.get("/", summary="Root endpoint", include_in_schema=False) | ||||
| async def read_root(): | ||||
|     return {"message": "Slack API Server is running. See /docs for available tool endpoints."} | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user