mirror of
https://github.com/open-webui/openapi-servers
synced 2025-06-26 18:17:04 +00:00
implement review feedback, add auth, logging, and improve robustness
This commit is contained in:
parent
08dff6aaeb
commit
565a563a77
@ -47,5 +47,9 @@ COPY . .
|
|||||||
# Expose the port that the application listens on.
|
# Expose the port that the application listens on.
|
||||||
EXPOSE 8000
|
EXPOSE 8000
|
||||||
|
|
||||||
# Run the application.
|
# Add a healthcheck to verify the server is running
|
||||||
CMD uvicorn 'main:app' --host=0.0.0.0 --port=8000
|
HEALTHCHECK --interval=30s --timeout=5s --start-period=5s --retries=3 \
|
||||||
|
CMD curl --fail http://localhost:8000/ || exit 1
|
||||||
|
|
||||||
|
# Run the application using the JSON array form to avoid shell interpretation issues.
|
||||||
|
CMD ["uvicorn", "main:app", "--host=0.0.0.0", "--port=8000"]
|
||||||
|
@ -1,12 +1,19 @@
|
|||||||
import os
|
import os
|
||||||
import httpx
|
import httpx
|
||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
|
import json # For JSONDecodeError
|
||||||
from typing import Optional, List, Dict, Any, Type, Callable
|
from typing import Optional, List, Dict, Any, Type, Callable
|
||||||
from fastapi import FastAPI, HTTPException, Body, Depends
|
from fastapi import FastAPI, HTTPException, Body, Depends, Security
|
||||||
|
from fastapi.security import APIKeyHeader
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# --- Logging Setup ---
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Load environment variables from .env file
|
# Load environment variables from .env file
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
@ -14,10 +21,15 @@ load_dotenv()
|
|||||||
SLACK_BOT_TOKEN = os.getenv("SLACK_BOT_TOKEN")
|
SLACK_BOT_TOKEN = os.getenv("SLACK_BOT_TOKEN")
|
||||||
SLACK_TEAM_ID = os.getenv("SLACK_TEAM_ID")
|
SLACK_TEAM_ID = os.getenv("SLACK_TEAM_ID")
|
||||||
SLACK_CHANNEL_IDS_STR = os.getenv("SLACK_CHANNEL_IDS") # Optional
|
SLACK_CHANNEL_IDS_STR = os.getenv("SLACK_CHANNEL_IDS") # Optional
|
||||||
|
ALLOWED_ORIGINS_STR = os.getenv("ALLOWED_ORIGINS", "*") # Default to allow all
|
||||||
|
SERVER_API_KEY = os.getenv("SERVER_API_KEY") # Optional API key for security
|
||||||
|
|
||||||
if not SLACK_BOT_TOKEN:
|
if not SLACK_BOT_TOKEN:
|
||||||
|
# Fail fast if essential config is missing
|
||||||
|
logger.critical("SLACK_BOT_TOKEN environment variable not set.")
|
||||||
raise ValueError("SLACK_BOT_TOKEN environment variable not set.")
|
raise ValueError("SLACK_BOT_TOKEN environment variable not set.")
|
||||||
if not SLACK_TEAM_ID:
|
if not SLACK_TEAM_ID:
|
||||||
|
logger.critical("SLACK_TEAM_ID environment variable not set.")
|
||||||
raise ValueError("SLACK_TEAM_ID environment variable not set.")
|
raise ValueError("SLACK_TEAM_ID environment variable not set.")
|
||||||
|
|
||||||
PREDEFINED_CHANNEL_IDS = [
|
PREDEFINED_CHANNEL_IDS = [
|
||||||
@ -32,16 +44,38 @@ app = FastAPI(
|
|||||||
description="FastAPI server providing Slack functionalities via specific, dynamically generated tool endpoints.",
|
description="FastAPI server providing Slack functionalities via specific, dynamically generated tool endpoints.",
|
||||||
)
|
)
|
||||||
|
|
||||||
origins = ["*"]
|
# Configure CORS
|
||||||
|
allow_origins = [origin.strip() for origin in ALLOWED_ORIGINS_STR.split(',')]
|
||||||
|
if allow_origins == ["*"]:
|
||||||
|
logger.warning("CORS allow_origins is set to '*' which is insecure for production. Consider setting ALLOWED_ORIGINS environment variable.")
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=origins,
|
allow_origins=allow_origins,
|
||||||
allow_credentials=True,
|
allow_credentials=True, # Allow credentials if origins are specific, adjust if needed
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# --- API Key Security ---
|
||||||
|
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) # auto_error=False to handle optional key
|
||||||
|
|
||||||
|
async def get_api_key(key: str = Security(api_key_header)):
|
||||||
|
if SERVER_API_KEY: # Only enforce key if it's set in the environment
|
||||||
|
if not key:
|
||||||
|
logger.warning("API Key required but not provided in X-API-Key header.")
|
||||||
|
raise HTTPException(status_code=401, detail="X-API-Key header required")
|
||||||
|
if key != SERVER_API_KEY:
|
||||||
|
logger.warning("Invalid API Key provided.")
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid API Key")
|
||||||
|
# If key is valid and required, proceed
|
||||||
|
# If SERVER_API_KEY is not set, allow access without a key
|
||||||
|
# logger.info("API Key check passed (or not required).") # Optional: Log successful checks
|
||||||
|
return key # Return the key or None if not required/provided
|
||||||
|
|
||||||
|
if not SERVER_API_KEY:
|
||||||
|
logger.warning("SERVER_API_KEY environment variable is not set. Server will allow unauthenticated requests.")
|
||||||
|
|
||||||
# [Previous Pydantic models remain the same...]
|
# [Previous Pydantic models remain the same...]
|
||||||
class ListChannelsArgs(BaseModel):
|
class ListChannelsArgs(BaseModel):
|
||||||
limit: Optional[int] = Field(100, description="Maximum number of channels to return (default 100, max 200)")
|
limit: Optional[int] = Field(100, description="Maximum number of channels to return (default 100, max 200)")
|
||||||
@ -98,18 +132,29 @@ class SlackClient:
|
|||||||
data = response.json()
|
data = response.json()
|
||||||
if not data.get("ok"):
|
if not data.get("ok"):
|
||||||
error_msg = data.get("error", "Unknown Slack API error")
|
error_msg = data.get("error", "Unknown Slack API error")
|
||||||
print(f"Slack API Error for {method} {endpoint}: {error_msg}")
|
# Return the specific Slack error in the response
|
||||||
raise HTTPException(status_code=400, detail={"slack_error": error_msg, "message": f"Slack API Error: {error_msg}"})
|
logger.warning(f"Slack API Error for {method} {endpoint}: {error_msg}")
|
||||||
|
raise HTTPException(status_code=400, detail={"slack_error": error_msg, "message": f"Slack API returned an error: {error_msg}"})
|
||||||
return data
|
return data
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
print(f"HTTP Error: {e.response.status_code} - {e.response.text}")
|
# Handle specific HTTP errors like rate limiting (429)
|
||||||
raise HTTPException(status_code=e.response.status_code, detail=f"Slack API HTTP Error: {e.response.text}")
|
if e.response.status_code == 429:
|
||||||
|
retry_after = e.response.headers.get("Retry-After")
|
||||||
|
detail = f"Slack API rate limit exceeded. Retry after {retry_after} seconds." if retry_after else "Slack API rate limit exceeded."
|
||||||
|
logger.warning(f"Rate limit hit for {method} {endpoint}. Retry-After: {retry_after}")
|
||||||
|
raise HTTPException(status_code=429, detail=detail, headers={"Retry-After": retry_after} if retry_after else {})
|
||||||
|
else:
|
||||||
|
logger.error(f"HTTP Error: {e.response.status_code} - {e.response.text}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=e.response.status_code, detail=f"Slack API HTTP Error: Status {e.response.status_code}")
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
print(f"Request Error: {e}")
|
logger.error(f"Request Error connecting to Slack API: {e}", exc_info=True)
|
||||||
raise HTTPException(status_code=503, detail=f"Error connecting to Slack API: {e}")
|
raise HTTPException(status_code=503, detail=f"Could not connect to Slack API: {e}")
|
||||||
except Exception as e:
|
except json.JSONDecodeError as e:
|
||||||
print(f"Unexpected Error during Slack request: {e}")
|
logger.error(f"Failed to decode JSON response from Slack API for {method} {endpoint}: {e}", exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=f"An internal error occurred during the Slack request: {e}")
|
raise HTTPException(status_code=502, detail="Invalid response received from Slack API.")
|
||||||
|
except Exception as e: # Catch other unexpected errors
|
||||||
|
logger.exception(f"Unexpected error during Slack request for {method} {endpoint}: {e}") # Use logger.exception to include traceback
|
||||||
|
raise HTTPException(status_code=500, detail=f"An internal server error occurred: {type(e).__name__}")
|
||||||
|
|
||||||
async def get_channel_history(self, args: GetChannelHistoryArgs) -> Dict[str, Any]:
|
async def get_channel_history(self, args: GetChannelHistoryArgs) -> Dict[str, Any]:
|
||||||
params = {"channel": args.channel_id, "limit": args.limit}
|
params = {"channel": args.channel_id, "limit": args.limit}
|
||||||
@ -134,15 +179,15 @@ class SlackClient:
|
|||||||
"conversations.history",
|
"conversations.history",
|
||||||
params={
|
params={
|
||||||
"channel": channel_id,
|
"channel": channel_id,
|
||||||
"limit": 10 # Get last 10 messages by default
|
"limit": 1 # Fetch minimal history by default to speed up get_channels. Consider asyncio.gather for concurrency.
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
# Add history to channel data
|
# Add history to channel data
|
||||||
if history.get("ok"):
|
if history.get("ok"):
|
||||||
channel_data["history"] = history.get("messages", [])
|
channel_data["history"] = history.get("messages", [])
|
||||||
except Exception as e:
|
except Exception as e: # Catch errors during history fetch but don't fail the whole channel list
|
||||||
print(f"Error fetching history for channel {channel_id}: {e}")
|
logger.warning(f"Error fetching history for channel {channel_id}: {e}", exc_info=True)
|
||||||
channel_data["history"] = []
|
channel_data["history"] = [] # Ensure history key exists even if fetch fails
|
||||||
|
|
||||||
return channel_data
|
return channel_data
|
||||||
|
|
||||||
@ -152,8 +197,8 @@ class SlackClient:
|
|||||||
try:
|
try:
|
||||||
if channel_data := await fetch_channel_with_history(channel_id):
|
if channel_data := await fetch_channel_with_history(channel_id):
|
||||||
channels_info.append(channel_data)
|
channels_info.append(channel_data)
|
||||||
except Exception as e:
|
except Exception as e: # Catch errors fetching predefined channels
|
||||||
print(f"Could not fetch info for predefined channel {channel_id}: {e}")
|
logger.warning(f"Could not fetch info for predefined channel {channel_id}: {e}", exc_info=True)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"ok": True,
|
"ok": True,
|
||||||
@ -182,9 +227,10 @@ class SlackClient:
|
|||||||
try:
|
try:
|
||||||
if channel_data := await fetch_channel_with_history(channel["id"]):
|
if channel_data := await fetch_channel_with_history(channel["id"]):
|
||||||
channels_with_history.append(channel_data)
|
channels_with_history.append(channel_data)
|
||||||
except Exception as e:
|
except Exception as e: # Catch errors during history fetch but don't fail the whole channel list
|
||||||
print(f"Error fetching history for channel {channel['id']}: {e}")
|
logger.warning(f"Error fetching history for channel {channel['id']}: {e}", exc_info=True)
|
||||||
channels_with_history.append(channel) # Fall back to channel info without history
|
channel["history"] = [] # Add empty history on error
|
||||||
|
channels_with_history.append(channel)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"ok": True,
|
"ok": True,
|
||||||
@ -268,17 +314,20 @@ TOOL_MAPPING = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Define a function factory to create endpoint handlers
|
# Define a function factory to create endpoint handlers, including API key dependency
|
||||||
def create_endpoint_handler(tool_name: str, method: Callable, args_model: Type[BaseModel]):
|
def create_endpoint_handler(tool_name: str, method: Callable, args_model: Type[BaseModel]):
|
||||||
async def endpoint_handler(args: args_model = Body(...)) -> ToolResponse:
|
async def endpoint_handler(
|
||||||
|
args: args_model = Body(...),
|
||||||
|
api_key: str = Depends(get_api_key) # Add API key dependency here
|
||||||
|
) -> ToolResponse:
|
||||||
try:
|
try:
|
||||||
result = await method(args=args)
|
result = await method(args=args)
|
||||||
return {"content": result}
|
return {"content": result}
|
||||||
except HTTPException as e:
|
except HTTPException as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error executing tool {tool_name}: {e}")
|
logger.exception(f"Error executing tool {tool_name}: {e}") # Use logger.exception here too
|
||||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Internal server error: {type(e).__name__}")
|
||||||
return endpoint_handler
|
return endpoint_handler
|
||||||
|
|
||||||
# Register endpoints for each tool
|
# Register endpoints for each tool
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
fastapi
|
fastapi>=0.110.0,<0.111.0
|
||||||
uvicorn[standard]
|
uvicorn[standard]>=0.29.0,<0.30.0
|
||||||
pydantic
|
pydantic>=2.6.0,<3.0.0
|
||||||
python-multipart
|
httpx>=0.27.0,<0.28.0
|
||||||
httpx
|
python-dotenv>=1.0.0,<2.0.0
|
||||||
python-dotenv
|
# NOTE: Run 'pip freeze > requirements.txt' in a virtual environment
|
||||||
|
# to capture the exact versions of all transitive dependencies
|
||||||
|
# for truly reproducible builds.
|
||||||
|
Loading…
Reference in New Issue
Block a user