diff --git a/backend/open_webui/utils/code_interpreter.py b/backend/open_webui/utils/code_interpreter.py index 34daa71c9..0a74da9c7 100644 --- a/backend/open_webui/utils/code_interpreter.py +++ b/backend/open_webui/utils/code_interpreter.py @@ -18,6 +18,7 @@ async def execute_code_jupyter( :param password: Jupyter password (optional) :param timeout: WebSocket timeout in seconds (default: 10s) :return: Dictionary with stdout, stderr, and result + - Images are prefixed with "base64:image/png," and separated by newlines if multiple. """ session = requests.Session() # Maintain cookies headers = {} # Headers for requests @@ -28,20 +29,15 @@ async def execute_code_jupyter( login_url = urljoin(jupyter_url, "/login") response = session.get(login_url) response.raise_for_status() - - # Retrieve `_xsrf` token xsrf_token = session.cookies.get("_xsrf") if not xsrf_token: raise ValueError("Failed to fetch _xsrf token") - # Send login request login_data = {"_xsrf": xsrf_token, "password": password} login_response = session.post( login_url, data=login_data, cookies=session.cookies ) login_response.raise_for_status() - - # Update headers with `_xsrf` headers["X-XSRFToken"] = xsrf_token except Exception as e: return { @@ -55,18 +51,15 @@ async def execute_code_jupyter( kernel_url = urljoin(jupyter_url, f"/api/kernels{params}") try: - # Include cookies if authenticating with password response = session.post(kernel_url, headers=headers, cookies=session.cookies) response.raise_for_status() kernel_id = response.json()["id"] - # Construct WebSocket URL websocket_url = urljoin( jupyter_url.replace("http", "ws"), f"/api/kernels/{kernel_id}/channels{params}", ) - # **IMPORTANT:** Include authentication cookies for WebSockets ws_headers = {} if password and not token: ws_headers["X-XSRFToken"] = session.cookies.get("_xsrf") @@ -75,13 +68,10 @@ async def execute_code_jupyter( [f"{name}={value}" for name, value in cookies.items()] ) - # Connect to the WebSocket async with websockets.connect( websocket_url, additional_headers=ws_headers ) as ws: msg_id = str(uuid.uuid4()) - - # Send execution request execute_request = { "header": { "msg_id": msg_id, @@ -105,37 +95,47 @@ async def execute_code_jupyter( } await ws.send(json.dumps(execute_request)) - # Collect execution results - stdout, stderr, result = "", "", None + stdout, stderr, result = "", "", [] + while True: try: message = await asyncio.wait_for(ws.recv(), timeout) message_data = json.loads(message) if message_data.get("parent_header", {}).get("msg_id") == msg_id: msg_type = message_data.get("msg_type") + if msg_type == "stream": if message_data["content"]["name"] == "stdout": stdout += message_data["content"]["text"] elif message_data["content"]["name"] == "stderr": stderr += message_data["content"]["text"] + elif msg_type in ("execute_result", "display_data"): - result = message_data["content"]["data"].get( - "text/plain", "" - ) + data = message_data["content"]["data"] + if "image/png" in data: + result.append( + f"data:image/png;base64,{data['image/png']}" + ) + elif "text/plain" in data: + result.append(data["text/plain"]) + elif msg_type == "error": stderr += "\n".join(message_data["content"]["traceback"]) + elif ( msg_type == "status" and message_data["content"]["execution_state"] == "idle" ): break + except asyncio.TimeoutError: stderr += "\nExecution timed out." break + except Exception as e: return {"stdout": "", "stderr": f"Error: {str(e)}", "result": ""} + finally: - # Shutdown the kernel if kernel_id: requests.delete( f"{kernel_url}/{kernel_id}", headers=headers, cookies=session.cookies @@ -144,10 +144,5 @@ async def execute_code_jupyter( return { "stdout": stdout.strip(), "stderr": stderr.strip(), - "result": result.strip() if result else "", + "result": "\n".join(result).strip() if result else "", } - - -# Example Usage -# asyncio.run(execute_code_jupyter("http://localhost:8888", "print('Hello, world!')", token="your-token")) -# asyncio.run(execute_code_jupyter("http://localhost:8888", "print('Hello, world!')", password="your-password")) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index aba2a5e4f..f855e6434 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -1723,6 +1723,38 @@ async def process_chat_response( ) output["stdout"] = "\n".join(stdoutLines) + + result = output.get("result", "") + + if result: + resultLines = result.split("\n") + for idx, line in enumerate(resultLines): + if "data:image/png;base64" in line: + id = str(uuid4()) + + # ensure the path exists + os.makedirs( + os.path.join(CACHE_DIR, "images"), + exist_ok=True, + ) + + image_path = os.path.join( + CACHE_DIR, + f"images/{id}.png", + ) + + with open(image_path, "wb") as f: + f.write( + base64.b64decode( + line.split(",")[1] + ) + ) + + resultLines[idx] = ( + f"![Output Image {idx}](/cache/images/{id}.png)" + ) + + output["result"] = "\n".join(resultLines) except Exception as e: output = str(e)