mirror of
https://github.com/open-webui/openapi-servers
synced 2025-06-26 18:17:04 +00:00
refac: reuse shared httpx client and concurrent channel history fetch for major performance gain
This commit is contained in:
parent
565a563a77
commit
49900846d0
@ -1,276 +1,297 @@
|
|||||||
|
"""Slack MCP Server – high‑performance version
|
||||||
|
------------------------------------------------
|
||||||
|
Showcase‑level code quality and pythonic clarity.
|
||||||
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import httpx
|
import asyncio
|
||||||
import inspect
|
|
||||||
import logging
|
import logging
|
||||||
import json # For JSONDecodeError
|
import json # For JSONDecodeError
|
||||||
from typing import Optional, List, Dict, Any, Type, Callable
|
from typing import Optional, List, Dict, Any, Type, Callable
|
||||||
from fastapi import FastAPI, HTTPException, Body, Depends, Security
|
|
||||||
from fastapi.security import APIKeyHeader
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
# --- Logging Setup ---
|
import httpx
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
from dotenv import load_dotenv
|
||||||
|
from fastapi import FastAPI, HTTPException, Body, Depends, Security
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.security import APIKeyHeader
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Logging
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Load environment variables from .env file
|
# ---------------------------------------------------------------------------
|
||||||
|
# Environment variables
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# --- Environment Variable Checks ---
|
|
||||||
SLACK_BOT_TOKEN = os.getenv("SLACK_BOT_TOKEN")
|
SLACK_BOT_TOKEN = os.getenv("SLACK_BOT_TOKEN")
|
||||||
SLACK_TEAM_ID = os.getenv("SLACK_TEAM_ID")
|
SLACK_TEAM_ID = os.getenv("SLACK_TEAM_ID")
|
||||||
SLACK_CHANNEL_IDS_STR = os.getenv("SLACK_CHANNEL_IDS") # Optional
|
SLACK_CHANNEL_IDS_STR = os.getenv("SLACK_CHANNEL_IDS") # Optional
|
||||||
ALLOWED_ORIGINS_STR = os.getenv("ALLOWED_ORIGINS", "*") # Default to allow all
|
ALLOWED_ORIGINS_STR = os.getenv("ALLOWED_ORIGINS", "*")
|
||||||
SERVER_API_KEY = os.getenv("SERVER_API_KEY") # Optional API key for security
|
SERVER_API_KEY = os.getenv("SERVER_API_KEY") # Optional API key for security
|
||||||
|
|
||||||
if not SLACK_BOT_TOKEN:
|
if not SLACK_BOT_TOKEN:
|
||||||
# Fail fast if essential config is missing
|
|
||||||
logger.critical("SLACK_BOT_TOKEN environment variable not set.")
|
logger.critical("SLACK_BOT_TOKEN environment variable not set.")
|
||||||
raise ValueError("SLACK_BOT_TOKEN environment variable not set.")
|
raise ValueError("SLACK_BOT_TOKEN environment variable not set.")
|
||||||
if not SLACK_TEAM_ID:
|
if not SLACK_TEAM_ID:
|
||||||
logger.critical("SLACK_TEAM_ID environment variable not set.")
|
logger.critical("SLACK_TEAM_ID environment variable not set.")
|
||||||
raise ValueError("SLACK_TEAM_ID environment variable not set.")
|
raise ValueError("SLACK_TEAM_ID environment variable not set.")
|
||||||
|
|
||||||
PREDEFINED_CHANNEL_IDS = [
|
PREDEFINED_CHANNEL_IDS: Optional[List[str]] = (
|
||||||
channel_id.strip()
|
[cid.strip() for cid in SLACK_CHANNEL_IDS_STR.split(",")] if SLACK_CHANNEL_IDS_STR else None
|
||||||
for channel_id in SLACK_CHANNEL_IDS_STR.split(',')
|
)
|
||||||
] if SLACK_CHANNEL_IDS_STR else None
|
|
||||||
|
|
||||||
# --- FastAPI App Setup ---
|
# ---------------------------------------------------------------------------
|
||||||
|
# FastAPI app setup
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Slack API Server",
|
title="Slack API Server",
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
description="FastAPI server providing Slack functionalities via specific, dynamically generated tool endpoints.",
|
description="FastAPI server providing Slack functionalities via specific, dynamically generated tool endpoints.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Configure CORS
|
# CORS
|
||||||
allow_origins = [origin.strip() for origin in ALLOWED_ORIGINS_STR.split(',')]
|
allow_origins = [origin.strip() for origin in ALLOWED_ORIGINS_STR.split(",")]
|
||||||
if allow_origins == ["*"]:
|
if allow_origins == ["*"]:
|
||||||
logger.warning("CORS allow_origins is set to '*' which is insecure for production. Consider setting ALLOWED_ORIGINS environment variable.")
|
logger.warning("CORS allow_origins is set to '*' which is insecure for production. Consider setting ALLOWED_ORIGINS environment variable.")
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=allow_origins,
|
allow_origins=allow_origins,
|
||||||
allow_credentials=True, # Allow credentials if origins are specific, adjust if needed
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- API Key Security ---
|
# ---------------------------------------------------------------------------
|
||||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) # auto_error=False to handle optional key
|
# API key security
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
async def get_api_key(key: str = Security(api_key_header)):
|
async def get_api_key(key: str = Security(api_key_header)):
|
||||||
if SERVER_API_KEY: # Only enforce key if it's set in the environment
|
if SERVER_API_KEY:
|
||||||
if not key:
|
if not key:
|
||||||
logger.warning("API Key required but not provided in X-API-Key header.")
|
logger.warning("API Key required but not provided in X-API-Key header.")
|
||||||
raise HTTPException(status_code=401, detail="X-API-Key header required")
|
raise HTTPException(status_code=401, detail="X-API-Key header required")
|
||||||
if key != SERVER_API_KEY:
|
if key != SERVER_API_KEY:
|
||||||
logger.warning("Invalid API Key provided.")
|
logger.warning("Invalid API Key provided.")
|
||||||
raise HTTPException(status_code=401, detail="Invalid API Key")
|
raise HTTPException(status_code=401, detail="Invalid API Key")
|
||||||
# If key is valid and required, proceed
|
return key # May be None when not required
|
||||||
# If SERVER_API_KEY is not set, allow access without a key
|
|
||||||
# logger.info("API Key check passed (or not required).") # Optional: Log successful checks
|
|
||||||
return key # Return the key or None if not required/provided
|
|
||||||
|
|
||||||
if not SERVER_API_KEY:
|
if not SERVER_API_KEY:
|
||||||
logger.warning("SERVER_API_KEY environment variable is not set. Server will allow unauthenticated requests.")
|
logger.warning("SERVER_API_KEY environment variable is not set. Server will allow unauthenticated requests.")
|
||||||
|
|
||||||
# [Previous Pydantic models remain the same...]
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pydantic models (arguments & responses)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
class ListChannelsArgs(BaseModel):
|
class ListChannelsArgs(BaseModel):
|
||||||
limit: Optional[int] = Field(100, description="Maximum number of channels to return (default 100, max 200)")
|
limit: Optional[int] = Field(100, description="Maximum number of channels to return (default 100, max 200)")
|
||||||
cursor: Optional[str] = Field(None, description="Pagination cursor for next page of results")
|
cursor: Optional[str] = Field(None, description="Pagination cursor for next page of results")
|
||||||
|
|
||||||
|
|
||||||
class PostMessageArgs(BaseModel):
|
class PostMessageArgs(BaseModel):
|
||||||
channel_id: str = Field(..., description="The ID of the channel to post to")
|
channel_id: str = Field(..., description="The ID of the channel to post to")
|
||||||
text: str = Field(..., description="The message text to post")
|
text: str = Field(..., description="The message text to post")
|
||||||
|
|
||||||
|
|
||||||
class ReplyToThreadArgs(BaseModel):
|
class ReplyToThreadArgs(BaseModel):
|
||||||
channel_id: str = Field(..., description="The ID of the channel containing the thread")
|
channel_id: str = Field(..., description="The ID of the channel containing the thread")
|
||||||
thread_ts: str = Field(..., description="The timestamp of the parent message (e.g., '1234567890.123456')")
|
thread_ts: str = Field(..., description="The timestamp of the parent message (e.g., '1234567890.123456')")
|
||||||
text: str = Field(..., description="The reply text")
|
text: str = Field(..., description="The reply text")
|
||||||
|
|
||||||
|
|
||||||
class AddReactionArgs(BaseModel):
|
class AddReactionArgs(BaseModel):
|
||||||
channel_id: str = Field(..., description="The ID of the channel containing the message")
|
channel_id: str = Field(..., description="The ID of the channel containing the message")
|
||||||
timestamp: str = Field(..., description="The timestamp of the message to react to")
|
timestamp: str = Field(..., description="The timestamp of the message to react to")
|
||||||
reaction: str = Field(..., description="The name of the emoji reaction (without colons)")
|
reaction: str = Field(..., description="The name of the emoji reaction (without colons)")
|
||||||
|
|
||||||
|
|
||||||
class GetChannelHistoryArgs(BaseModel):
|
class GetChannelHistoryArgs(BaseModel):
|
||||||
channel_id: str = Field(..., description="The ID of the channel")
|
channel_id: str = Field(..., description="The ID of the channel")
|
||||||
limit: Optional[int] = Field(10, description="Number of messages to retrieve (default 10)")
|
limit: Optional[int] = Field(10, description="Number of messages to retrieve (default 10)")
|
||||||
|
|
||||||
|
|
||||||
class GetThreadRepliesArgs(BaseModel):
|
class GetThreadRepliesArgs(BaseModel):
|
||||||
channel_id: str = Field(..., description="The ID of the channel containing the thread")
|
channel_id: str = Field(..., description="The ID of the channel containing the thread")
|
||||||
thread_ts: str = Field(..., description="The timestamp of the parent message (e.g., '1234567890.123456')")
|
thread_ts: str = Field(..., description="The timestamp of the parent message (e.g., '1234567890.123456')")
|
||||||
|
|
||||||
|
|
||||||
class GetUsersArgs(BaseModel):
|
class GetUsersArgs(BaseModel):
|
||||||
cursor: Optional[str] = Field(None, description="Pagination cursor for next page of results")
|
cursor: Optional[str] = Field(None, description="Pagination cursor for next page of results")
|
||||||
limit: Optional[int] = Field(100, description="Maximum number of users to return (default 100, max 200)")
|
limit: Optional[int] = Field(100, description="Maximum number of users to return (default 100, max 200)")
|
||||||
|
|
||||||
|
|
||||||
class GetUserProfileArgs(BaseModel):
|
class GetUserProfileArgs(BaseModel):
|
||||||
user_id: str = Field(..., description="The ID of the user")
|
user_id: str = Field(..., description="The ID of the user")
|
||||||
|
|
||||||
|
|
||||||
class ToolResponse(BaseModel):
|
class ToolResponse(BaseModel):
|
||||||
content: Dict[str, Any] = Field(..., description="The JSON response from the Slack API call")
|
content: Dict[str, Any] = Field(..., description="The JSON response from the Slack API call")
|
||||||
|
|
||||||
# --- Slack Client Class ---
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Slack client (high‑performance)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
class SlackClient:
|
class SlackClient:
|
||||||
|
"""Thin async wrapper over Slack Web API with connection‑pool reuse."""
|
||||||
|
|
||||||
BASE_URL = "https://slack.com/api/"
|
BASE_URL = "https://slack.com/api/"
|
||||||
|
|
||||||
def __init__(self, token: str, team_id: str):
|
def __init__(self, token: str, team_id: str, *, max_connections: int = 20):
|
||||||
|
self.team_id = team_id
|
||||||
self.headers = {
|
self.headers = {
|
||||||
"Authorization": f"Bearer {token}",
|
"Authorization": f"Bearer {token}",
|
||||||
"Content-Type": "application/json; charset=utf-8",
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
}
|
}
|
||||||
self.team_id = team_id
|
limits = httpx.Limits(max_connections=max_connections, max_keepalive_connections=max_connections)
|
||||||
|
self._client = httpx.AsyncClient(
|
||||||
|
base_url=self.BASE_URL,
|
||||||
|
headers=self.headers,
|
||||||
|
limits=limits,
|
||||||
|
http2=True,
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
|
||||||
async def _request(self, method: str, endpoint: str, params: Optional[Dict] = None, json_data: Optional[Dict] = None) -> Dict[str, Any]:
|
# ---------------- private helpers ---------------- #
|
||||||
async with httpx.AsyncClient(base_url=self.BASE_URL, headers=self.headers) as client:
|
async def _request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
endpoint: str,
|
||||||
|
*,
|
||||||
|
params: Optional[Dict] = None,
|
||||||
|
json_data: Optional[Dict] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
response = await client.request(method, endpoint, params=params, json=json_data)
|
response = await self._client.request(method, endpoint, params=params, json=json_data)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
if not data.get("ok"):
|
if not data.get("ok"):
|
||||||
error_msg = data.get("error", "Unknown Slack API error")
|
error_msg = data.get("error", "Unknown Slack API error")
|
||||||
# Return the specific Slack error in the response
|
raise HTTPException(status_code=400, detail={"slack_error": error_msg})
|
||||||
logger.warning(f"Slack API Error for {method} {endpoint}: {error_msg}")
|
|
||||||
raise HTTPException(status_code=400, detail={"slack_error": error_msg, "message": f"Slack API returned an error: {error_msg}"})
|
|
||||||
return data
|
return data
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
# Handle specific HTTP errors like rate limiting (429)
|
|
||||||
if e.response.status_code == 429:
|
if e.response.status_code == 429:
|
||||||
retry_after = e.response.headers.get("Retry-After")
|
retry_after = e.response.headers.get("Retry-After")
|
||||||
detail = f"Slack API rate limit exceeded. Retry after {retry_after} seconds." if retry_after else "Slack API rate limit exceeded."
|
detail = (
|
||||||
logger.warning(f"Rate limit hit for {method} {endpoint}. Retry-After: {retry_after}")
|
f"Slack API rate limit exceeded. Retry after {retry_after} seconds."
|
||||||
|
if retry_after
|
||||||
|
else "Slack API rate limit exceeded."
|
||||||
|
)
|
||||||
|
logger.warning("Rate limit hit: %s", detail)
|
||||||
raise HTTPException(status_code=429, detail=detail, headers={"Retry-After": retry_after} if retry_after else {})
|
raise HTTPException(status_code=429, detail=detail, headers={"Retry-After": retry_after} if retry_after else {})
|
||||||
else:
|
logger.error("HTTP Error %s - %s", e.response.status_code, e.response.text, exc_info=True)
|
||||||
logger.error(f"HTTP Error: {e.response.status_code} - {e.response.text}", exc_info=True)
|
raise HTTPException(status_code=e.response.status_code, detail="Slack API HTTP Error")
|
||||||
raise HTTPException(status_code=e.response.status_code, detail=f"Slack API HTTP Error: Status {e.response.status_code}")
|
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
logger.error(f"Request Error connecting to Slack API: {e}", exc_info=True)
|
logger.error("Request Error connecting to Slack API: %s", e, exc_info=True)
|
||||||
raise HTTPException(status_code=503, detail=f"Could not connect to Slack API: {e}")
|
raise HTTPException(status_code=503, detail=f"Could not connect to Slack API: {e}")
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.error(f"Failed to decode JSON response from Slack API for {method} {endpoint}: {e}", exc_info=True)
|
logger.error("Failed to decode JSON: %s", e, exc_info=True)
|
||||||
raise HTTPException(status_code=502, detail="Invalid response received from Slack API.")
|
raise HTTPException(status_code=502, detail="Invalid JSON from Slack API")
|
||||||
except Exception as e: # Catch other unexpected errors
|
except Exception as e: # noqa: BLE001
|
||||||
logger.exception(f"Unexpected error during Slack request for {method} {endpoint}: {e}") # Use logger.exception to include traceback
|
logger.exception("Unexpected error during Slack request: %s", e)
|
||||||
raise HTTPException(status_code=500, detail=f"An internal server error occurred: {type(e).__name__}")
|
raise HTTPException(status_code=500, detail=f"Internal error: {type(e).__name__}")
|
||||||
|
|
||||||
async def get_channel_history(self, args: GetChannelHistoryArgs) -> Dict[str, Any]:
|
# ---------------- public helpers ---------------- #
|
||||||
params = {"channel": args.channel_id, "limit": args.limit}
|
async def channel_with_history(self, channel_id: str, *, history_limit: int = 1) -> Optional[Dict[str, Any]]:
|
||||||
return await self._request("GET", "conversations.history", params=params)
|
"""Return channel metadata plus ≤ ``history_limit`` recent messages, or None."""
|
||||||
|
|
||||||
async def get_channels(self, args: ListChannelsArgs) -> Dict[str, Any]:
|
|
||||||
limit = args.limit
|
|
||||||
cursor = args.cursor
|
|
||||||
|
|
||||||
async def fetch_channel_with_history(channel_id: str) -> Dict[str, Any]:
|
|
||||||
# First get channel info
|
|
||||||
channel_info = await self._request("GET", "conversations.info", params={"channel": channel_id})
|
|
||||||
if not channel_info.get("ok") or channel_info.get("channel", {}).get("is_archived"):
|
|
||||||
return None
|
|
||||||
|
|
||||||
channel_data = channel_info["channel"]
|
|
||||||
|
|
||||||
# Then get channel history
|
|
||||||
try:
|
try:
|
||||||
history = await self._request(
|
info = await self._request("GET", "conversations.info", params={"channel": channel_id})
|
||||||
|
chan = info["channel"]
|
||||||
|
if chan.get("is_archived"):
|
||||||
|
return None
|
||||||
|
hist = await self._request(
|
||||||
"GET",
|
"GET",
|
||||||
"conversations.history",
|
"conversations.history",
|
||||||
params={
|
params={"channel": channel_id, "limit": history_limit},
|
||||||
"channel": channel_id,
|
|
||||||
"limit": 1 # Fetch minimal history by default to speed up get_channels. Consider asyncio.gather for concurrency.
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
# Add history to channel data
|
chan["history"] = hist.get("messages", [])
|
||||||
if history.get("ok"):
|
return chan
|
||||||
channel_data["history"] = history.get("messages", [])
|
except Exception as exc: # noqa: BLE001
|
||||||
except Exception as e: # Catch errors during history fetch but don't fail the whole channel list
|
logger.warning("Skipping channel %s – %s", channel_id, exc, exc_info=True)
|
||||||
logger.warning(f"Error fetching history for channel {channel_id}: {e}", exc_info=True)
|
return None
|
||||||
channel_data["history"] = [] # Ensure history key exists even if fetch fails
|
|
||||||
|
|
||||||
return channel_data
|
# ---------------- API surface ---------------- #
|
||||||
|
async def get_channel_history(self, args: GetChannelHistoryArgs) -> Dict[str, Any]:
|
||||||
|
return await self._request("GET", "conversations.history", params={"channel": args.channel_id, "limit": args.limit})
|
||||||
|
|
||||||
|
async def get_channels(self, args: ListChannelsArgs) -> Dict[str, Any]: # noqa: C901 – keep cohesive
|
||||||
|
# 1. decide which ids to fetch
|
||||||
if PREDEFINED_CHANNEL_IDS:
|
if PREDEFINED_CHANNEL_IDS:
|
||||||
channels_info = []
|
ids = PREDEFINED_CHANNEL_IDS
|
||||||
for channel_id in PREDEFINED_CHANNEL_IDS:
|
next_cursor = ""
|
||||||
try:
|
|
||||||
if channel_data := await fetch_channel_with_history(channel_id):
|
|
||||||
channels_info.append(channel_data)
|
|
||||||
except Exception as e: # Catch errors fetching predefined channels
|
|
||||||
logger.warning(f"Could not fetch info for predefined channel {channel_id}: {e}", exc_info=True)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"ok": True,
|
|
||||||
"channels": channels_info,
|
|
||||||
"response_metadata": {"next_cursor": ""}
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
# First get list of channels
|
params: Dict[str, Any] = {
|
||||||
params = {
|
|
||||||
"types": "public_channel",
|
"types": "public_channel",
|
||||||
"exclude_archived": "true",
|
"exclude_archived": "true",
|
||||||
"limit": min(limit, 200),
|
|
||||||
"team_id": self.team_id,
|
|
||||||
}
|
|
||||||
if cursor:
|
|
||||||
params["cursor"] = cursor
|
|
||||||
|
|
||||||
channels_list = await self._request("GET", "conversations.list", params=params)
|
|
||||||
|
|
||||||
if not channels_list.get("ok"):
|
|
||||||
return channels_list
|
|
||||||
|
|
||||||
# Then fetch history for each channel
|
|
||||||
channels_with_history = []
|
|
||||||
for channel in channels_list["channels"]:
|
|
||||||
try:
|
|
||||||
if channel_data := await fetch_channel_with_history(channel["id"]):
|
|
||||||
channels_with_history.append(channel_data)
|
|
||||||
except Exception as e: # Catch errors during history fetch but don't fail the whole channel list
|
|
||||||
logger.warning(f"Error fetching history for channel {channel['id']}: {e}", exc_info=True)
|
|
||||||
channel["history"] = [] # Add empty history on error
|
|
||||||
channels_with_history.append(channel)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"ok": True,
|
|
||||||
"channels": channels_with_history,
|
|
||||||
"response_metadata": channels_list.get("response_metadata", {"next_cursor": ""})
|
|
||||||
}
|
|
||||||
|
|
||||||
async def post_message(self, args: PostMessageArgs) -> Dict[str, Any]:
|
|
||||||
payload = {"channel": args.channel_id, "text": args.text}
|
|
||||||
return await self._request("POST", "chat.postMessage", json_data=payload)
|
|
||||||
|
|
||||||
async def post_reply(self, args: ReplyToThreadArgs) -> Dict[str, Any]:
|
|
||||||
payload = {"channel": args.channel_id, "thread_ts": args.thread_ts, "text": args.text}
|
|
||||||
return await self._request("POST", "chat.postMessage", json_data=payload)
|
|
||||||
|
|
||||||
async def add_reaction(self, args: AddReactionArgs) -> Dict[str, Any]:
|
|
||||||
payload = {"channel": args.channel_id, "timestamp": args.timestamp, "name": args.reaction}
|
|
||||||
return await self._request("POST", "reactions.add", json_data=payload)
|
|
||||||
|
|
||||||
async def get_thread_replies(self, args: GetThreadRepliesArgs) -> Dict[str, Any]:
|
|
||||||
params = {"channel": args.channel_id, "ts": args.thread_ts}
|
|
||||||
return await self._request("GET", "conversations.replies", params=params)
|
|
||||||
|
|
||||||
async def get_users(self, args: GetUsersArgs) -> Dict[str, Any]:
|
|
||||||
params = {
|
|
||||||
"limit": min(args.limit, 200),
|
"limit": min(args.limit, 200),
|
||||||
"team_id": self.team_id,
|
"team_id": self.team_id,
|
||||||
}
|
}
|
||||||
|
if args.cursor:
|
||||||
|
params["cursor"] = args.cursor
|
||||||
|
clist = await self._request("GET", "conversations.list", params=params)
|
||||||
|
ids = [c["id"] for c in clist["channels"]]
|
||||||
|
next_cursor = clist.get("response_metadata", {}).get("next_cursor", "")
|
||||||
|
|
||||||
|
# 2. fetch metadata + history concurrently under a semaphore
|
||||||
|
sem = asyncio.Semaphore(10) # adjust parallelism as desired
|
||||||
|
|
||||||
|
async def guarded(cid: str):
|
||||||
|
async with sem:
|
||||||
|
return await self.channel_with_history(cid)
|
||||||
|
|
||||||
|
channels = [c for c in await asyncio.gather(*(guarded(cid) for cid in ids)) if c]
|
||||||
|
return {"ok": True, "channels": channels, "response_metadata": {"next_cursor": next_cursor}}
|
||||||
|
|
||||||
|
async def post_message(self, args: PostMessageArgs) -> Dict[str, Any]:
|
||||||
|
return await self._request("POST", "chat.postMessage", json_data={"channel": args.channel_id, "text": args.text})
|
||||||
|
|
||||||
|
async def post_reply(self, args: ReplyToThreadArgs) -> Dict[str, Any]:
|
||||||
|
return await self._request(
|
||||||
|
"POST",
|
||||||
|
"chat.postMessage",
|
||||||
|
json_data={"channel": args.channel_id, "thread_ts": args.thread_ts, "text": args.text},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def add_reaction(self, args: AddReactionArgs) -> Dict[str, Any]:
|
||||||
|
return await self._request(
|
||||||
|
"POST",
|
||||||
|
"reactions.add",
|
||||||
|
json_data={"channel": args.channel_id, "timestamp": args.timestamp, "name": args.reaction},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_thread_replies(self, args: GetThreadRepliesArgs) -> Dict[str, Any]:
|
||||||
|
return await self._request("GET", "conversations.replies", params={"channel": args.channel_id, "ts": args.thread_ts})
|
||||||
|
|
||||||
|
async def get_users(self, args: GetUsersArgs) -> Dict[str, Any]:
|
||||||
|
params = {"limit": min(args.limit, 200), "team_id": self.team_id}
|
||||||
if args.cursor:
|
if args.cursor:
|
||||||
params["cursor"] = args.cursor
|
params["cursor"] = args.cursor
|
||||||
return await self._request("GET", "users.list", params=params)
|
return await self._request("GET", "users.list", params=params)
|
||||||
|
|
||||||
async def get_user_profile(self, args: GetUserProfileArgs) -> Dict[str, Any]:
|
async def get_user_profile(self, args: GetUserProfileArgs) -> Dict[str, Any]:
|
||||||
params = {"user": args.user_id, "include_labels": "true"}
|
return await self._request("GET", "users.profile.get", params={"user": args.user_id, "include_labels": "true"})
|
||||||
return await self._request("GET", "users.profile.get", params=params)
|
|
||||||
|
|
||||||
# --- Instantiate Slack Client ---
|
# ---------------- lifecycle ---------------- #
|
||||||
|
async def aclose(self) -> None: # call on app shutdown
|
||||||
|
await self._client.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Instantiate Slack client
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
slack_client = SlackClient(token=SLACK_BOT_TOKEN, team_id=SLACK_TEAM_ID)
|
slack_client = SlackClient(token=SLACK_BOT_TOKEN, team_id=SLACK_TEAM_ID)
|
||||||
|
|
||||||
# --- Tool Definitions & Endpoint Generation ---
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Dynamic tool mapping / endpoint generation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
TOOL_MAPPING = {
|
TOOL_MAPPING = {
|
||||||
"slack_list_channels": {
|
"slack_list_channels": {
|
||||||
"args_model": ListChannelsArgs,
|
"args_model": ListChannelsArgs,
|
||||||
@ -314,40 +335,45 @@ TOOL_MAPPING = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Define a function factory to create endpoint handlers, including API key dependency
|
|
||||||
|
# ---------------- endpoint factory ---------------- #
|
||||||
|
|
||||||
def create_endpoint_handler(tool_name: str, method: Callable, args_model: Type[BaseModel]):
|
def create_endpoint_handler(tool_name: str, method: Callable, args_model: Type[BaseModel]):
|
||||||
async def endpoint_handler(
|
async def handler(args: args_model = Body(...), api_key: str = Depends(get_api_key)) -> ToolResponse: # noqa: ANN001
|
||||||
args: args_model = Body(...),
|
|
||||||
api_key: str = Depends(get_api_key) # Add API key dependency here
|
|
||||||
) -> ToolResponse:
|
|
||||||
try:
|
try:
|
||||||
result = await method(args=args)
|
result = await method(args=args)
|
||||||
return {"content": result}
|
return {"content": result}
|
||||||
except HTTPException as e:
|
except HTTPException:
|
||||||
raise e
|
raise # re‑raise untouched
|
||||||
except Exception as e:
|
except Exception as exc: # noqa: BLE001
|
||||||
logger.exception(f"Error executing tool {tool_name}: {e}") # Use logger.exception here too
|
logger.exception("Error executing tool %s: %s", tool_name, exc)
|
||||||
raise HTTPException(status_code=500, detail=f"Internal server error: {type(e).__name__}")
|
raise HTTPException(status_code=500, detail=f"Internal server error: {type(exc).__name__}")
|
||||||
return endpoint_handler
|
|
||||||
|
|
||||||
# Register endpoints for each tool
|
return handler
|
||||||
for tool_name, config in TOOL_MAPPING.items():
|
|
||||||
handler = create_endpoint_handler(
|
|
||||||
tool_name=tool_name,
|
|
||||||
method=config["method"],
|
|
||||||
args_model=config["args_model"]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
for name, cfg in TOOL_MAPPING.items():
|
||||||
app.post(
|
app.post(
|
||||||
f"/{tool_name}",
|
f"/{name}",
|
||||||
response_model=ToolResponse,
|
response_model=ToolResponse,
|
||||||
summary=config["description"],
|
summary=cfg["description"],
|
||||||
description=f"Executes the {tool_name} tool. Arguments are passed in the request body.",
|
description=f"Executes the {name} tool. Arguments are passed in the request body.",
|
||||||
tags=["Slack Tools"],
|
tags=["Slack Tools"],
|
||||||
name=tool_name
|
name=name,
|
||||||
)(handler)
|
)(create_endpoint_handler(name, cfg["method"], cfg["args_model"]))
|
||||||
|
|
||||||
# --- Root Endpoint ---
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Lifecycle events
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def _close_slack_client():
|
||||||
|
await slack_client.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Root endpoint
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
@app.get("/", summary="Root endpoint", include_in_schema=False)
|
@app.get("/", summary="Root endpoint", include_in_schema=False)
|
||||||
async def read_root():
|
async def read_root():
|
||||||
return {"message": "Slack API Server is running. See /docs for available tool endpoints."}
|
return {"message": "Slack API Server is running. See /docs for available tool endpoints."}
|
||||||
|
Loading…
Reference in New Issue
Block a user