mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge pull request #11120 from OrenZhang/fix_jupyter
fix(jupyter): fix kernel_id not set and optimize code
This commit is contained in:
commit
bb2bd7d721
@ -1,90 +1,138 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
import websockets
|
import websockets
|
||||||
import requests
|
from pydantic import BaseModel
|
||||||
from urllib.parse import urljoin
|
from websockets import ClientConnection
|
||||||
|
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
|
||||||
async def execute_code_jupyter(
|
class ResultModel(BaseModel):
|
||||||
jupyter_url, code, token=None, password=None, timeout=10
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Executes Python code in a Jupyter kernel.
|
Execute Code Result Model
|
||||||
Supports authentication with a token or password.
|
"""
|
||||||
:param jupyter_url: Jupyter server URL (e.g., "http://localhost:8888")
|
|
||||||
|
stdout: Optional[str] = ""
|
||||||
|
stderr: Optional[str] = ""
|
||||||
|
result: Optional[str] = ""
|
||||||
|
|
||||||
|
|
||||||
|
class JupyterCodeExecuter:
|
||||||
|
"""
|
||||||
|
Execute code in jupyter notebook
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60):
|
||||||
|
"""
|
||||||
|
:param base_url: Jupyter server URL (e.g., "http://localhost:8888")
|
||||||
:param code: Code to execute
|
:param code: Code to execute
|
||||||
:param token: Jupyter authentication token (optional)
|
:param token: Jupyter authentication token (optional)
|
||||||
:param password: Jupyter password (optional)
|
:param password: Jupyter password (optional)
|
||||||
:param timeout: WebSocket timeout in seconds (default: 10s)
|
:param timeout: WebSocket timeout in seconds (default: 60s)
|
||||||
:return: Dictionary with stdout, stderr, and result
|
|
||||||
- Images are prefixed with "base64:image/png," and separated by newlines if multiple.
|
|
||||||
"""
|
"""
|
||||||
session = requests.Session() # Maintain cookies
|
self.base_url = base_url.rstrip("/")
|
||||||
headers = {} # Headers for requests
|
self.code = code
|
||||||
|
self.token = token
|
||||||
|
self.password = password
|
||||||
|
self.timeout = timeout
|
||||||
|
self.kernel_id = ""
|
||||||
|
self.session = aiohttp.ClientSession(base_url=self.base_url)
|
||||||
|
self.params = {}
|
||||||
|
self.result = ResultModel()
|
||||||
|
|
||||||
# Authenticate using password
|
async def __aenter__(self):
|
||||||
if password and not token:
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
if self.kernel_id:
|
||||||
try:
|
try:
|
||||||
login_url = urljoin(jupyter_url, "/login")
|
async with self.session.delete(f"/api/kernels/{self.kernel_id}", params=self.params) as response:
|
||||||
response = session.get(login_url)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
xsrf_token = session.cookies.get("_xsrf")
|
except Exception as err:
|
||||||
|
logger.exception("close kernel failed, %s", err)
|
||||||
|
await self.session.close()
|
||||||
|
|
||||||
|
async def run(self) -> ResultModel:
|
||||||
|
try:
|
||||||
|
await self.sign_in()
|
||||||
|
await self.init_kernel()
|
||||||
|
await self.execute_code()
|
||||||
|
except Exception as err:
|
||||||
|
logger.exception("execute code failed, %s", err)
|
||||||
|
self.result.stderr = f"Error: {err}"
|
||||||
|
return self.result
|
||||||
|
|
||||||
|
async def sign_in(self) -> None:
|
||||||
|
# password authentication
|
||||||
|
if self.password and not self.token:
|
||||||
|
async with self.session.get("/login") as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
xsrf_token = response.cookies["_xsrf"].value
|
||||||
if not xsrf_token:
|
if not xsrf_token:
|
||||||
raise ValueError("Failed to fetch _xsrf token")
|
raise ValueError("_xsrf token not found")
|
||||||
|
self.session.cookie_jar.update_cookies(response.cookies)
|
||||||
login_data = {"_xsrf": xsrf_token, "password": password}
|
self.session.headers.update({"X-XSRFToken": xsrf_token})
|
||||||
login_response = session.post(
|
async with self.session.post(
|
||||||
login_url, data=login_data, cookies=session.cookies
|
"/login", data={"_xsrf": xsrf_token, "password": self.password}, allow_redirects=False
|
||||||
)
|
) as response:
|
||||||
login_response.raise_for_status()
|
|
||||||
headers["X-XSRFToken"] = xsrf_token
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"stdout": "",
|
|
||||||
"stderr": f"Authentication Error: {str(e)}",
|
|
||||||
"result": "",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Construct API URLs with authentication token if provided
|
|
||||||
params = f"?token={token}" if token else ""
|
|
||||||
kernel_url = urljoin(jupyter_url, f"/api/kernels{params}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = session.post(kernel_url, headers=headers, cookies=session.cookies)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
kernel_id = response.json()["id"]
|
self.session.cookie_jar.update_cookies(response.cookies)
|
||||||
|
|
||||||
websocket_url = urljoin(
|
# token authentication
|
||||||
jupyter_url.replace("http", "ws"),
|
if self.token:
|
||||||
f"/api/kernels/{kernel_id}/channels{params}",
|
self.params.update({"token": self.token})
|
||||||
)
|
|
||||||
|
|
||||||
|
async def init_kernel(self) -> None:
|
||||||
|
async with self.session.post(url="/api/kernels", params=self.params) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
kernel_data = await response.json()
|
||||||
|
self.kernel_id = kernel_data["id"]
|
||||||
|
|
||||||
|
def init_ws(self) -> (str, dict):
|
||||||
|
ws_base = self.base_url.replace("http", "ws")
|
||||||
|
ws_params = "?" + "&".join([f"{key}={val}" for key, val in self.params.items()])
|
||||||
|
websocket_url = f"{ws_base}/api/kernels/{self.kernel_id}/channels{ws_params if len(ws_params) > 1 else ''}"
|
||||||
ws_headers = {}
|
ws_headers = {}
|
||||||
if password and not token:
|
if self.password and not self.token:
|
||||||
ws_headers["X-XSRFToken"] = session.cookies.get("_xsrf")
|
ws_headers = {
|
||||||
cookies = {name: value for name, value in session.cookies.items()}
|
"Cookie": "; ".join([f"{cookie.key}={cookie.value}" for cookie in self.session.cookie_jar]),
|
||||||
ws_headers["Cookie"] = "; ".join(
|
**self.session.headers,
|
||||||
[f"{name}={value}" for name, value in cookies.items()]
|
}
|
||||||
)
|
return websocket_url, ws_headers
|
||||||
|
|
||||||
async with websockets.connect(
|
async def execute_code(self) -> None:
|
||||||
websocket_url, additional_headers=ws_headers
|
# initialize ws
|
||||||
) as ws:
|
websocket_url, ws_headers = self.init_ws()
|
||||||
msg_id = str(uuid.uuid4())
|
# execute
|
||||||
execute_request = {
|
async with websockets.connect(websocket_url, additional_headers=ws_headers) as ws:
|
||||||
|
await self.execute_in_jupyter(ws)
|
||||||
|
|
||||||
|
async def execute_in_jupyter(self, ws: ClientConnection) -> None:
|
||||||
|
# send message
|
||||||
|
msg_id = uuid.uuid4().hex
|
||||||
|
await ws.send(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
"header": {
|
"header": {
|
||||||
"msg_id": msg_id,
|
"msg_id": msg_id,
|
||||||
"msg_type": "execute_request",
|
"msg_type": "execute_request",
|
||||||
"username": "user",
|
"username": "user",
|
||||||
"session": str(uuid.uuid4()),
|
"session": uuid.uuid4().hex,
|
||||||
"date": "",
|
"date": "",
|
||||||
"version": "5.3",
|
"version": "5.3",
|
||||||
},
|
},
|
||||||
"parent_header": {},
|
"parent_header": {},
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"content": {
|
"content": {
|
||||||
"code": code,
|
"code": self.code,
|
||||||
"silent": False,
|
"silent": False,
|
||||||
"store_history": True,
|
"store_history": True,
|
||||||
"user_expressions": {},
|
"user_expressions": {},
|
||||||
@ -93,56 +141,49 @@ async def execute_code_jupyter(
|
|||||||
},
|
},
|
||||||
"channel": "shell",
|
"channel": "shell",
|
||||||
}
|
}
|
||||||
await ws.send(json.dumps(execute_request))
|
)
|
||||||
|
)
|
||||||
|
# parse message
|
||||||
stdout, stderr, result = "", "", []
|
stdout, stderr, result = "", "", []
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
message = await asyncio.wait_for(ws.recv(), timeout)
|
# wait for message
|
||||||
|
message = await asyncio.wait_for(ws.recv(), self.timeout)
|
||||||
message_data = json.loads(message)
|
message_data = json.loads(message)
|
||||||
if message_data.get("parent_header", {}).get("msg_id") == msg_id:
|
# msg id not match, skip
|
||||||
|
if message_data.get("parent_header", {}).get("msg_id") != msg_id:
|
||||||
|
continue
|
||||||
|
# check message type
|
||||||
msg_type = message_data.get("msg_type")
|
msg_type = message_data.get("msg_type")
|
||||||
|
match msg_type:
|
||||||
if msg_type == "stream":
|
case "stream":
|
||||||
if message_data["content"]["name"] == "stdout":
|
if message_data["content"]["name"] == "stdout":
|
||||||
stdout += message_data["content"]["text"]
|
stdout += message_data["content"]["text"]
|
||||||
elif message_data["content"]["name"] == "stderr":
|
elif message_data["content"]["name"] == "stderr":
|
||||||
stderr += message_data["content"]["text"]
|
stderr += message_data["content"]["text"]
|
||||||
|
case "execute_result" | "display_data":
|
||||||
elif msg_type in ("execute_result", "display_data"):
|
|
||||||
data = message_data["content"]["data"]
|
data = message_data["content"]["data"]
|
||||||
if "image/png" in data:
|
if "image/png" in data:
|
||||||
result.append(
|
result.append(f"data:image/png;base64,{data['image/png']}")
|
||||||
f"data:image/png;base64,{data['image/png']}"
|
|
||||||
)
|
|
||||||
elif "text/plain" in data:
|
elif "text/plain" in data:
|
||||||
result.append(data["text/plain"])
|
result.append(data["text/plain"])
|
||||||
|
case "error":
|
||||||
elif msg_type == "error":
|
|
||||||
stderr += "\n".join(message_data["content"]["traceback"])
|
stderr += "\n".join(message_data["content"]["traceback"])
|
||||||
|
case "status":
|
||||||
elif (
|
if message_data["content"]["execution_state"] == "idle":
|
||||||
msg_type == "status"
|
|
||||||
and message_data["content"]["execution_state"] == "idle"
|
|
||||||
):
|
|
||||||
break
|
break
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
stderr += "\nExecution timed out."
|
stderr += "\nExecution timed out."
|
||||||
break
|
break
|
||||||
|
self.result.stdout = stdout.strip()
|
||||||
|
self.result.stderr = stderr.strip()
|
||||||
|
self.result.result = "\n".join(result).strip() if result else ""
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {"stdout": "", "stderr": f"Error: {str(e)}", "result": ""}
|
|
||||||
|
|
||||||
finally:
|
async def execute_code_jupyter(
|
||||||
if kernel_id:
|
base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60
|
||||||
requests.delete(
|
) -> dict:
|
||||||
f"{kernel_url}/{kernel_id}", headers=headers, cookies=session.cookies
|
async with JupyterCodeExecuter(base_url, code, token, password, timeout) as executor:
|
||||||
)
|
result = await executor.run()
|
||||||
|
return result.model_dump()
|
||||||
return {
|
|
||||||
"stdout": stdout.strip(),
|
|
||||||
"stderr": stderr.strip(),
|
|
||||||
"result": "\n".join(result).strip() if result else "",
|
|
||||||
}
|
|
||||||
|
Loading…
Reference in New Issue
Block a user