feat: oauth2.1 mcp integration

This commit is contained in:
Timothy Jaeryang Baek
2025-09-25 01:49:16 -05:00
parent 972be4eda5
commit 77e971dd9f
10 changed files with 248 additions and 53 deletions

View File

@@ -473,7 +473,12 @@ from open_webui.utils.auth import (
get_verified_user,
)
from open_webui.utils.plugin import install_tool_and_function_dependencies
from open_webui.utils.oauth import OAuthManager
from open_webui.utils.oauth import (
OAuthManager,
OAuthClientManager,
decrypt_data,
OAuthClientInformationFull,
)
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.utils.redis import get_redis_connection
@@ -603,9 +608,14 @@ app = FastAPI(
lifespan=lifespan,
)
# For Open WebUI OIDC/OAuth2
oauth_manager = OAuthManager(app)
app.state.oauth_manager = oauth_manager
# For Integrations
oauth_client_manager = OAuthClientManager(app)
app.state.oauth_client_manager = oauth_client_manager
app.state.instance_id = None
app.state.config = AppConfig(
redis_url=REDIS_URL,
@@ -1881,6 +1891,24 @@ async def get_current_usage(user=Depends(get_verified_user)):
# OAuth Login & Callback
############################
# Initialize OAuth client manager with any MCP tool servers using OAuth 2.1
if len(app.state.config.TOOL_SERVER_CONNECTIONS) > 0:
for tool_server_connection in app.state.config.TOOL_SERVER_CONNECTIONS:
if tool_server_connection.get("type", "openapi") == "mcp":
server_id = tool_server_connection.get("info", {}).get("id")
auth_type = tool_server_connection.get("auth_type", "none")
if server_id and auth_type == "oauth_2.1":
oauth_client_info = tool_server_connection.get("info", {}).get(
"oauth_client_info"
)
oauth_client_info = decrypt_data(oauth_client_info)
app.state.oauth_client_manager.add_client(
f"mcp:{server_id}", OAuthClientInformationFull(**oauth_client_info)
)
# SessionMiddleware is used by authlib for oauth
if len(OAUTH_PROVIDERS) > 0:
try:
@@ -1913,6 +1941,31 @@ if len(OAUTH_PROVIDERS) > 0:
)
@app.get("/oauth/clients/{client_id}/authorize")
async def oauth_client_authorize(
client_id: str,
request: Request,
response: Response,
user=Depends(get_verified_user),
):
return await oauth_client_manager.handle_authorize(request, client_id=client_id)
@app.get("/oauth/clients/{client_id}/callback")
async def oauth_client_callback(
client_id: str,
request: Request,
response: Response,
user=Depends(get_verified_user),
):
return await oauth_client_manager.handle_callback(
request,
client_id=client_id,
user_id=user.id if user else None,
response=response,
)
@app.get("/oauth/{provider}/login")
async def oauth_login(provider: str, request: Request):
return await oauth_manager.handle_login(request, provider)
@@ -1924,8 +1977,9 @@ async def oauth_login(provider: str, request: Request):
# - This is considered insecure in general, as OAuth providers do not always verify email addresses
# 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user
# - Email addresses are considered unique, so we fail registration if the email address is already taken
@app.get("/oauth/{provider}/callback")
async def oauth_callback(provider: str, request: Request, response: Response):
@app.get("/oauth/{provider}/callback") # Legacy endpoint
@app.get("/oauth/{provider}/login/callback")
async def oauth_login_callback(provider: str, request: Request, response: Response):
return await oauth_manager.handle_callback(request, provider, response)

View File

@@ -176,6 +176,26 @@ class OAuthSessionTable:
log.error(f"Error getting OAuth session by ID: {e}")
return None
def get_session_by_provider_and_user_id(
self, provider: str, user_id: str
) -> Optional[OAuthSessionModel]:
"""Get OAuth session by provider and user ID"""
try:
with get_db() as db:
session = (
db.query(OAuthSession)
.filter_by(provider=provider, user_id=user_id)
.first()
)
if session:
session.token = self._decrypt_token(session.token)
return OAuthSessionModel.model_validate(session)
return None
except Exception as e:
log.error(f"Error getting OAuth session by provider and user ID: {e}")
return None
def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]:
"""Get all OAuth sessions for a user"""
try:

View File

