mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	Merge pull request #11120 from OrenZhang/fix_jupyter
fix(jupyter): fix kernel_id not set and optimize code
This commit is contained in:
		
						commit
						bb2bd7d721
					
				| @ -1,148 +1,189 @@ | ||||
| import asyncio | ||||
| import json | ||||
| import logging | ||||
| import uuid | ||||
| from typing import Optional | ||||
| 
 | ||||
| import aiohttp | ||||
| import websockets | ||||
| import requests | ||||
| from urllib.parse import urljoin | ||||
| from pydantic import BaseModel | ||||
| from websockets import ClientConnection | ||||
| 
 | ||||
| from open_webui.env import SRC_LOG_LEVELS | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| logger.setLevel(SRC_LOG_LEVELS["MAIN"]) | ||||
| 
 | ||||
| 
 | ||||
| class ResultModel(BaseModel): | ||||
|     """ | ||||
|     Execute Code Result Model | ||||
|     """ | ||||
| 
 | ||||
|     stdout: Optional[str] = "" | ||||
|     stderr: Optional[str] = "" | ||||
|     result: Optional[str] = "" | ||||
| 
 | ||||
| 
 | ||||
| class JupyterCodeExecuter: | ||||
|     """ | ||||
|     Execute code in jupyter notebook | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60): | ||||
|         """ | ||||
|         :param base_url: Jupyter server URL (e.g., "http://localhost:8888") | ||||
|         :param code: Code to execute | ||||
|         :param token: Jupyter authentication token (optional) | ||||
|         :param password: Jupyter password (optional) | ||||
|         :param timeout: WebSocket timeout in seconds (default: 60s) | ||||
|         """ | ||||
|         self.base_url = base_url.rstrip("/") | ||||
|         self.code = code | ||||
|         self.token = token | ||||
|         self.password = password | ||||
|         self.timeout = timeout | ||||
|         self.kernel_id = "" | ||||
|         self.session = aiohttp.ClientSession(base_url=self.base_url) | ||||
|         self.params = {} | ||||
|         self.result = ResultModel() | ||||
| 
 | ||||
|     async def __aenter__(self): | ||||
|         return self | ||||
| 
 | ||||
|     async def __aexit__(self, exc_type, exc_val, exc_tb): | ||||
|         if self.kernel_id: | ||||
|             try: | ||||
|                 async with self.session.delete(f"/api/kernels/{self.kernel_id}", params=self.params) as response: | ||||
|                     response.raise_for_status() | ||||
|             except Exception as err: | ||||
|                 logger.exception("close kernel failed, %s", err) | ||||
|         await self.session.close() | ||||
| 
 | ||||
|     async def run(self) -> ResultModel: | ||||
|         try: | ||||
|             await self.sign_in() | ||||
|             await self.init_kernel() | ||||
|             await self.execute_code() | ||||
|         except Exception as err: | ||||
|             logger.exception("execute code failed, %s", err) | ||||
|             self.result.stderr = f"Error: {err}" | ||||
|         return self.result | ||||
| 
 | ||||
|     async def sign_in(self) -> None: | ||||
|         # password authentication | ||||
|         if self.password and not self.token: | ||||
|             async with self.session.get("/login") as response: | ||||
|                 response.raise_for_status() | ||||
|                 xsrf_token = response.cookies["_xsrf"].value | ||||
|                 if not xsrf_token: | ||||
|                     raise ValueError("_xsrf token not found") | ||||
|                 self.session.cookie_jar.update_cookies(response.cookies) | ||||
|                 self.session.headers.update({"X-XSRFToken": xsrf_token}) | ||||
|             async with self.session.post( | ||||
|                 "/login", data={"_xsrf": xsrf_token, "password": self.password}, allow_redirects=False | ||||
|             ) as response: | ||||
|                 response.raise_for_status() | ||||
|                 self.session.cookie_jar.update_cookies(response.cookies) | ||||
| 
 | ||||
|         # token authentication | ||||
|         if self.token: | ||||
|             self.params.update({"token": self.token}) | ||||
| 
 | ||||
|     async def init_kernel(self) -> None: | ||||
|         async with self.session.post(url="/api/kernels", params=self.params) as response: | ||||
|             response.raise_for_status() | ||||
|             kernel_data = await response.json() | ||||
|             self.kernel_id = kernel_data["id"] | ||||
| 
 | ||||
|     def init_ws(self) -> (str, dict): | ||||
|         ws_base = self.base_url.replace("http", "ws") | ||||
|         ws_params = "?" + "&".join([f"{key}={val}" for key, val in self.params.items()]) | ||||
|         websocket_url = f"{ws_base}/api/kernels/{self.kernel_id}/channels{ws_params if len(ws_params) > 1 else ''}" | ||||
|         ws_headers = {} | ||||
|         if self.password and not self.token: | ||||
|             ws_headers = { | ||||
|                 "Cookie": "; ".join([f"{cookie.key}={cookie.value}" for cookie in self.session.cookie_jar]), | ||||
|                 **self.session.headers, | ||||
|             } | ||||
|         return websocket_url, ws_headers | ||||
| 
 | ||||
|     async def execute_code(self) -> None: | ||||
|         # initialize ws | ||||
|         websocket_url, ws_headers = self.init_ws() | ||||
|         # execute | ||||
|         async with websockets.connect(websocket_url, additional_headers=ws_headers) as ws: | ||||
|             await self.execute_in_jupyter(ws) | ||||
| 
 | ||||
|     async def execute_in_jupyter(self, ws: ClientConnection) -> None: | ||||
|         # send message | ||||
|         msg_id = uuid.uuid4().hex | ||||
|         await ws.send( | ||||
|             json.dumps( | ||||
|                 { | ||||
|                     "header": { | ||||
|                         "msg_id": msg_id, | ||||
|                         "msg_type": "execute_request", | ||||
|                         "username": "user", | ||||
|                         "session": uuid.uuid4().hex, | ||||
|                         "date": "", | ||||
|                         "version": "5.3", | ||||
|                     }, | ||||
|                     "parent_header": {}, | ||||
|                     "metadata": {}, | ||||
|                     "content": { | ||||
|                         "code": self.code, | ||||
|                         "silent": False, | ||||
|                         "store_history": True, | ||||
|                         "user_expressions": {}, | ||||
|                         "allow_stdin": False, | ||||
|                         "stop_on_error": True, | ||||
|                     }, | ||||
|                     "channel": "shell", | ||||
|                 } | ||||
|             ) | ||||
|         ) | ||||
|         # parse message | ||||
|         stdout, stderr, result = "", "", [] | ||||
|         while True: | ||||
|             try: | ||||
|                 # wait for message | ||||
|                 message = await asyncio.wait_for(ws.recv(), self.timeout) | ||||
|                 message_data = json.loads(message) | ||||
|                 # msg id not match, skip | ||||
|                 if message_data.get("parent_header", {}).get("msg_id") != msg_id: | ||||
|                     continue | ||||
|                 # check message type | ||||
|                 msg_type = message_data.get("msg_type") | ||||
|                 match msg_type: | ||||
|                     case "stream": | ||||
|                         if message_data["content"]["name"] == "stdout": | ||||
|                             stdout += message_data["content"]["text"] | ||||
|                         elif message_data["content"]["name"] == "stderr": | ||||
|                             stderr += message_data["content"]["text"] | ||||
|                     case "execute_result" | "display_data": | ||||
|                         data = message_data["content"]["data"] | ||||
|                         if "image/png" in data: | ||||
|                             result.append(f"data:image/png;base64,{data['image/png']}") | ||||
|                         elif "text/plain" in data: | ||||
|                             result.append(data["text/plain"]) | ||||
|                     case "error": | ||||
|                         stderr += "\n".join(message_data["content"]["traceback"]) | ||||
|                     case "status": | ||||
|                         if message_data["content"]["execution_state"] == "idle": | ||||
|                             break | ||||
| 
 | ||||
|             except asyncio.TimeoutError: | ||||
|                 stderr += "\nExecution timed out." | ||||
|                 break | ||||
|         self.result.stdout = stdout.strip() | ||||
|         self.result.stderr = stderr.strip() | ||||
|         self.result.result = "\n".join(result).strip() if result else "" | ||||
| 
 | ||||
| 
 | ||||
| async def execute_code_jupyter( | ||||
|     jupyter_url, code, token=None, password=None, timeout=10 | ||||
| ): | ||||
|     """ | ||||
|     Executes Python code in a Jupyter kernel. | ||||
|     Supports authentication with a token or password. | ||||
|     :param jupyter_url: Jupyter server URL (e.g., "http://localhost:8888") | ||||
|     :param code: Code to execute | ||||
|     :param token: Jupyter authentication token (optional) | ||||
|     :param password: Jupyter password (optional) | ||||
|     :param timeout: WebSocket timeout in seconds (default: 10s) | ||||
|     :return: Dictionary with stdout, stderr, and result | ||||
|              - Images are prefixed with "base64:image/png," and separated by newlines if multiple. | ||||
|     """ | ||||
|     session = requests.Session()  # Maintain cookies | ||||
|     headers = {}  # Headers for requests | ||||
| 
 | ||||
|     # Authenticate using password | ||||
|     if password and not token: | ||||
|         try: | ||||
|             login_url = urljoin(jupyter_url, "/login") | ||||
|             response = session.get(login_url) | ||||
|             response.raise_for_status() | ||||
|             xsrf_token = session.cookies.get("_xsrf") | ||||
|             if not xsrf_token: | ||||
|                 raise ValueError("Failed to fetch _xsrf token") | ||||
| 
 | ||||
|             login_data = {"_xsrf": xsrf_token, "password": password} | ||||
|             login_response = session.post( | ||||
|                 login_url, data=login_data, cookies=session.cookies | ||||
|             ) | ||||
|             login_response.raise_for_status() | ||||
|             headers["X-XSRFToken"] = xsrf_token | ||||
|         except Exception as e: | ||||
|             return { | ||||
|                 "stdout": "", | ||||
|                 "stderr": f"Authentication Error: {str(e)}", | ||||
|                 "result": "", | ||||
|             } | ||||
| 
 | ||||
|     # Construct API URLs with authentication token if provided | ||||
|     params = f"?token={token}" if token else "" | ||||
|     kernel_url = urljoin(jupyter_url, f"/api/kernels{params}") | ||||
| 
 | ||||
|     try: | ||||
|         response = session.post(kernel_url, headers=headers, cookies=session.cookies) | ||||
|         response.raise_for_status() | ||||
|         kernel_id = response.json()["id"] | ||||
| 
 | ||||
|         websocket_url = urljoin( | ||||
|             jupyter_url.replace("http", "ws"), | ||||
|             f"/api/kernels/{kernel_id}/channels{params}", | ||||
|         ) | ||||
| 
 | ||||
|         ws_headers = {} | ||||
|         if password and not token: | ||||
|             ws_headers["X-XSRFToken"] = session.cookies.get("_xsrf") | ||||
|             cookies = {name: value for name, value in session.cookies.items()} | ||||
|             ws_headers["Cookie"] = "; ".join( | ||||
|                 [f"{name}={value}" for name, value in cookies.items()] | ||||
|             ) | ||||
| 
 | ||||
|         async with websockets.connect( | ||||
|             websocket_url, additional_headers=ws_headers | ||||
|         ) as ws: | ||||
|             msg_id = str(uuid.uuid4()) | ||||
|             execute_request = { | ||||
|                 "header": { | ||||
|                     "msg_id": msg_id, | ||||
|                     "msg_type": "execute_request", | ||||
|                     "username": "user", | ||||
|                     "session": str(uuid.uuid4()), | ||||
|                     "date": "", | ||||
|                     "version": "5.3", | ||||
|                 }, | ||||
|                 "parent_header": {}, | ||||
|                 "metadata": {}, | ||||
|                 "content": { | ||||
|                     "code": code, | ||||
|                     "silent": False, | ||||
|                     "store_history": True, | ||||
|                     "user_expressions": {}, | ||||
|                     "allow_stdin": False, | ||||
|                     "stop_on_error": True, | ||||
|                 }, | ||||
|                 "channel": "shell", | ||||
|             } | ||||
|             await ws.send(json.dumps(execute_request)) | ||||
| 
 | ||||
|             stdout, stderr, result = "", "", [] | ||||
| 
 | ||||
|             while True: | ||||
|                 try: | ||||
|                     message = await asyncio.wait_for(ws.recv(), timeout) | ||||
|                     message_data = json.loads(message) | ||||
|                     if message_data.get("parent_header", {}).get("msg_id") == msg_id: | ||||
|                         msg_type = message_data.get("msg_type") | ||||
| 
 | ||||
|                         if msg_type == "stream": | ||||
|                             if message_data["content"]["name"] == "stdout": | ||||
|                                 stdout += message_data["content"]["text"] | ||||
|                             elif message_data["content"]["name"] == "stderr": | ||||
|                                 stderr += message_data["content"]["text"] | ||||
| 
 | ||||
|                         elif msg_type in ("execute_result", "display_data"): | ||||
|                             data = message_data["content"]["data"] | ||||
|                             if "image/png" in data: | ||||
|                                 result.append( | ||||
|                                     f"data:image/png;base64,{data['image/png']}" | ||||
|                                 ) | ||||
|                             elif "text/plain" in data: | ||||
|                                 result.append(data["text/plain"]) | ||||
| 
 | ||||
|                         elif msg_type == "error": | ||||
|                             stderr += "\n".join(message_data["content"]["traceback"]) | ||||
| 
 | ||||
|                         elif ( | ||||
|                             msg_type == "status" | ||||
|                             and message_data["content"]["execution_state"] == "idle" | ||||
|                         ): | ||||
|                             break | ||||
| 
 | ||||
|                 except asyncio.TimeoutError: | ||||
|                     stderr += "\nExecution timed out." | ||||
|                     break | ||||
| 
 | ||||
|     except Exception as e: | ||||
|         return {"stdout": "", "stderr": f"Error: {str(e)}", "result": ""} | ||||
| 
 | ||||
|     finally: | ||||
|         if kernel_id: | ||||
|             requests.delete( | ||||
|                 f"{kernel_url}/{kernel_id}", headers=headers, cookies=session.cookies | ||||
|             ) | ||||
| 
 | ||||
|     return { | ||||
|         "stdout": stdout.strip(), | ||||
|         "stderr": stderr.strip(), | ||||
|         "result": "\n".join(result).strip() if result else "", | ||||
|     } | ||||
|     base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60 | ||||
| ) -> dict: | ||||
|     async with JupyterCodeExecuter(base_url, code, token, password, timeout) as executor: | ||||
|         result = await executor.run() | ||||
|         return result.model_dump() | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user