From f9c5819314e489eb2e1a87085ece71f14fe54942 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Mon, 10 Feb 2025 02:25:02 -0800 Subject: [PATCH] enh: code interpreter jupyter support --- backend/open_webui/config.py | 42 +++++ backend/open_webui/main.py | 25 +++ backend/open_webui/routers/configs.py | 55 +++++++ backend/open_webui/utils/code_interpreter.py | 153 ++++++++++++++++++ backend/open_webui/utils/middleware.py | 48 ++++-- src/lib/apis/configs/index.ts | 57 +++++++ src/lib/components/admin/Settings.svelte | 36 +++++ .../admin/Settings/CodeInterpreter.svelte | 145 +++++++++++++++++ 8 files changed, 552 insertions(+), 9 deletions(-) create mode 100644 backend/open_webui/utils/code_interpreter.py create mode 100644 src/lib/components/admin/Settings/CodeInterpreter.svelte diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index fc3629786..6993b2801 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1326,6 +1326,48 @@ Your task is to synthesize these responses into a single, high-quality response. Responses from models: {{responses}}""" +#################################### +# Code Interpreter +#################################### + +ENABLE_CODE_INTERPRETER = PersistentConfig( + "ENABLE_CODE_INTERPRETER", + "code_interpreter.enable", + os.environ.get("ENABLE_CODE_INTERPRETER", "True").lower() == "true", +) + +CODE_INTERPRETER_ENGINE = PersistentConfig( + "CODE_INTERPRETER_ENGINE", + "code_interpreter.engine", + os.environ.get("CODE_INTERPRETER_ENGINE", "pyodide"), +) + +CODE_INTERPRETER_JUPYTER_URL = PersistentConfig( + "CODE_INTERPRETER_JUPYTER_URL", + "code_interpreter.jupyter.url", + os.environ.get("CODE_INTERPRETER_JUPYTER_URL", ""), +) + +CODE_INTERPRETER_JUPYTER_AUTH = PersistentConfig( + "CODE_INTERPRETER_JUPYTER_AUTH", + "code_interpreter.jupyter.auth", + os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH", ""), +) + +CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = PersistentConfig( + "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN", + "code_interpreter.jupyter.auth_token", + os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH_TOKEN", ""), +) + + +CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = PersistentConfig( + "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD", + "code_interpreter.jupyter.auth_password", + os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD", ""), +) + + DEFAULT_CODE_INTERPRETER_PROMPT = """ #### Tools Available diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 119551f37..d9f1408ba 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -97,6 +97,13 @@ from open_webui.config import ( OPENAI_API_BASE_URLS, OPENAI_API_KEYS, OPENAI_API_CONFIGS, + # Code Interpreter + ENABLE_CODE_INTERPRETER, + CODE_INTERPRETER_ENGINE, + CODE_INTERPRETER_JUPYTER_URL, + CODE_INTERPRETER_JUPYTER_AUTH, + CODE_INTERPRETER_JUPYTER_AUTH_TOKEN, + CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD, # Image AUTOMATIC1111_API_AUTH, AUTOMATIC1111_BASE_URL, @@ -570,6 +577,23 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) +######################################## +# +# CODE INTERPRETER +# +######################################## + +app.state.config.ENABLE_CODE_INTERPRETER = ENABLE_CODE_INTERPRETER +app.state.config.CODE_INTERPRETER_ENGINE = CODE_INTERPRETER_ENGINE + +app.state.config.CODE_INTERPRETER_JUPYTER_URL = CODE_INTERPRETER_JUPYTER_URL +app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = CODE_INTERPRETER_JUPYTER_AUTH +app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = ( + CODE_INTERPRETER_JUPYTER_AUTH_TOKEN +) +app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = ( + CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD +) ######################################## # @@ -755,6 +779,7 @@ app.include_router(openai.router, prefix="/openai", tags=["openai"]) app.include_router(pipelines.router, prefix="/api/v1/pipelines", tags=["pipelines"]) app.include_router(tasks.router, prefix="/api/v1/tasks", tags=["tasks"]) app.include_router(images.router, prefix="/api/v1/images", tags=["images"]) + app.include_router(audio.router, prefix="/api/v1/audio", tags=["audio"]) app.include_router(retrieval.router, prefix="/api/v1/retrieval", tags=["retrieval"]) diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index ef6c4d8c1..4fe7dbf4d 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -36,6 +36,61 @@ async def export_config(user=Depends(get_admin_user)): return get_config() +############################ +# CodeInterpreterConfig +############################ +class CodeInterpreterConfigForm(BaseModel): + ENABLE_CODE_INTERPRETER: bool + CODE_INTERPRETER_ENGINE: str + CODE_INTERPRETER_JUPYTER_URL: Optional[str] + CODE_INTERPRETER_JUPYTER_AUTH: Optional[str] + CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str] + CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str] + + +@router.get("/code_interpreter", response_model=CodeInterpreterConfigForm) +async def get_code_interpreter_config(request: Request, user=Depends(get_admin_user)): + return { + "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER, + "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE, + "CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL, + "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH, + "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN, + "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD, + } + + +@router.post("/code_interpreter", response_model=CodeInterpreterConfigForm) +async def set_code_interpreter_config( + request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user) +): + request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER + request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE + request.app.state.config.CODE_INTERPRETER_JUPYTER_URL = ( + form_data.CODE_INTERPRETER_JUPYTER_URL + ) + + request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = ( + form_data.CODE_INTERPRETER_JUPYTER_AUTH + ) + + request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = ( + form_data.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN + ) + request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = ( + form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD + ) + + return { + "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER, + "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE, + "CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL, + "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH, + "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN, + "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD, + } + + ############################ # SetDefaultModels ############################ diff --git a/backend/open_webui/utils/code_interpreter.py b/backend/open_webui/utils/code_interpreter.py new file mode 100644 index 000000000..34daa71c9 --- /dev/null +++ b/backend/open_webui/utils/code_interpreter.py @@ -0,0 +1,153 @@ +import asyncio +import json +import uuid +import websockets +import requests +from urllib.parse import urljoin + + +async def execute_code_jupyter( + jupyter_url, code, token=None, password=None, timeout=10 +): + """ + Executes Python code in a Jupyter kernel. + Supports authentication with a token or password. + :param jupyter_url: Jupyter server URL (e.g., "http://localhost:8888") + :param code: Code to execute + :param token: Jupyter authentication token (optional) + :param password: Jupyter password (optional) + :param timeout: WebSocket timeout in seconds (default: 10s) + :return: Dictionary with stdout, stderr, and result + """ + session = requests.Session() # Maintain cookies + headers = {} # Headers for requests + + # Authenticate using password + if password and not token: + try: + login_url = urljoin(jupyter_url, "/login") + response = session.get(login_url) + response.raise_for_status() + + # Retrieve `_xsrf` token + xsrf_token = session.cookies.get("_xsrf") + if not xsrf_token: + raise ValueError("Failed to fetch _xsrf token") + + # Send login request + login_data = {"_xsrf": xsrf_token, "password": password} + login_response = session.post( + login_url, data=login_data, cookies=session.cookies + ) + login_response.raise_for_status() + + # Update headers with `_xsrf` + headers["X-XSRFToken"] = xsrf_token + except Exception as e: + return { + "stdout": "", + "stderr": f"Authentication Error: {str(e)}", + "result": "", + } + + # Construct API URLs with authentication token if provided + params = f"?token={token}" if token else "" + kernel_url = urljoin(jupyter_url, f"/api/kernels{params}") + + try: + # Include cookies if authenticating with password + response = session.post(kernel_url, headers=headers, cookies=session.cookies) + response.raise_for_status() + kernel_id = response.json()["id"] + + # Construct WebSocket URL + websocket_url = urljoin( + jupyter_url.replace("http", "ws"), + f"/api/kernels/{kernel_id}/channels{params}", + ) + + # **IMPORTANT:** Include authentication cookies for WebSockets + ws_headers = {} + if password and not token: + ws_headers["X-XSRFToken"] = session.cookies.get("_xsrf") + cookies = {name: value for name, value in session.cookies.items()} + ws_headers["Cookie"] = "; ".join( + [f"{name}={value}" for name, value in cookies.items()] + ) + + # Connect to the WebSocket + async with websockets.connect( + websocket_url, additional_headers=ws_headers + ) as ws: + msg_id = str(uuid.uuid4()) + + # Send execution request + execute_request = { + "header": { + "msg_id": msg_id, + "msg_type": "execute_request", + "username": "user", + "session": str(uuid.uuid4()), + "date": "", + "version": "5.3", + }, + "parent_header": {}, + "metadata": {}, + "content": { + "code": code, + "silent": False, + "store_history": True, + "user_expressions": {}, + "allow_stdin": False, + "stop_on_error": True, + }, + "channel": "shell", + } + await ws.send(json.dumps(execute_request)) + + # Collect execution results + stdout, stderr, result = "", "", None + while True: + try: + message = await asyncio.wait_for(ws.recv(), timeout) + message_data = json.loads(message) + if message_data.get("parent_header", {}).get("msg_id") == msg_id: + msg_type = message_data.get("msg_type") + if msg_type == "stream": + if message_data["content"]["name"] == "stdout": + stdout += message_data["content"]["text"] + elif message_data["content"]["name"] == "stderr": + stderr += message_data["content"]["text"] + elif msg_type in ("execute_result", "display_data"): + result = message_data["content"]["data"].get( + "text/plain", "" + ) + elif msg_type == "error": + stderr += "\n".join(message_data["content"]["traceback"]) + elif ( + msg_type == "status" + and message_data["content"]["execution_state"] == "idle" + ): + break + except asyncio.TimeoutError: + stderr += "\nExecution timed out." + break + except Exception as e: + return {"stdout": "", "stderr": f"Error: {str(e)}", "result": ""} + finally: + # Shutdown the kernel + if kernel_id: + requests.delete( + f"{kernel_url}/{kernel_id}", headers=headers, cookies=session.cookies + ) + + return { + "stdout": stdout.strip(), + "stderr": stderr.strip(), + "result": result.strip() if result else "", + } + + +# Example Usage +# asyncio.run(execute_code_jupyter("http://localhost:8888", "print('Hello, world!')", token="your-token")) +# asyncio.run(execute_code_jupyter("http://localhost:8888", "print('Hello, world!')", password="your-password")) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 29bfb2ba1..aba2a5e4f 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -72,7 +72,7 @@ from open_webui.utils.filter import ( get_sorted_filter_ids, process_filter_functions, ) - +from open_webui.utils.code_interpreter import execute_code_jupyter from open_webui.tasks import create_task @@ -1651,15 +1651,45 @@ async def process_chat_response( output = "" try: if content_blocks[-1]["attributes"].get("type") == "code": - output = await event_caller( - { - "type": "execute:python", - "data": { - "id": str(uuid4()), - "code": content_blocks[-1]["content"], - }, + code = content_blocks[-1]["content"] + + if ( + request.app.state.config.CODE_INTERPRETER_ENGINE + == "pyodide" + ): + output = await event_caller( + { + "type": "execute:python", + "data": { + "id": str(uuid4()), + "code": code, + }, + } + ) + elif ( + request.app.state.config.CODE_INTERPRETER_ENGINE + == "jupyter" + ): + output = await execute_code_jupyter( + request.app.state.config.CODE_INTERPRETER_JUPYTER_URL, + code, + ( + request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN + if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH + == "token" + else None + ), + ( + request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD + if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH + == "password" + else None + ), + ) + else: + output = { + "stdout": "Code interpreter engine not configured." } - ) if isinstance(output, dict): stdout = output.get("stdout", "") diff --git a/src/lib/apis/configs/index.ts b/src/lib/apis/configs/index.ts index e9faf346b..999842b26 100644 --- a/src/lib/apis/configs/index.ts +++ b/src/lib/apis/configs/index.ts @@ -58,6 +58,63 @@ export const exportConfig = async (token: string) => { return res; }; +export const getCodeInterpreterConfig = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/code_interpreter`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const setCodeInterpreterConfig = async (token: string, config: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/code_interpreter`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...config + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getModelsConfig = async (token: string) => { let error = null; diff --git a/src/lib/components/admin/Settings.svelte b/src/lib/components/admin/Settings.svelte index f0886ea5c..415e4377a 100644 --- a/src/lib/components/admin/Settings.svelte +++ b/src/lib/components/admin/Settings.svelte @@ -19,6 +19,7 @@ import ChartBar from '../icons/ChartBar.svelte'; import DocumentChartBar from '../icons/DocumentChartBar.svelte'; import Evaluations from './Settings/Evaluations.svelte'; + import CodeInterpreter from './Settings/CodeInterpreter.svelte'; const i18n = getContext('i18n'); @@ -188,6 +189,32 @@
{$i18n.t('Web Search')}
+ + + +