implement review feedback, add auth, logging, and improve robustness

This commit is contained in:
Taylor Wilsdon 2025-04-20 19:44:52 -04:00
parent 08dff6aaeb
commit 565a563a77
3 changed files with 89 additions and 34 deletions

View File

@ -47,5 +47,9 @@ COPY . .
# Expose the port that the application listens on. # Expose the port that the application listens on.
EXPOSE 8000 EXPOSE 8000
# Run the application. # Add a healthcheck to verify the server is running
CMD uvicorn 'main:app' --host=0.0.0.0 --port=8000 HEALTHCHECK --interval=30s --timeout=5s --start-period=5s --retries=3 \
CMD curl --fail http://localhost:8000/ || exit 1
# Run the application using the JSON array form to avoid shell interpretation issues.
CMD ["uvicorn", "main:app", "--host=0.0.0.0", "--port=8000"]

View File

@ -1,12 +1,19 @@
import os import os
import httpx import httpx
import inspect import inspect
import logging
import json # For JSONDecodeError
from typing import Optional, List, Dict, Any, Type, Callable from typing import Optional, List, Dict, Any, Type, Callable
from fastapi import FastAPI, HTTPException, Body, Depends from fastapi import FastAPI, HTTPException, Body, Depends, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from dotenv import load_dotenv from dotenv import load_dotenv
# --- Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Load environment variables from .env file # Load environment variables from .env file
load_dotenv() load_dotenv()
@ -14,10 +21,15 @@ load_dotenv()
SLACK_BOT_TOKEN = os.getenv("SLACK_BOT_TOKEN") SLACK_BOT_TOKEN = os.getenv("SLACK_BOT_TOKEN")
SLACK_TEAM_ID = os.getenv("SLACK_TEAM_ID") SLACK_TEAM_ID = os.getenv("SLACK_TEAM_ID")
SLACK_CHANNEL_IDS_STR = os.getenv("SLACK_CHANNEL_IDS") # Optional SLACK_CHANNEL_IDS_STR = os.getenv("SLACK_CHANNEL_IDS") # Optional
ALLOWED_ORIGINS_STR = os.getenv("ALLOWED_ORIGINS", "*") # Default to allow all
SERVER_API_KEY = os.getenv("SERVER_API_KEY") # Optional API key for security
if not SLACK_BOT_TOKEN: if not SLACK_BOT_TOKEN:
# Fail fast if essential config is missing
logger.critical("SLACK_BOT_TOKEN environment variable not set.")
raise ValueError("SLACK_BOT_TOKEN environment variable not set.") raise ValueError("SLACK_BOT_TOKEN environment variable not set.")
if not SLACK_TEAM_ID: if not SLACK_TEAM_ID:
logger.critical("SLACK_TEAM_ID environment variable not set.")
raise ValueError("SLACK_TEAM_ID environment variable not set.") raise ValueError("SLACK_TEAM_ID environment variable not set.")
PREDEFINED_CHANNEL_IDS = [ PREDEFINED_CHANNEL_IDS = [
@ -32,16 +44,38 @@ app = FastAPI(
description="FastAPI server providing Slack functionalities via specific, dynamically generated tool endpoints.", description="FastAPI server providing Slack functionalities via specific, dynamically generated tool endpoints.",
) )
origins = ["*"] # Configure CORS
allow_origins = [origin.strip() for origin in ALLOWED_ORIGINS_STR.split(',')]
if allow_origins == ["*"]:
logger.warning("CORS allow_origins is set to '*' which is insecure for production. Consider setting ALLOWED_ORIGINS environment variable.")
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=origins, allow_origins=allow_origins,
allow_credentials=True, allow_credentials=True, # Allow credentials if origins are specific, adjust if needed
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
# --- API Key Security ---
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) # auto_error=False to handle optional key
async def get_api_key(key: str = Security(api_key_header)):
if SERVER_API_KEY: # Only enforce key if it's set in the environment
if not key:
logger.warning("API Key required but not provided in X-API-Key header.")
raise HTTPException(status_code=401, detail="X-API-Key header required")
if key != SERVER_API_KEY:
logger.warning("Invalid API Key provided.")
raise HTTPException(status_code=401, detail="Invalid API Key")
# If key is valid and required, proceed
# If SERVER_API_KEY is not set, allow access without a key
# logger.info("API Key check passed (or not required).") # Optional: Log successful checks
return key # Return the key or None if not required/provided
if not SERVER_API_KEY:
logger.warning("SERVER_API_KEY environment variable is not set. Server will allow unauthenticated requests.")
# [Previous Pydantic models remain the same...] # [Previous Pydantic models remain the same...]
class ListChannelsArgs(BaseModel): class ListChannelsArgs(BaseModel):
limit: Optional[int] = Field(100, description="Maximum number of channels to return (default 100, max 200)") limit: Optional[int] = Field(100, description="Maximum number of channels to return (default 100, max 200)")
@ -98,18 +132,29 @@ class SlackClient:
data = response.json() data = response.json()
if not data.get("ok"): if not data.get("ok"):
error_msg = data.get("error", "Unknown Slack API error") error_msg = data.get("error", "Unknown Slack API error")
print(f"Slack API Error for {method} {endpoint}: {error_msg}") # Return the specific Slack error in the response
raise HTTPException(status_code=400, detail={"slack_error": error_msg, "message": f"Slack API Error: {error_msg}"}) logger.warning(f"Slack API Error for {method} {endpoint}: {error_msg}")
raise HTTPException(status_code=400, detail={"slack_error": error_msg, "message": f"Slack API returned an error: {error_msg}"})
return data return data
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
print(f"HTTP Error: {e.response.status_code} - {e.response.text}") # Handle specific HTTP errors like rate limiting (429)
raise HTTPException(status_code=e.response.status_code, detail=f"Slack API HTTP Error: {e.response.text}") if e.response.status_code == 429:
retry_after = e.response.headers.get("Retry-After")
detail = f"Slack API rate limit exceeded. Retry after {retry_after} seconds." if retry_after else "Slack API rate limit exceeded."
logger.warning(f"Rate limit hit for {method} {endpoint}. Retry-After: {retry_after}")
raise HTTPException(status_code=429, detail=detail, headers={"Retry-After": retry_after} if retry_after else {})
else:
logger.error(f"HTTP Error: {e.response.status_code} - {e.response.text}", exc_info=True)
raise HTTPException(status_code=e.response.status_code, detail=f"Slack API HTTP Error: Status {e.response.status_code}")
except httpx.RequestError as e: except httpx.RequestError as e:
print(f"Request Error: {e}") logger.error(f"Request Error connecting to Slack API: {e}", exc_info=True)
raise HTTPException(status_code=503, detail=f"Error connecting to Slack API: {e}") raise HTTPException(status_code=503, detail=f"Could not connect to Slack API: {e}")
except Exception as e: except json.JSONDecodeError as e:
print(f"Unexpected Error during Slack request: {e}") logger.error(f"Failed to decode JSON response from Slack API for {method} {endpoint}: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"An internal error occurred during the Slack request: {e}") raise HTTPException(status_code=502, detail="Invalid response received from Slack API.")
except Exception as e: # Catch other unexpected errors
logger.exception(f"Unexpected error during Slack request for {method} {endpoint}: {e}") # Use logger.exception to include traceback
raise HTTPException(status_code=500, detail=f"An internal server error occurred: {type(e).__name__}")
async def get_channel_history(self, args: GetChannelHistoryArgs) -> Dict[str, Any]: async def get_channel_history(self, args: GetChannelHistoryArgs) -> Dict[str, Any]:
params = {"channel": args.channel_id, "limit": args.limit} params = {"channel": args.channel_id, "limit": args.limit}
@ -134,15 +179,15 @@ class SlackClient:
"conversations.history", "conversations.history",
params={ params={
"channel": channel_id, "channel": channel_id,
"limit": 10 # Get last 10 messages by default "limit": 1 # Fetch minimal history by default to speed up get_channels. Consider asyncio.gather for concurrency.
} }
) )
# Add history to channel data # Add history to channel data
if history.get("ok"): if history.get("ok"):
channel_data["history"] = history.get("messages", []) channel_data["history"] = history.get("messages", [])
except Exception as e: except Exception as e: # Catch errors during history fetch but don't fail the whole channel list
print(f"Error fetching history for channel {channel_id}: {e}") logger.warning(f"Error fetching history for channel {channel_id}: {e}", exc_info=True)
channel_data["history"] = [] channel_data["history"] = [] # Ensure history key exists even if fetch fails
return channel_data return channel_data
@ -152,8 +197,8 @@ class SlackClient:
try: try:
if channel_data := await fetch_channel_with_history(channel_id): if channel_data := await fetch_channel_with_history(channel_id):
channels_info.append(channel_data) channels_info.append(channel_data)
except Exception as e: except Exception as e: # Catch errors fetching predefined channels
print(f"Could not fetch info for predefined channel {channel_id}: {e}") logger.warning(f"Could not fetch info for predefined channel {channel_id}: {e}", exc_info=True)
return { return {
"ok": True, "ok": True,
@ -182,9 +227,10 @@ class SlackClient:
try: try:
if channel_data := await fetch_channel_with_history(channel["id"]): if channel_data := await fetch_channel_with_history(channel["id"]):
channels_with_history.append(channel_data) channels_with_history.append(channel_data)
except Exception as e: except Exception as e: # Catch errors during history fetch but don't fail the whole channel list
print(f"Error fetching history for channel {channel['id']}: {e}") logger.warning(f"Error fetching history for channel {channel['id']}: {e}", exc_info=True)
channels_with_history.append(channel) # Fall back to channel info without history channel["history"] = [] # Add empty history on error
channels_with_history.append(channel)
return { return {
"ok": True, "ok": True,
@ -268,17 +314,20 @@ TOOL_MAPPING = {
}, },
} }
# Define a function factory to create endpoint handlers # Define a function factory to create endpoint handlers, including API key dependency
def create_endpoint_handler(tool_name: str, method: Callable, args_model: Type[BaseModel]): def create_endpoint_handler(tool_name: str, method: Callable, args_model: Type[BaseModel]):
async def endpoint_handler(args: args_model = Body(...)) -> ToolResponse: async def endpoint_handler(
args: args_model = Body(...),
api_key: str = Depends(get_api_key) # Add API key dependency here
) -> ToolResponse:
try: try:
result = await method(args=args) result = await method(args=args)
return {"content": result} return {"content": result}
except HTTPException as e: except HTTPException as e:
raise e raise e
except Exception as e: except Exception as e:
print(f"Error executing tool {tool_name}: {e}") logger.exception(f"Error executing tool {tool_name}: {e}") # Use logger.exception here too
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") raise HTTPException(status_code=500, detail=f"Internal server error: {type(e).__name__}")
return endpoint_handler return endpoint_handler
# Register endpoints for each tool # Register endpoints for each tool

View File

@ -1,6 +1,8 @@
fastapi fastapi>=0.110.0,<0.111.0
uvicorn[standard] uvicorn[standard]>=0.29.0,<0.30.0
pydantic pydantic>=2.6.0,<3.0.0
python-multipart httpx>=0.27.0,<0.28.0
httpx python-dotenv>=1.0.0,<2.0.0
python-dotenv # NOTE: Run 'pip freeze > requirements.txt' in a virtual environment
# to capture the exact versions of all transitive dependencies
# for truly reproducible builds.