@@ -21,7 +21,9 @@ from open_webui.env import SRC_LOG_LEVELS
from open_webui.utils.oauth import (
get_discovery_urls,
get_oauth_client_info_with_dynamic_client_registration,
encrypt_token,
encrypt_data,
decrypt_data,
OAuthClientInformationFull,
)
from mcp.shared.auth import OAuthMetadata
@@ -103,17 +105,22 @@ class OAuthClientRegistrationForm(BaseModel):
async def register_oauth_client(
request: Request,
form_data: OAuthClientRegistrationForm,
type: Optional[str] = None,
user=Depends(get_admin_user),
):
try:
oauth_client_id = form_data.client_id
if type:
oauth_client_id = f"{type}:{form_data.client_id}"
oauth_client_info = (
await get_oauth_client_info_with_dynamic_client_registration(
request, form_data.url
request, oauth_client_id, form_data.url
)
)
return {
"status": True,
"oauth_client_info": encrypt_token(
"oauth_client_info": encrypt_data(
oauth_client_info.model_dump(mode="json")
),
}
@@ -161,8 +168,25 @@ async def set_tool_servers_config(
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
]
await set_tool_servers(request)
for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
server_type = connection.get("type", "openapi")
if server_type == "mcp":
server_id = connection.get("info", {}).get("id")
auth_type = connection.get("auth_type", "none")
if auth_type == "oauth_2.1" and server_id:
try:
oauth_client_info = decrypt_data(oauth_client_info)
await request.app.state.oauth_client_manager.add_client(
f"{server_type}:{server_id}",
OAuthClientInformationFull(**oauth_client_info),
)
except Exception as e:
log.debug(f"Failed to add OAuth client for MCP tool server: {e}")
continue
return {
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
}

View File

@@ -9,6 +9,7 @@ from pydantic import BaseModel, HttpUrl
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.models.oauth_sessions import OAuthSessions
from open_webui.models.tools import (
ToolForm,
ToolModel,
@@ -80,6 +81,24 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
# MCP Tool Servers
for server in request.app.state.config.TOOL_SERVER_CONNECTIONS:
if server.get("type", "openapi") == "mcp":
server_id = server.get("info", {}).get("id")
auth_type = server.get("auth_type", "none")
session_token = None
if auth_type == "oauth_2.1":
splits = server_id.split(":")
server_id = splits[-1] if len(splits) > 1 else server_id
session_token = (
await request.app.state.oauth_client_manager.get_oauth_token(
user.id, f"mcp:{server_id}"
)
)
print("User ID:", user.id)
print("Server ID:", server_id)
print("MCP Session Token:", session_token)
tools.append(
ToolUserResponse(
**{
@@ -96,6 +115,13 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
),
"updated_at": int(time.time()),
"created_at": int(time.time()),
**(
{
"authenticated": session_token is not None,
}
if auth_type == "oauth_2.1"
else {}
),
}
)
)

View File

@@ -24,6 +24,7 @@ from fastapi.responses import HTMLResponse
from starlette.responses import Response, StreamingResponse, JSONResponse
from open_webui.models.oauth_sessions import OAuthSessions
from open_webui.models.chats import Chats
from open_webui.models.folders import Folders
from open_webui.models.users import Users
@@ -1047,6 +1048,22 @@ async def process_chat_payload(request, form_data, user, metadata, model):
headers["Authorization"] = (
f"Bearer {oauth_token.get('access_token', '')}"
)
elif auth_type == "oauth_2.1":
try:
splits = server_id.split(":")
server_id = splits[-1] if len(splits) > 1 else server_id
oauth_token = await request.app.state.oauth_client_manager.get_oauth_token(
user.id, f"mcp:{server_id}"
)
if oauth_token:
headers["Authorization"] = (
f"Bearer {oauth_token.get('access_token', '')}"
)
except Exception as e:
log.error(f"Error getting OAuth token: {e}")
oauth_token = None
mcp_client = MCPClient()
await mcp_client.connect(

View File

@@ -126,24 +126,24 @@ except Exception as e:
raise
def encrypt_token(token) -> str:
"""Encrypt OAuth tokens for storage"""
def encrypt_data(data) -> str:
"""Encrypt data for storage"""
try:
token_json = json.dumps(token)
encrypted = FERNET.encrypt(token_json.encode()).decode()
data_json = json.dumps(data)
encrypted = FERNET.encrypt(data_json.encode()).decode()
return encrypted
except Exception as e:
log.error(f"Error encrypting tokens: {e}")
log.error(f"Error encrypting data: {e}")
raise
def decrypt_token(token: str):
"""Decrypt OAuth tokens from storage"""
def decrypt_data(data: str):
"""Decrypt data from storage"""
try:
decrypted = FERNET.decrypt(token.encode()).decode()
decrypted = FERNET.decrypt(data.encode()).decode()
return json.loads(decrypted)
except Exception as e:
log.error(f"Error decrypting tokens: {e}")
log.error(f"Error decrypting data: {e}")
raise
@@ -212,7 +212,10 @@ def get_discovery_urls(server_url) -> list[str]:
# TODO: Some OAuth providers require Initial Access Tokens (IATs) for dynamic client registration.
# This is not currently supported.
async def get_oauth_client_info_with_dynamic_client_registration(
request, oauth_server_url, oauth_server_key: Optional[str] = None
request,
client_id: str,
oauth_server_url: str,
oauth_server_key: Optional[str] = None,
) -> OAuthClientInformationFull:
try:
oauth_server_metadata = None
@@ -221,9 +224,10 @@ async def get_oauth_client_info_with_dynamic_client_registration(
redirect_base_url = (
str(request.app.state.config.WEBUI_URL or request.base_url)
).rstrip("/")
oauth_client_metadata = OAuthClientMetadata(
client_name="Open WebUI",
redirect_uris=[f"{redirect_base_url}/oauth/callback"],
redirect_uris=[f"{redirect_base_url}/oauth/clients/{client_id}/callback"],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
token_endpoint_auth_method="client_secret_post",
@@ -315,23 +319,22 @@ class OAuthClientManager:
self.clients = {}
def add_client(self, client_id, oauth_client_info: OAuthClientInformationFull):
if client_id not in self.clients:
self.clients[client_id] = {
"client": self.oauth.register(
name=client_id,
client_id=oauth_client_info.client_id,
client_secret=oauth_client_info.client_secret,
client_kwargs=(
{"scope": oauth_client_info.scope}
if oauth_client_info.scope
else {}
),
server_metadata_url=(
oauth_client_info.issuer if oauth_client_info.issuer else None
),
self.clients[client_id] = {
"client": self.oauth.register(
name=client_id,
client_id=oauth_client_info.client_id,
client_secret=oauth_client_info.client_secret,
client_kwargs=(
{"scope": oauth_client_info.scope}
if oauth_client_info.scope
else {}
),
"client_info": oauth_client_info,
}
server_metadata_url=(
oauth_client_info.issuer if oauth_client_info.issuer else None
),
),
"client_info": oauth_client_info,
}
return self.clients[client_id]
def remove_client(self, client_id):
@@ -359,7 +362,7 @@ class OAuthClientManager:
return None
async def get_oauth_token(
self, user_id: str, session_id: str, force_refresh: bool = False
self, user_id: str, client_id: str, force_refresh: bool = False
):
"""
Get a valid OAuth token for the user, automatically refreshing if needed.
@@ -374,10 +377,12 @@ class OAuthClientManager:
"""
try:
# Get the OAuth session
session = OAuthSessions.get_session_by_id_and_user_id(session_id, user_id)
session = OAuthSessions.get_session_by_provider_and_user_id(
client_id, user_id
)
if not session:
log.warning(
f"No OAuth session found for user {user_id}, session {session_id}"
f"No OAuth session found for user {user_id}, client_id {client_id}"
)
return None
@@ -392,8 +397,9 @@ class OAuthClientManager:
return refreshed_token
else:
log.warning(
f"Token refresh failed for user {user_id}, client_id {session.provider}"
f"Token refresh failed for user {user_id}, client_id {session.provider}, deleting session {session.id}"
)
OAuthSessions.delete_session_by_id(session.id)
return None
return session.token
@@ -533,7 +539,7 @@ class OAuthClientManager:
redirect_uri = (
client_info.redirect_uris[0] if client_info.redirect_uris else None
)
return await client.authorize_redirect(request, redirect_uri)
return await client.authorize_redirect(request, str(redirect_uri))
async def handle_callback(self, request, client_id: str, user_id: str, response):
client = self.get_client(client_id)
@@ -565,7 +571,6 @@ class OAuthClientManager:
provider=client_id,
token=token,
)
log.info(
f"Stored OAuth session server-side for user {user_id}, client_id {client_id}"
)
@@ -579,16 +584,17 @@ class OAuthClientManager:
error_message = "OAuth callback error"
log.warning(f"OAuth callback error: {e}")
redirect_base_url = (
redirect_url = (
str(request.app.state.config.WEBUI_URL or request.base_url)
).rstrip("/")
redirect_url = f"{redirect_base_url}/auth"
if error_message:
redirect_url = f"{redirect_url}?error={error_message}"
log.debug(error_message)
redirect_url = f"{redirect_url}/?error={error_message}"
return RedirectResponse(url=redirect_url, headers=response.headers)
response = RedirectResponse(url=redirect_url, headers=response.headers)
return response
class OAuthManager:
@@ -649,8 +655,10 @@ class OAuthManager:
return refreshed_token
else:
log.warning(
f"Token refresh failed for user {user_id}, provider {session.provider}"
f"Token refresh failed for user {user_id}, provider {session.provider}, deleting session {session.id}"
)
OAuthSessions.delete_session_by_id(session.id)
return None
return session.token