mirror of
https://github.com/open-webui/openapi-servers
synced 2025-06-26 18:17:04 +00:00
380 lines
16 KiB
Python
380 lines
16 KiB
Python
"""Slack MCP Server – high‑performance version
|
||
------------------------------------------------
|
||
Showcase‑level code quality and pythonic clarity.
|
||
"""
|
||
|
||
import os
|
||
import asyncio
|
||
import logging
|
||
import json # For JSONDecodeError
|
||
from typing import Optional, List, Dict, Any, Type, Callable
|
||
|
||
import httpx
|
||
from dotenv import load_dotenv
|
||
from fastapi import FastAPI, HTTPException, Body, Depends, Security
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.security import APIKeyHeader
|
||
from pydantic import BaseModel, Field
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Logging
|
||
# ---------------------------------------------------------------------------
|
||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Environment variables
|
||
# ---------------------------------------------------------------------------
|
||
load_dotenv()
|
||
|
||
SLACK_BOT_TOKEN = os.getenv("SLACK_BOT_TOKEN")
|
||
SLACK_TEAM_ID = os.getenv("SLACK_TEAM_ID")
|
||
SLACK_CHANNEL_IDS_STR = os.getenv("SLACK_CHANNEL_IDS") # Optional
|
||
ALLOWED_ORIGINS_STR = os.getenv("ALLOWED_ORIGINS", "*")
|
||
SERVER_API_KEY = os.getenv("SERVER_API_KEY") # Optional API key for security
|
||
|
||
if not SLACK_BOT_TOKEN:
|
||
logger.critical("SLACK_BOT_TOKEN environment variable not set.")
|
||
raise ValueError("SLACK_BOT_TOKEN environment variable not set.")
|
||
if not SLACK_TEAM_ID:
|
||
logger.critical("SLACK_TEAM_ID environment variable not set.")
|
||
raise ValueError("SLACK_TEAM_ID environment variable not set.")
|
||
|
||
PREDEFINED_CHANNEL_IDS: Optional[List[str]] = (
|
||
[cid.strip() for cid in SLACK_CHANNEL_IDS_STR.split(",")] if SLACK_CHANNEL_IDS_STR else None
|
||
)
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# FastAPI app setup
|
||
# ---------------------------------------------------------------------------
|
||
app = FastAPI(
|
||
title="Slack API Server",
|
||
version="1.0.0",
|
||
description="FastAPI server providing Slack functionalities via specific, dynamically generated tool endpoints.",
|
||
)
|
||
|
||
# CORS
|
||
allow_origins = [origin.strip() for origin in ALLOWED_ORIGINS_STR.split(",")]
|
||
if allow_origins == ["*"]:
|
||
logger.warning("CORS allow_origins is set to '*' which is insecure for production. Consider setting ALLOWED_ORIGINS environment variable.")
|
||
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=allow_origins,
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# API key security
|
||
# ---------------------------------------------------------------------------
|
||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||
|
||
|
||
async def get_api_key(key: str = Security(api_key_header)):
|
||
if SERVER_API_KEY:
|
||
if not key:
|
||
logger.warning("API Key required but not provided in X-API-Key header.")
|
||
raise HTTPException(status_code=401, detail="X-API-Key header required")
|
||
if key != SERVER_API_KEY:
|
||
logger.warning("Invalid API Key provided.")
|
||
raise HTTPException(status_code=401, detail="Invalid API Key")
|
||
return key # May be None when not required
|
||
|
||
|
||
if not SERVER_API_KEY:
|
||
logger.warning("SERVER_API_KEY environment variable is not set. Server will allow unauthenticated requests.")
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Pydantic models (arguments & responses)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class ListChannelsArgs(BaseModel):
|
||
limit: Optional[int] = Field(100, description="Maximum number of channels to return (default 100, max 200)")
|
||
cursor: Optional[str] = Field(None, description="Pagination cursor for next page of results")
|
||
|
||
|
||
class PostMessageArgs(BaseModel):
|
||
channel_id: str = Field(..., description="The ID of the channel to post to")
|
||
text: str = Field(..., description="The message text to post")
|
||
|
||
|
||
class ReplyToThreadArgs(BaseModel):
|
||
channel_id: str = Field(..., description="The ID of the channel containing the thread")
|
||
thread_ts: str = Field(..., description="The timestamp of the parent message (e.g., '1234567890.123456')")
|
||
text: str = Field(..., description="The reply text")
|
||
|
||
|
||
class AddReactionArgs(BaseModel):
|
||
channel_id: str = Field(..., description="The ID of the channel containing the message")
|
||
timestamp: str = Field(..., description="The timestamp of the message to react to")
|
||
reaction: str = Field(..., description="The name of the emoji reaction (without colons)")
|
||
|
||
|
||
class GetChannelHistoryArgs(BaseModel):
|
||
channel_id: str = Field(..., description="The ID of the channel")
|
||
limit: Optional[int] = Field(10, description="Number of messages to retrieve (default 10)")
|
||
|
||
|
||
class GetThreadRepliesArgs(BaseModel):
|
||
channel_id: str = Field(..., description="The ID of the channel containing the thread")
|
||
thread_ts: str = Field(..., description="The timestamp of the parent message (e.g., '1234567890.123456')")
|
||
|
||
|
||
class GetUsersArgs(BaseModel):
|
||
cursor: Optional[str] = Field(None, description="Pagination cursor for next page of results")
|
||
limit: Optional[int] = Field(100, description="Maximum number of users to return (default 100, max 200)")
|
||
|
||
|
||
class GetUserProfileArgs(BaseModel):
|
||
user_id: str = Field(..., description="The ID of the user")
|
||
|
||
|
||
class ToolResponse(BaseModel):
|
||
content: Dict[str, Any] = Field(..., description="The JSON response from the Slack API call")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Slack client (high‑performance)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class SlackClient:
|
||
"""Thin async wrapper over Slack Web API with connection‑pool reuse."""
|
||
|
||
BASE_URL = "https://slack.com/api/"
|
||
|
||
def __init__(self, token: str, team_id: str, *, max_connections: int = 20):
|
||
self.team_id = team_id
|
||
self.headers = {
|
||
"Authorization": f"Bearer {token}",
|
||
"Content-Type": "application/json; charset=utf-8",
|
||
}
|
||
limits = httpx.Limits(max_connections=max_connections, max_keepalive_connections=max_connections)
|
||
self._client = httpx.AsyncClient(
|
||
base_url=self.BASE_URL,
|
||
headers=self.headers,
|
||
limits=limits,
|
||
http2=True,
|
||
timeout=10,
|
||
)
|
||
|
||
# ---------------- private helpers ---------------- #
|
||
async def _request(
|
||
self,
|
||
method: str,
|
||
endpoint: str,
|
||
*,
|
||
params: Optional[Dict] = None,
|
||
json_data: Optional[Dict] = None,
|
||
) -> Dict[str, Any]:
|
||
try:
|
||
response = await self._client.request(method, endpoint, params=params, json=json_data)
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
if not data.get("ok"):
|
||
error_msg = data.get("error", "Unknown Slack API error")
|
||
raise HTTPException(status_code=400, detail={"slack_error": error_msg})
|
||
return data
|
||
except httpx.HTTPStatusError as e:
|
||
if e.response.status_code == 429:
|
||
retry_after = e.response.headers.get("Retry-After")
|
||
detail = (
|
||
f"Slack API rate limit exceeded. Retry after {retry_after} seconds."
|
||
if retry_after
|
||
else "Slack API rate limit exceeded."
|
||
)
|
||
logger.warning("Rate limit hit: %s", detail)
|
||
raise HTTPException(status_code=429, detail=detail, headers={"Retry-After": retry_after} if retry_after else {})
|
||
logger.error("HTTP Error %s - %s", e.response.status_code, e.response.text, exc_info=True)
|
||
raise HTTPException(status_code=e.response.status_code, detail="Slack API HTTP Error")
|
||
except httpx.RequestError as e:
|
||
logger.error("Request Error connecting to Slack API: %s", e, exc_info=True)
|
||
raise HTTPException(status_code=503, detail=f"Could not connect to Slack API: {e}")
|
||
except json.JSONDecodeError as e:
|
||
logger.error("Failed to decode JSON: %s", e, exc_info=True)
|
||
raise HTTPException(status_code=502, detail="Invalid JSON from Slack API")
|
||
except Exception as e: # noqa: BLE001
|
||
logger.exception("Unexpected error during Slack request: %s", e)
|
||
raise HTTPException(status_code=500, detail=f"Internal error: {type(e).__name__}")
|
||
|
||
# ---------------- public helpers ---------------- #
|
||
async def channel_with_history(self, channel_id: str, *, history_limit: int = 1) -> Optional[Dict[str, Any]]:
|
||
"""Return channel metadata plus ≤ ``history_limit`` recent messages, or None."""
|
||
try:
|
||
info = await self._request("GET", "conversations.info", params={"channel": channel_id})
|
||
chan = info["channel"]
|
||
if chan.get("is_archived"):
|
||
return None
|
||
hist = await self._request(
|
||
"GET",
|
||
"conversations.history",
|
||
params={"channel": channel_id, "limit": history_limit},
|
||
)
|
||
chan["history"] = hist.get("messages", [])
|
||
return chan
|
||
except Exception as exc: # noqa: BLE001
|
||
logger.warning("Skipping channel %s – %s", channel_id, exc, exc_info=True)
|
||
return None
|
||
|
||
# ---------------- API surface ---------------- #
|
||
async def get_channel_history(self, args: GetChannelHistoryArgs) -> Dict[str, Any]:
|
||
return await self._request("GET", "conversations.history", params={"channel": args.channel_id, "limit": args.limit})
|
||
|
||
async def get_channels(self, args: ListChannelsArgs) -> Dict[str, Any]: # noqa: C901 – keep cohesive
|
||
# 1. decide which ids to fetch
|
||
if PREDEFINED_CHANNEL_IDS:
|
||
ids = PREDEFINED_CHANNEL_IDS
|
||
next_cursor = ""
|
||
else:
|
||
params: Dict[str, Any] = {
|
||
"types": "public_channel",
|
||
"exclude_archived": "true",
|
||
"limit": min(args.limit, 200),
|
||
"team_id": self.team_id,
|
||
}
|
||
if args.cursor:
|
||
params["cursor"] = args.cursor
|
||
clist = await self._request("GET", "conversations.list", params=params)
|
||
ids = [c["id"] for c in clist["channels"]]
|
||
next_cursor = clist.get("response_metadata", {}).get("next_cursor", "")
|
||
|
||
# 2. fetch metadata + history concurrently under a semaphore
|
||
sem = asyncio.Semaphore(10) # adjust parallelism as desired
|
||
|
||
async def guarded(cid: str):
|
||
async with sem:
|
||
return await self.channel_with_history(cid)
|
||
|
||
channels = [c for c in await asyncio.gather(*(guarded(cid) for cid in ids)) if c]
|
||
return {"ok": True, "channels": channels, "response_metadata": {"next_cursor": next_cursor}}
|
||
|
||
async def post_message(self, args: PostMessageArgs) -> Dict[str, Any]:
|
||
return await self._request("POST", "chat.postMessage", json_data={"channel": args.channel_id, "text": args.text})
|
||
|
||
async def post_reply(self, args: ReplyToThreadArgs) -> Dict[str, Any]:
|
||
return await self._request(
|
||
"POST",
|
||
"chat.postMessage",
|
||
json_data={"channel": args.channel_id, "thread_ts": args.thread_ts, "text": args.text},
|
||
)
|
||
|
||
async def add_reaction(self, args: AddReactionArgs) -> Dict[str, Any]:
|
||
return await self._request(
|
||
"POST",
|
||
"reactions.add",
|
||
json_data={"channel": args.channel_id, "timestamp": args.timestamp, "name": args.reaction},
|
||
)
|
||
|
||
async def get_thread_replies(self, args: GetThreadRepliesArgs) -> Dict[str, Any]:
|
||
return await self._request("GET", "conversations.replies", params={"channel": args.channel_id, "ts": args.thread_ts})
|
||
|
||
async def get_users(self, args: GetUsersArgs) -> Dict[str, Any]:
|
||
params = {"limit": min(args.limit, 200), "team_id": self.team_id}
|
||
if args.cursor:
|
||
params["cursor"] = args.cursor
|
||
return await self._request("GET", "users.list", params=params)
|
||
|
||
async def get_user_profile(self, args: GetUserProfileArgs) -> Dict[str, Any]:
|
||
return await self._request("GET", "users.profile.get", params={"user": args.user_id, "include_labels": "true"})
|
||
|
||
# ---------------- lifecycle ---------------- #
|
||
async def aclose(self) -> None: # call on app shutdown
|
||
await self._client.aclose()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Instantiate Slack client
|
||
# ---------------------------------------------------------------------------
|
||
slack_client = SlackClient(token=SLACK_BOT_TOKEN, team_id=SLACK_TEAM_ID)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Dynamic tool mapping / endpoint generation
|
||
# ---------------------------------------------------------------------------
|
||
TOOL_MAPPING = {
|
||
"slack_list_channels": {
|
||
"args_model": ListChannelsArgs,
|
||
"method": slack_client.get_channels,
|
||
"description": "List public or pre-defined channels in the workspace with pagination",
|
||
},
|
||
"slack_post_message": {
|
||
"args_model": PostMessageArgs,
|
||
"method": slack_client.post_message,
|
||
"description": "Post a new message to a Slack channel",
|
||
},
|
||
"slack_reply_to_thread": {
|
||
"args_model": ReplyToThreadArgs,
|
||
"method": slack_client.post_reply,
|
||
"description": "Reply to a specific message thread in Slack",
|
||
},
|
||
"slack_add_reaction": {
|
||
"args_model": AddReactionArgs,
|
||
"method": slack_client.add_reaction,
|
||
"description": "Add a reaction emoji to a message",
|
||
},
|
||
"slack_get_channel_history": {
|
||
"args_model": GetChannelHistoryArgs,
|
||
"method": slack_client.get_channel_history,
|
||
"description": "Get recent messages from a channel",
|
||
},
|
||
"slack_get_thread_replies": {
|
||
"args_model": GetThreadRepliesArgs,
|
||
"method": slack_client.get_thread_replies,
|
||
"description": "Get all replies in a message thread",
|
||
},
|
||
"slack_get_users": {
|
||
"args_model": GetUsersArgs,
|
||
"method": slack_client.get_users,
|
||
"description": "Get a list of all users in the workspace with their basic profile information",
|
||
},
|
||
"slack_get_user_profile": {
|
||
"args_model": GetUserProfileArgs,
|
||
"method": slack_client.get_user_profile,
|
||
"description": "Get detailed profile information for a specific user",
|
||
},
|
||
}
|
||
|
||
|
||
# ---------------- endpoint factory ---------------- #
|
||
|
||
def create_endpoint_handler(tool_name: str, method: Callable, args_model: Type[BaseModel]):
|
||
async def handler(args: args_model = Body(...), api_key: str = Depends(get_api_key)) -> ToolResponse: # noqa: ANN001
|
||
try:
|
||
result = await method(args=args)
|
||
return {"content": result}
|
||
except HTTPException:
|
||
raise # re‑raise untouched
|
||
except Exception as exc: # noqa: BLE001
|
||
logger.exception("Error executing tool %s: %s", tool_name, exc)
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {type(exc).__name__}")
|
||
|
||
return handler
|
||
|
||
|
||
for name, cfg in TOOL_MAPPING.items():
|
||
app.post(
|
||
f"/{name}",
|
||
response_model=ToolResponse,
|
||
summary=cfg["description"],
|
||
description=f"Executes the {name} tool. Arguments are passed in the request body.",
|
||
tags=["Slack Tools"],
|
||
name=name,
|
||
)(create_endpoint_handler(name, cfg["method"], cfg["args_model"]))
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Lifecycle events
|
||
# ---------------------------------------------------------------------------
|
||
@app.on_event("shutdown")
|
||
async def _close_slack_client():
|
||
await slack_client.aclose()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Root endpoint
|
||
# ---------------------------------------------------------------------------
|
||
@app.get("/", summary="Root endpoint", include_in_schema=False)
|
||
async def read_root():
|
||
return {"message": "Slack API Server is running. See /docs for available tool endpoints."}
|