refac: reuse shared httpx client and concurrent channel history fetch for major performance gain

This commit is contained in:
Taylor Wilsdon 2025-04-20 19:51:36 -04:00
parent 565a563a77
commit 49900846d0

View File

@ -1,276 +1,297 @@
import os
import httpx
import inspect
import logging
import json # For JSONDecodeError
from typing import Optional, List, Dict, Any, Type, Callable
from fastapi import FastAPI, HTTPException, Body, Depends, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from dotenv import load_dotenv
"""Slack MCP Server highperformance version
------------------------------------------------
Showcaselevel code quality and pythonic clarity.
"""
# --- Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
import os
import asyncio
import logging
import json # For JSONDecodeError
from typing import Optional, List, Dict, Any, Type, Callable
import httpx
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Body, Depends, Security
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import APIKeyHeader
from pydantic import BaseModel, Field
# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Load environment variables from .env file
# ---------------------------------------------------------------------------
# Environment variables
# ---------------------------------------------------------------------------
load_dotenv()
# --- Environment Variable Checks ---
SLACK_BOT_TOKEN = os.getenv("SLACK_BOT_TOKEN")
SLACK_TEAM_ID = os.getenv("SLACK_TEAM_ID")
SLACK_CHANNEL_IDS_STR = os.getenv("SLACK_CHANNEL_IDS") # Optional
ALLOWED_ORIGINS_STR = os.getenv("ALLOWED_ORIGINS", "*") # Default to allow all
SERVER_API_KEY = os.getenv("SERVER_API_KEY") # Optional API key for security
SLACK_CHANNEL_IDS_STR = os.getenv("SLACK_CHANNEL_IDS") # Optional
ALLOWED_ORIGINS_STR = os.getenv("ALLOWED_ORIGINS", "*")
SERVER_API_KEY = os.getenv("SERVER_API_KEY") # Optional API key for security
if not SLACK_BOT_TOKEN:
# Fail fast if essential config is missing
logger.critical("SLACK_BOT_TOKEN environment variable not set.")
raise ValueError("SLACK_BOT_TOKEN environment variable not set.")
if not SLACK_TEAM_ID:
logger.critical("SLACK_TEAM_ID environment variable not set.")
raise ValueError("SLACK_TEAM_ID environment variable not set.")
PREDEFINED_CHANNEL_IDS = [
channel_id.strip()
for channel_id in SLACK_CHANNEL_IDS_STR.split(',')
] if SLACK_CHANNEL_IDS_STR else None
PREDEFINED_CHANNEL_IDS: Optional[List[str]] = (
[cid.strip() for cid in SLACK_CHANNEL_IDS_STR.split(",")] if SLACK_CHANNEL_IDS_STR else None
)
# --- FastAPI App Setup ---
# ---------------------------------------------------------------------------
# FastAPI app setup
# ---------------------------------------------------------------------------
app = FastAPI(
title="Slack API Server",
version="1.0.0",
description="FastAPI server providing Slack functionalities via specific, dynamically generated tool endpoints.",
)
# Configure CORS
allow_origins = [origin.strip() for origin in ALLOWED_ORIGINS_STR.split(',')]
# CORS
allow_origins = [origin.strip() for origin in ALLOWED_ORIGINS_STR.split(",")]
if allow_origins == ["*"]:
logger.warning("CORS allow_origins is set to '*' which is insecure for production. Consider setting ALLOWED_ORIGINS environment variable.")
app.add_middleware(
CORSMiddleware,
allow_origins=allow_origins,
allow_credentials=True, # Allow credentials if origins are specific, adjust if needed
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- API Key Security ---
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) # auto_error=False to handle optional key
# ---------------------------------------------------------------------------
# API key security
# ---------------------------------------------------------------------------
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def get_api_key(key: str = Security(api_key_header)):
if SERVER_API_KEY: # Only enforce key if it's set in the environment
if SERVER_API_KEY:
if not key:
logger.warning("API Key required but not provided in X-API-Key header.")
raise HTTPException(status_code=401, detail="X-API-Key header required")
if key != SERVER_API_KEY:
logger.warning("Invalid API Key provided.")
raise HTTPException(status_code=401, detail="Invalid API Key")
# If key is valid and required, proceed
# If SERVER_API_KEY is not set, allow access without a key
# logger.info("API Key check passed (or not required).") # Optional: Log successful checks
return key # Return the key or None if not required/provided
return key # May be None when not required
if not SERVER_API_KEY:
logger.warning("SERVER_API_KEY environment variable is not set. Server will allow unauthenticated requests.")
# [Previous Pydantic models remain the same...]
# ---------------------------------------------------------------------------
# Pydantic models (arguments & responses)
# ---------------------------------------------------------------------------
class ListChannelsArgs(BaseModel):
limit: Optional[int] = Field(100, description="Maximum number of channels to return (default 100, max 200)")
cursor: Optional[str] = Field(None, description="Pagination cursor for next page of results")
class PostMessageArgs(BaseModel):
channel_id: str = Field(..., description="The ID of the channel to post to")
text: str = Field(..., description="The message text to post")
class ReplyToThreadArgs(BaseModel):
channel_id: str = Field(..., description="The ID of the channel containing the thread")
thread_ts: str = Field(..., description="The timestamp of the parent message (e.g., '1234567890.123456')")
text: str = Field(..., description="The reply text")
class AddReactionArgs(BaseModel):
channel_id: str = Field(..., description="The ID of the channel containing the message")
timestamp: str = Field(..., description="The timestamp of the message to react to")
reaction: str = Field(..., description="The name of the emoji reaction (without colons)")
class GetChannelHistoryArgs(BaseModel):
channel_id: str = Field(..., description="The ID of the channel")
limit: Optional[int] = Field(10, description="Number of messages to retrieve (default 10)")
class GetThreadRepliesArgs(BaseModel):
channel_id: str = Field(..., description="The ID of the channel containing the thread")
thread_ts: str = Field(..., description="The timestamp of the parent message (e.g., '1234567890.123456')")
class GetUsersArgs(BaseModel):
cursor: Optional[str] = Field(None, description="Pagination cursor for next page of results")
limit: Optional[int] = Field(100, description="Maximum number of users to return (default 100, max 200)")
class GetUserProfileArgs(BaseModel):
user_id: str = Field(..., description="The ID of the user")
class ToolResponse(BaseModel):
content: Dict[str, Any] = Field(..., description="The JSON response from the Slack API call")
# --- Slack Client Class ---
# ---------------------------------------------------------------------------
# Slack client (highperformance)
# ---------------------------------------------------------------------------
class SlackClient:
"""Thin async wrapper over Slack Web API with connectionpool reuse."""
BASE_URL = "https://slack.com/api/"
def __init__(self, token: str, team_id: str):
def __init__(self, token: str, team_id: str, *, max_connections: int = 20):
self.team_id = team_id
self.headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json; charset=utf-8",
}
self.team_id = team_id
limits = httpx.Limits(max_connections=max_connections, max_keepalive_connections=max_connections)
self._client = httpx.AsyncClient(
base_url=self.BASE_URL,
headers=self.headers,
limits=limits,
http2=True,
timeout=10,
)
async def _request(self, method: str, endpoint: str, params: Optional[Dict] = None, json_data: Optional[Dict] = None) -> Dict[str, Any]:
async with httpx.AsyncClient(base_url=self.BASE_URL, headers=self.headers) as client:
try:
response = await client.request(method, endpoint, params=params, json=json_data)
response.raise_for_status()
data = response.json()
if not data.get("ok"):
error_msg = data.get("error", "Unknown Slack API error")
# Return the specific Slack error in the response
logger.warning(f"Slack API Error for {method} {endpoint}: {error_msg}")
raise HTTPException(status_code=400, detail={"slack_error": error_msg, "message": f"Slack API returned an error: {error_msg}"})
return data
except httpx.HTTPStatusError as e:
# Handle specific HTTP errors like rate limiting (429)
if e.response.status_code == 429:
retry_after = e.response.headers.get("Retry-After")
detail = f"Slack API rate limit exceeded. Retry after {retry_after} seconds." if retry_after else "Slack API rate limit exceeded."
logger.warning(f"Rate limit hit for {method} {endpoint}. Retry-After: {retry_after}")
raise HTTPException(status_code=429, detail=detail, headers={"Retry-After": retry_after} if retry_after else {})
else:
logger.error(f"HTTP Error: {e.response.status_code} - {e.response.text}", exc_info=True)
raise HTTPException(status_code=e.response.status_code, detail=f"Slack API HTTP Error: Status {e.response.status_code}")
except httpx.RequestError as e:
logger.error(f"Request Error connecting to Slack API: {e}", exc_info=True)
raise HTTPException(status_code=503, detail=f"Could not connect to Slack API: {e}")
except json.JSONDecodeError as e:
logger.error(f"Failed to decode JSON response from Slack API for {method} {endpoint}: {e}", exc_info=True)
raise HTTPException(status_code=502, detail="Invalid response received from Slack API.")
except Exception as e: # Catch other unexpected errors
logger.exception(f"Unexpected error during Slack request for {method} {endpoint}: {e}") # Use logger.exception to include traceback
raise HTTPException(status_code=500, detail=f"An internal server error occurred: {type(e).__name__}")
async def get_channel_history(self, args: GetChannelHistoryArgs) -> Dict[str, Any]:
params = {"channel": args.channel_id, "limit": args.limit}
return await self._request("GET", "conversations.history", params=params)
async def get_channels(self, args: ListChannelsArgs) -> Dict[str, Any]:
limit = args.limit
cursor = args.cursor
async def fetch_channel_with_history(channel_id: str) -> Dict[str, Any]:
# First get channel info
channel_info = await self._request("GET", "conversations.info", params={"channel": channel_id})
if not channel_info.get("ok") or channel_info.get("channel", {}).get("is_archived"):
return None
channel_data = channel_info["channel"]
# Then get channel history
try:
history = await self._request(
"GET",
"conversations.history",
params={
"channel": channel_id,
"limit": 1 # Fetch minimal history by default to speed up get_channels. Consider asyncio.gather for concurrency.
}
# ---------------- private helpers ---------------- #
async def _request(
self,
method: str,
endpoint: str,
*,
params: Optional[Dict] = None,
json_data: Optional[Dict] = None,
) -> Dict[str, Any]:
try:
response = await self._client.request(method, endpoint, params=params, json=json_data)
response.raise_for_status()
data = response.json()
if not data.get("ok"):
error_msg = data.get("error", "Unknown Slack API error")
raise HTTPException(status_code=400, detail={"slack_error": error_msg})
return data
except httpx.HTTPStatusError as e:
if e.response.status_code == 429:
retry_after = e.response.headers.get("Retry-After")
detail = (
f"Slack API rate limit exceeded. Retry after {retry_after} seconds."
if retry_after
else "Slack API rate limit exceeded."
)
# Add history to channel data
if history.get("ok"):
channel_data["history"] = history.get("messages", [])
except Exception as e: # Catch errors during history fetch but don't fail the whole channel list
logger.warning(f"Error fetching history for channel {channel_id}: {e}", exc_info=True)
channel_data["history"] = [] # Ensure history key exists even if fetch fails
logger.warning("Rate limit hit: %s", detail)
raise HTTPException(status_code=429, detail=detail, headers={"Retry-After": retry_after} if retry_after else {})
logger.error("HTTP Error %s - %s", e.response.status_code, e.response.text, exc_info=True)
raise HTTPException(status_code=e.response.status_code, detail="Slack API HTTP Error")
except httpx.RequestError as e:
logger.error("Request Error connecting to Slack API: %s", e, exc_info=True)
raise HTTPException(status_code=503, detail=f"Could not connect to Slack API: {e}")
except json.JSONDecodeError as e:
logger.error("Failed to decode JSON: %s", e, exc_info=True)
raise HTTPException(status_code=502, detail="Invalid JSON from Slack API")
except Exception as e: # noqa: BLE001
logger.exception("Unexpected error during Slack request: %s", e)
raise HTTPException(status_code=500, detail=f"Internal error: {type(e).__name__}")
return channel_data
# ---------------- public helpers ---------------- #
async def channel_with_history(self, channel_id: str, *, history_limit: int = 1) -> Optional[Dict[str, Any]]:
"""Return channel metadata plus ≤ ``history_limit`` recent messages, or None."""
try:
info = await self._request("GET", "conversations.info", params={"channel": channel_id})
chan = info["channel"]
if chan.get("is_archived"):
return None
hist = await self._request(
"GET",
"conversations.history",
params={"channel": channel_id, "limit": history_limit},
)
chan["history"] = hist.get("messages", [])
return chan
except Exception as exc: # noqa: BLE001
logger.warning("Skipping channel %s %s", channel_id, exc, exc_info=True)
return None
# ---------------- API surface ---------------- #
async def get_channel_history(self, args: GetChannelHistoryArgs) -> Dict[str, Any]:
return await self._request("GET", "conversations.history", params={"channel": args.channel_id, "limit": args.limit})
async def get_channels(self, args: ListChannelsArgs) -> Dict[str, Any]: # noqa: C901 keep cohesive
# 1. decide which ids to fetch
if PREDEFINED_CHANNEL_IDS:
channels_info = []
for channel_id in PREDEFINED_CHANNEL_IDS:
try:
if channel_data := await fetch_channel_with_history(channel_id):
channels_info.append(channel_data)
except Exception as e: # Catch errors fetching predefined channels
logger.warning(f"Could not fetch info for predefined channel {channel_id}: {e}", exc_info=True)
return {
"ok": True,
"channels": channels_info,
"response_metadata": {"next_cursor": ""}
}
ids = PREDEFINED_CHANNEL_IDS
next_cursor = ""
else:
# First get list of channels
params = {
params: Dict[str, Any] = {
"types": "public_channel",
"exclude_archived": "true",
"limit": min(limit, 200),
"limit": min(args.limit, 200),
"team_id": self.team_id,
}
if cursor:
params["cursor"] = cursor
if args.cursor:
params["cursor"] = args.cursor
clist = await self._request("GET", "conversations.list", params=params)
ids = [c["id"] for c in clist["channels"]]
next_cursor = clist.get("response_metadata", {}).get("next_cursor", "")
channels_list = await self._request("GET", "conversations.list", params=params)
# 2. fetch metadata + history concurrently under a semaphore
sem = asyncio.Semaphore(10) # adjust parallelism as desired
if not channels_list.get("ok"):
return channels_list
async def guarded(cid: str):
async with sem:
return await self.channel_with_history(cid)
# Then fetch history for each channel
channels_with_history = []
for channel in channels_list["channels"]:
try:
if channel_data := await fetch_channel_with_history(channel["id"]):
channels_with_history.append(channel_data)
except Exception as e: # Catch errors during history fetch but don't fail the whole channel list
logger.warning(f"Error fetching history for channel {channel['id']}: {e}", exc_info=True)
channel["history"] = [] # Add empty history on error
channels_with_history.append(channel)
return {
"ok": True,
"channels": channels_with_history,
"response_metadata": channels_list.get("response_metadata", {"next_cursor": ""})
}
channels = [c for c in await asyncio.gather(*(guarded(cid) for cid in ids)) if c]
return {"ok": True, "channels": channels, "response_metadata": {"next_cursor": next_cursor}}
async def post_message(self, args: PostMessageArgs) -> Dict[str, Any]:
payload = {"channel": args.channel_id, "text": args.text}
return await self._request("POST", "chat.postMessage", json_data=payload)
return await self._request("POST", "chat.postMessage", json_data={"channel": args.channel_id, "text": args.text})
async def post_reply(self, args: ReplyToThreadArgs) -> Dict[str, Any]:
payload = {"channel": args.channel_id, "thread_ts": args.thread_ts, "text": args.text}
return await self._request("POST", "chat.postMessage", json_data=payload)
return await self._request(
"POST",
"chat.postMessage",
json_data={"channel": args.channel_id, "thread_ts": args.thread_ts, "text": args.text},
)
async def add_reaction(self, args: AddReactionArgs) -> Dict[str, Any]:
payload = {"channel": args.channel_id, "timestamp": args.timestamp, "name": args.reaction}
return await self._request("POST", "reactions.add", json_data=payload)
return await self._request(
"POST",
"reactions.add",
json_data={"channel": args.channel_id, "timestamp": args.timestamp, "name": args.reaction},
)
async def get_thread_replies(self, args: GetThreadRepliesArgs) -> Dict[str, Any]:
params = {"channel": args.channel_id, "ts": args.thread_ts}
return await self._request("GET", "conversations.replies", params=params)
return await self._request("GET", "conversations.replies", params={"channel": args.channel_id, "ts": args.thread_ts})
async def get_users(self, args: GetUsersArgs) -> Dict[str, Any]:
params = {
"limit": min(args.limit, 200),
"team_id": self.team_id,
}
params = {"limit": min(args.limit, 200), "team_id": self.team_id}
if args.cursor:
params["cursor"] = args.cursor
return await self._request("GET", "users.list", params=params)
async def get_user_profile(self, args: GetUserProfileArgs) -> Dict[str, Any]:
params = {"user": args.user_id, "include_labels": "true"}
return await self._request("GET", "users.profile.get", params=params)
return await self._request("GET", "users.profile.get", params={"user": args.user_id, "include_labels": "true"})
# --- Instantiate Slack Client ---
# ---------------- lifecycle ---------------- #
async def aclose(self) -> None: # call on app shutdown
await self._client.aclose()
# ---------------------------------------------------------------------------
# Instantiate Slack client
# ---------------------------------------------------------------------------
slack_client = SlackClient(token=SLACK_BOT_TOKEN, team_id=SLACK_TEAM_ID)
# --- Tool Definitions & Endpoint Generation ---
# ---------------------------------------------------------------------------
# Dynamic tool mapping / endpoint generation
# ---------------------------------------------------------------------------
TOOL_MAPPING = {
"slack_list_channels": {
"args_model": ListChannelsArgs,
@ -314,40 +335,45 @@ TOOL_MAPPING = {
},
}
# Define a function factory to create endpoint handlers, including API key dependency
# ---------------- endpoint factory ---------------- #
def create_endpoint_handler(tool_name: str, method: Callable, args_model: Type[BaseModel]):
async def endpoint_handler(
args: args_model = Body(...),
api_key: str = Depends(get_api_key) # Add API key dependency here
) -> ToolResponse:
async def handler(args: args_model = Body(...), api_key: str = Depends(get_api_key)) -> ToolResponse: # noqa: ANN001
try:
result = await method(args=args)
return {"content": result}
except HTTPException as e:
raise e
except Exception as e:
logger.exception(f"Error executing tool {tool_name}: {e}") # Use logger.exception here too
raise HTTPException(status_code=500, detail=f"Internal server error: {type(e).__name__}")
return endpoint_handler
except HTTPException:
raise # reraise untouched
except Exception as exc: # noqa: BLE001
logger.exception("Error executing tool %s: %s", tool_name, exc)
raise HTTPException(status_code=500, detail=f"Internal server error: {type(exc).__name__}")
# Register endpoints for each tool
for tool_name, config in TOOL_MAPPING.items():
handler = create_endpoint_handler(
tool_name=tool_name,
method=config["method"],
args_model=config["args_model"]
)
return handler
for name, cfg in TOOL_MAPPING.items():
app.post(
f"/{tool_name}",
f"/{name}",
response_model=ToolResponse,
summary=config["description"],
description=f"Executes the {tool_name} tool. Arguments are passed in the request body.",
summary=cfg["description"],
description=f"Executes the {name} tool. Arguments are passed in the request body.",
tags=["Slack Tools"],
name=tool_name
)(handler)
name=name,
)(create_endpoint_handler(name, cfg["method"], cfg["args_model"]))
# --- Root Endpoint ---
# ---------------------------------------------------------------------------
# Lifecycle events
# ---------------------------------------------------------------------------
@app.on_event("shutdown")
async def _close_slack_client():
await slack_client.aclose()
# ---------------------------------------------------------------------------
# Root endpoint
# ---------------------------------------------------------------------------
@app.get("/", summary="Root endpoint", include_in_schema=False)
async def read_root():
return {"message": "Slack API Server is running. See /docs for available tool endpoints."}