From 49900846d00ad5b7283d160cd41e099695518e01 Mon Sep 17 00:00:00 2001 From: Taylor Wilsdon Date: Sun, 20 Apr 2025 19:51:36 -0400 Subject: [PATCH] =?UTF-8?q?refac:=20reuse=20shared=20httpx=20client=20and?= =?UTF-8?q?=20concurrent=20channel=C2=A0history=20fetch=20for=20major=20pe?= =?UTF-8?q?rformance=20gain?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- servers/slack/main.py | 384 ++++++++++++++++++++++-------------------- 1 file changed, 205 insertions(+), 179 deletions(-) diff --git a/servers/slack/main.py b/servers/slack/main.py index cea6836..6b77cb2 100644 --- a/servers/slack/main.py +++ b/servers/slack/main.py @@ -1,276 +1,297 @@ -import os -import httpx -import inspect -import logging -import json # For JSONDecodeError -from typing import Optional, List, Dict, Any, Type, Callable -from fastapi import FastAPI, HTTPException, Body, Depends, Security -from fastapi.security import APIKeyHeader -from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel, Field -from dotenv import load_dotenv +"""Slack MCP Server – high‑performance version +------------------------------------------------ +Showcase‑level code quality and pythonic clarity. +""" -# --- Logging Setup --- -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +import os +import asyncio +import logging +import json # For JSONDecodeError +from typing import Optional, List, Dict, Any, Type, Callable + +import httpx +from dotenv import load_dotenv +from fastapi import FastAPI, HTTPException, Body, Depends, Security +from fastapi.middleware.cors import CORSMiddleware +from fastapi.security import APIKeyHeader +from pydantic import BaseModel, Field + +# --------------------------------------------------------------------------- +# Logging +# --------------------------------------------------------------------------- +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) -# Load environment variables from .env file +# --------------------------------------------------------------------------- +# Environment variables +# --------------------------------------------------------------------------- load_dotenv() -# --- Environment Variable Checks --- SLACK_BOT_TOKEN = os.getenv("SLACK_BOT_TOKEN") SLACK_TEAM_ID = os.getenv("SLACK_TEAM_ID") -SLACK_CHANNEL_IDS_STR = os.getenv("SLACK_CHANNEL_IDS") # Optional -ALLOWED_ORIGINS_STR = os.getenv("ALLOWED_ORIGINS", "*") # Default to allow all -SERVER_API_KEY = os.getenv("SERVER_API_KEY") # Optional API key for security +SLACK_CHANNEL_IDS_STR = os.getenv("SLACK_CHANNEL_IDS") # Optional +ALLOWED_ORIGINS_STR = os.getenv("ALLOWED_ORIGINS", "*") +SERVER_API_KEY = os.getenv("SERVER_API_KEY") # Optional API key for security if not SLACK_BOT_TOKEN: - # Fail fast if essential config is missing logger.critical("SLACK_BOT_TOKEN environment variable not set.") raise ValueError("SLACK_BOT_TOKEN environment variable not set.") if not SLACK_TEAM_ID: logger.critical("SLACK_TEAM_ID environment variable not set.") raise ValueError("SLACK_TEAM_ID environment variable not set.") -PREDEFINED_CHANNEL_IDS = [ - channel_id.strip() - for channel_id in SLACK_CHANNEL_IDS_STR.split(',') -] if SLACK_CHANNEL_IDS_STR else None +PREDEFINED_CHANNEL_IDS: Optional[List[str]] = ( + [cid.strip() for cid in SLACK_CHANNEL_IDS_STR.split(",")] if SLACK_CHANNEL_IDS_STR else None +) -# --- FastAPI App Setup --- +# --------------------------------------------------------------------------- +# FastAPI app setup +# --------------------------------------------------------------------------- app = FastAPI( title="Slack API Server", version="1.0.0", description="FastAPI server providing Slack functionalities via specific, dynamically generated tool endpoints.", ) -# Configure CORS -allow_origins = [origin.strip() for origin in ALLOWED_ORIGINS_STR.split(',')] +# CORS +allow_origins = [origin.strip() for origin in ALLOWED_ORIGINS_STR.split(",")] if allow_origins == ["*"]: logger.warning("CORS allow_origins is set to '*' which is insecure for production. Consider setting ALLOWED_ORIGINS environment variable.") app.add_middleware( CORSMiddleware, allow_origins=allow_origins, - allow_credentials=True, # Allow credentials if origins are specific, adjust if needed + allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) -# --- API Key Security --- -api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) # auto_error=False to handle optional key +# --------------------------------------------------------------------------- +# API key security +# --------------------------------------------------------------------------- +api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) + async def get_api_key(key: str = Security(api_key_header)): - if SERVER_API_KEY: # Only enforce key if it's set in the environment + if SERVER_API_KEY: if not key: logger.warning("API Key required but not provided in X-API-Key header.") raise HTTPException(status_code=401, detail="X-API-Key header required") if key != SERVER_API_KEY: logger.warning("Invalid API Key provided.") raise HTTPException(status_code=401, detail="Invalid API Key") - # If key is valid and required, proceed - # If SERVER_API_KEY is not set, allow access without a key - # logger.info("API Key check passed (or not required).") # Optional: Log successful checks - return key # Return the key or None if not required/provided + return key # May be None when not required + if not SERVER_API_KEY: logger.warning("SERVER_API_KEY environment variable is not set. Server will allow unauthenticated requests.") -# [Previous Pydantic models remain the same...] +# --------------------------------------------------------------------------- +# Pydantic models (arguments & responses) +# --------------------------------------------------------------------------- + class ListChannelsArgs(BaseModel): limit: Optional[int] = Field(100, description="Maximum number of channels to return (default 100, max 200)") cursor: Optional[str] = Field(None, description="Pagination cursor for next page of results") + class PostMessageArgs(BaseModel): channel_id: str = Field(..., description="The ID of the channel to post to") text: str = Field(..., description="The message text to post") + class ReplyToThreadArgs(BaseModel): channel_id: str = Field(..., description="The ID of the channel containing the thread") thread_ts: str = Field(..., description="The timestamp of the parent message (e.g., '1234567890.123456')") text: str = Field(..., description="The reply text") + class AddReactionArgs(BaseModel): channel_id: str = Field(..., description="The ID of the channel containing the message") timestamp: str = Field(..., description="The timestamp of the message to react to") reaction: str = Field(..., description="The name of the emoji reaction (without colons)") + class GetChannelHistoryArgs(BaseModel): channel_id: str = Field(..., description="The ID of the channel") limit: Optional[int] = Field(10, description="Number of messages to retrieve (default 10)") + class GetThreadRepliesArgs(BaseModel): channel_id: str = Field(..., description="The ID of the channel containing the thread") thread_ts: str = Field(..., description="The timestamp of the parent message (e.g., '1234567890.123456')") + class GetUsersArgs(BaseModel): cursor: Optional[str] = Field(None, description="Pagination cursor for next page of results") limit: Optional[int] = Field(100, description="Maximum number of users to return (default 100, max 200)") + class GetUserProfileArgs(BaseModel): user_id: str = Field(..., description="The ID of the user") + class ToolResponse(BaseModel): content: Dict[str, Any] = Field(..., description="The JSON response from the Slack API call") -# --- Slack Client Class --- + +# --------------------------------------------------------------------------- +# Slack client (high‑performance) +# --------------------------------------------------------------------------- + class SlackClient: + """Thin async wrapper over Slack Web API with connection‑pool reuse.""" + BASE_URL = "https://slack.com/api/" - def __init__(self, token: str, team_id: str): + def __init__(self, token: str, team_id: str, *, max_connections: int = 20): + self.team_id = team_id self.headers = { "Authorization": f"Bearer {token}", "Content-Type": "application/json; charset=utf-8", } - self.team_id = team_id + limits = httpx.Limits(max_connections=max_connections, max_keepalive_connections=max_connections) + self._client = httpx.AsyncClient( + base_url=self.BASE_URL, + headers=self.headers, + limits=limits, + http2=True, + timeout=10, + ) - async def _request(self, method: str, endpoint: str, params: Optional[Dict] = None, json_data: Optional[Dict] = None) -> Dict[str, Any]: - async with httpx.AsyncClient(base_url=self.BASE_URL, headers=self.headers) as client: - try: - response = await client.request(method, endpoint, params=params, json=json_data) - response.raise_for_status() - data = response.json() - if not data.get("ok"): - error_msg = data.get("error", "Unknown Slack API error") - # Return the specific Slack error in the response - logger.warning(f"Slack API Error for {method} {endpoint}: {error_msg}") - raise HTTPException(status_code=400, detail={"slack_error": error_msg, "message": f"Slack API returned an error: {error_msg}"}) - return data - except httpx.HTTPStatusError as e: - # Handle specific HTTP errors like rate limiting (429) - if e.response.status_code == 429: - retry_after = e.response.headers.get("Retry-After") - detail = f"Slack API rate limit exceeded. Retry after {retry_after} seconds." if retry_after else "Slack API rate limit exceeded." - logger.warning(f"Rate limit hit for {method} {endpoint}. Retry-After: {retry_after}") - raise HTTPException(status_code=429, detail=detail, headers={"Retry-After": retry_after} if retry_after else {}) - else: - logger.error(f"HTTP Error: {e.response.status_code} - {e.response.text}", exc_info=True) - raise HTTPException(status_code=e.response.status_code, detail=f"Slack API HTTP Error: Status {e.response.status_code}") - except httpx.RequestError as e: - logger.error(f"Request Error connecting to Slack API: {e}", exc_info=True) - raise HTTPException(status_code=503, detail=f"Could not connect to Slack API: {e}") - except json.JSONDecodeError as e: - logger.error(f"Failed to decode JSON response from Slack API for {method} {endpoint}: {e}", exc_info=True) - raise HTTPException(status_code=502, detail="Invalid response received from Slack API.") - except Exception as e: # Catch other unexpected errors - logger.exception(f"Unexpected error during Slack request for {method} {endpoint}: {e}") # Use logger.exception to include traceback - raise HTTPException(status_code=500, detail=f"An internal server error occurred: {type(e).__name__}") - - async def get_channel_history(self, args: GetChannelHistoryArgs) -> Dict[str, Any]: - params = {"channel": args.channel_id, "limit": args.limit} - return await self._request("GET", "conversations.history", params=params) - - async def get_channels(self, args: ListChannelsArgs) -> Dict[str, Any]: - limit = args.limit - cursor = args.cursor - - async def fetch_channel_with_history(channel_id: str) -> Dict[str, Any]: - # First get channel info - channel_info = await self._request("GET", "conversations.info", params={"channel": channel_id}) - if not channel_info.get("ok") or channel_info.get("channel", {}).get("is_archived"): - return None - - channel_data = channel_info["channel"] - - # Then get channel history - try: - history = await self._request( - "GET", - "conversations.history", - params={ - "channel": channel_id, - "limit": 1 # Fetch minimal history by default to speed up get_channels. Consider asyncio.gather for concurrency. - } + # ---------------- private helpers ---------------- # + async def _request( + self, + method: str, + endpoint: str, + *, + params: Optional[Dict] = None, + json_data: Optional[Dict] = None, + ) -> Dict[str, Any]: + try: + response = await self._client.request(method, endpoint, params=params, json=json_data) + response.raise_for_status() + data = response.json() + if not data.get("ok"): + error_msg = data.get("error", "Unknown Slack API error") + raise HTTPException(status_code=400, detail={"slack_error": error_msg}) + return data + except httpx.HTTPStatusError as e: + if e.response.status_code == 429: + retry_after = e.response.headers.get("Retry-After") + detail = ( + f"Slack API rate limit exceeded. Retry after {retry_after} seconds." + if retry_after + else "Slack API rate limit exceeded." ) - # Add history to channel data - if history.get("ok"): - channel_data["history"] = history.get("messages", []) - except Exception as e: # Catch errors during history fetch but don't fail the whole channel list - logger.warning(f"Error fetching history for channel {channel_id}: {e}", exc_info=True) - channel_data["history"] = [] # Ensure history key exists even if fetch fails + logger.warning("Rate limit hit: %s", detail) + raise HTTPException(status_code=429, detail=detail, headers={"Retry-After": retry_after} if retry_after else {}) + logger.error("HTTP Error %s - %s", e.response.status_code, e.response.text, exc_info=True) + raise HTTPException(status_code=e.response.status_code, detail="Slack API HTTP Error") + except httpx.RequestError as e: + logger.error("Request Error connecting to Slack API: %s", e, exc_info=True) + raise HTTPException(status_code=503, detail=f"Could not connect to Slack API: {e}") + except json.JSONDecodeError as e: + logger.error("Failed to decode JSON: %s", e, exc_info=True) + raise HTTPException(status_code=502, detail="Invalid JSON from Slack API") + except Exception as e: # noqa: BLE001 + logger.exception("Unexpected error during Slack request: %s", e) + raise HTTPException(status_code=500, detail=f"Internal error: {type(e).__name__}") - return channel_data + # ---------------- public helpers ---------------- # + async def channel_with_history(self, channel_id: str, *, history_limit: int = 1) -> Optional[Dict[str, Any]]: + """Return channel metadata plus ≤ ``history_limit`` recent messages, or None.""" + try: + info = await self._request("GET", "conversations.info", params={"channel": channel_id}) + chan = info["channel"] + if chan.get("is_archived"): + return None + hist = await self._request( + "GET", + "conversations.history", + params={"channel": channel_id, "limit": history_limit}, + ) + chan["history"] = hist.get("messages", []) + return chan + except Exception as exc: # noqa: BLE001 + logger.warning("Skipping channel %s – %s", channel_id, exc, exc_info=True) + return None + # ---------------- API surface ---------------- # + async def get_channel_history(self, args: GetChannelHistoryArgs) -> Dict[str, Any]: + return await self._request("GET", "conversations.history", params={"channel": args.channel_id, "limit": args.limit}) + + async def get_channels(self, args: ListChannelsArgs) -> Dict[str, Any]: # noqa: C901 – keep cohesive + # 1. decide which ids to fetch if PREDEFINED_CHANNEL_IDS: - channels_info = [] - for channel_id in PREDEFINED_CHANNEL_IDS: - try: - if channel_data := await fetch_channel_with_history(channel_id): - channels_info.append(channel_data) - except Exception as e: # Catch errors fetching predefined channels - logger.warning(f"Could not fetch info for predefined channel {channel_id}: {e}", exc_info=True) - - return { - "ok": True, - "channels": channels_info, - "response_metadata": {"next_cursor": ""} - } + ids = PREDEFINED_CHANNEL_IDS + next_cursor = "" else: - # First get list of channels - params = { + params: Dict[str, Any] = { "types": "public_channel", "exclude_archived": "true", - "limit": min(limit, 200), + "limit": min(args.limit, 200), "team_id": self.team_id, } - if cursor: - params["cursor"] = cursor + if args.cursor: + params["cursor"] = args.cursor + clist = await self._request("GET", "conversations.list", params=params) + ids = [c["id"] for c in clist["channels"]] + next_cursor = clist.get("response_metadata", {}).get("next_cursor", "") - channels_list = await self._request("GET", "conversations.list", params=params) + # 2. fetch metadata + history concurrently under a semaphore + sem = asyncio.Semaphore(10) # adjust parallelism as desired - if not channels_list.get("ok"): - return channels_list + async def guarded(cid: str): + async with sem: + return await self.channel_with_history(cid) - # Then fetch history for each channel - channels_with_history = [] - for channel in channels_list["channels"]: - try: - if channel_data := await fetch_channel_with_history(channel["id"]): - channels_with_history.append(channel_data) - except Exception as e: # Catch errors during history fetch but don't fail the whole channel list - logger.warning(f"Error fetching history for channel {channel['id']}: {e}", exc_info=True) - channel["history"] = [] # Add empty history on error - channels_with_history.append(channel) - - return { - "ok": True, - "channels": channels_with_history, - "response_metadata": channels_list.get("response_metadata", {"next_cursor": ""}) - } + channels = [c for c in await asyncio.gather(*(guarded(cid) for cid in ids)) if c] + return {"ok": True, "channels": channels, "response_metadata": {"next_cursor": next_cursor}} async def post_message(self, args: PostMessageArgs) -> Dict[str, Any]: - payload = {"channel": args.channel_id, "text": args.text} - return await self._request("POST", "chat.postMessage", json_data=payload) + return await self._request("POST", "chat.postMessage", json_data={"channel": args.channel_id, "text": args.text}) async def post_reply(self, args: ReplyToThreadArgs) -> Dict[str, Any]: - payload = {"channel": args.channel_id, "thread_ts": args.thread_ts, "text": args.text} - return await self._request("POST", "chat.postMessage", json_data=payload) + return await self._request( + "POST", + "chat.postMessage", + json_data={"channel": args.channel_id, "thread_ts": args.thread_ts, "text": args.text}, + ) async def add_reaction(self, args: AddReactionArgs) -> Dict[str, Any]: - payload = {"channel": args.channel_id, "timestamp": args.timestamp, "name": args.reaction} - return await self._request("POST", "reactions.add", json_data=payload) + return await self._request( + "POST", + "reactions.add", + json_data={"channel": args.channel_id, "timestamp": args.timestamp, "name": args.reaction}, + ) async def get_thread_replies(self, args: GetThreadRepliesArgs) -> Dict[str, Any]: - params = {"channel": args.channel_id, "ts": args.thread_ts} - return await self._request("GET", "conversations.replies", params=params) + return await self._request("GET", "conversations.replies", params={"channel": args.channel_id, "ts": args.thread_ts}) async def get_users(self, args: GetUsersArgs) -> Dict[str, Any]: - params = { - "limit": min(args.limit, 200), - "team_id": self.team_id, - } + params = {"limit": min(args.limit, 200), "team_id": self.team_id} if args.cursor: params["cursor"] = args.cursor return await self._request("GET", "users.list", params=params) async def get_user_profile(self, args: GetUserProfileArgs) -> Dict[str, Any]: - params = {"user": args.user_id, "include_labels": "true"} - return await self._request("GET", "users.profile.get", params=params) + return await self._request("GET", "users.profile.get", params={"user": args.user_id, "include_labels": "true"}) -# --- Instantiate Slack Client --- + # ---------------- lifecycle ---------------- # + async def aclose(self) -> None: # call on app shutdown + await self._client.aclose() + + +# --------------------------------------------------------------------------- +# Instantiate Slack client +# --------------------------------------------------------------------------- slack_client = SlackClient(token=SLACK_BOT_TOKEN, team_id=SLACK_TEAM_ID) -# --- Tool Definitions & Endpoint Generation --- + +# --------------------------------------------------------------------------- +# Dynamic tool mapping / endpoint generation +# --------------------------------------------------------------------------- TOOL_MAPPING = { "slack_list_channels": { "args_model": ListChannelsArgs, @@ -314,40 +335,45 @@ TOOL_MAPPING = { }, } -# Define a function factory to create endpoint handlers, including API key dependency + +# ---------------- endpoint factory ---------------- # + def create_endpoint_handler(tool_name: str, method: Callable, args_model: Type[BaseModel]): - async def endpoint_handler( - args: args_model = Body(...), - api_key: str = Depends(get_api_key) # Add API key dependency here - ) -> ToolResponse: + async def handler(args: args_model = Body(...), api_key: str = Depends(get_api_key)) -> ToolResponse: # noqa: ANN001 try: result = await method(args=args) return {"content": result} - except HTTPException as e: - raise e - except Exception as e: - logger.exception(f"Error executing tool {tool_name}: {e}") # Use logger.exception here too - raise HTTPException(status_code=500, detail=f"Internal server error: {type(e).__name__}") - return endpoint_handler + except HTTPException: + raise # re‑raise untouched + except Exception as exc: # noqa: BLE001 + logger.exception("Error executing tool %s: %s", tool_name, exc) + raise HTTPException(status_code=500, detail=f"Internal server error: {type(exc).__name__}") -# Register endpoints for each tool -for tool_name, config in TOOL_MAPPING.items(): - handler = create_endpoint_handler( - tool_name=tool_name, - method=config["method"], - args_model=config["args_model"] - ) + return handler + +for name, cfg in TOOL_MAPPING.items(): app.post( - f"/{tool_name}", + f"/{name}", response_model=ToolResponse, - summary=config["description"], - description=f"Executes the {tool_name} tool. Arguments are passed in the request body.", + summary=cfg["description"], + description=f"Executes the {name} tool. Arguments are passed in the request body.", tags=["Slack Tools"], - name=tool_name - )(handler) + name=name, + )(create_endpoint_handler(name, cfg["method"], cfg["args_model"])) -# --- Root Endpoint --- + +# --------------------------------------------------------------------------- +# Lifecycle events +# --------------------------------------------------------------------------- +@app.on_event("shutdown") +async def _close_slack_client(): + await slack_client.aclose() + + +# --------------------------------------------------------------------------- +# Root endpoint +# --------------------------------------------------------------------------- @app.get("/", summary="Root endpoint", include_in_schema=False) async def read_root(): return {"message": "Slack API Server is running. See /docs for available tool endpoints."}