fix(jupyter): fix kernel_id not set and optimize code

This commit is contained in:
orenzhang 2025-03-03 22:05:50 +08:00
parent 3d6e48b05e
commit 744ffbb1fb
No known key found for this signature in database
GPG Key ID: 73D45F78147E506C

View File

@ -1,14 +1,15 @@
import asyncio import asyncio
import json import json
import uuid import uuid
from typing import Optional
import httpx
import websockets import websockets
import requests
from urllib.parse import urljoin
async def execute_code_jupyter( async def execute_code_jupyter(
jupyter_url, code, token=None, password=None, timeout=10 jupyter_url: str, code: str, token: str = None, password: str = None, timeout: int = 60
): ) -> Optional[dict]:
""" """
Executes Python code in a Jupyter kernel. Executes Python code in a Jupyter kernel.
Supports authentication with a token or password. Supports authentication with a token or password.
@ -20,80 +21,70 @@ async def execute_code_jupyter(
:return: Dictionary with stdout, stderr, and result :return: Dictionary with stdout, stderr, and result
- Images are prefixed with "base64:image/png," and separated by newlines if multiple. - Images are prefixed with "base64:image/png," and separated by newlines if multiple.
""" """
session = requests.Session() # Maintain cookies
headers = {} # Headers for requests
# Authenticate using password jupyter_url = jupyter_url.rstrip("/")
client = httpx.AsyncClient(base_url=jupyter_url, timeout=timeout, follow_redirects=True)
headers = {}
# password authentication
if password and not token: if password and not token:
try: try:
login_url = urljoin(jupyter_url, "/login") response = await client.get("/login")
response = session.get(login_url)
response.raise_for_status() response.raise_for_status()
xsrf_token = session.cookies.get("_xsrf") xsrf_token = response.cookies.get("_xsrf")
if not xsrf_token: if not xsrf_token:
raise ValueError("Failed to fetch _xsrf token") raise ValueError("_xsrf token not found")
response = await client.post("/login", data={"_xsrf": xsrf_token, "password": password})
login_data = {"_xsrf": xsrf_token, "password": password} response.raise_for_status()
login_response = session.post(
login_url, data=login_data, cookies=session.cookies
)
login_response.raise_for_status()
headers["X-XSRFToken"] = xsrf_token headers["X-XSRFToken"] = xsrf_token
except Exception as e: except Exception as e:
return { return {"stdout": "", "stderr": f"Authentication Error: {str(e)}", "result": ""}
"stdout": "",
"stderr": f"Authentication Error: {str(e)}",
"result": "",
}
# Construct API URLs with authentication token if provided # token authentication
params = f"?token={token}" if token else "" params = {"token": token} if token else {}
kernel_url = urljoin(jupyter_url, f"/api/kernels{params}")
kernel_id = ""
try: try:
response = session.post(kernel_url, headers=headers, cookies=session.cookies) response = await client.post(url="/api/kernels", params=params, headers=headers)
response.raise_for_status() response.raise_for_status()
kernel_id = response.json()["id"] kernel_id = response.json()["id"]
websocket_url = urljoin( ws_base = jupyter_url.replace("http", "ws")
jupyter_url.replace("http", "ws"), websocket_url = f"{ws_base}/api/kernels/{kernel_id}/channels" + (f"?token={token}" if token else "")
f"/api/kernels/{kernel_id}/channels{params}",
)
ws_headers = {} ws_headers = {}
if password and not token: if password and not token:
ws_headers["X-XSRFToken"] = session.cookies.get("_xsrf") ws_headers = {
cookies = {name: value for name, value in session.cookies.items()} "X-XSRFToken": client.cookies.get("_xsrf"),
ws_headers["Cookie"] = "; ".join( "Cookie": "; ".join([f"{name}={value}" for name, value in client.cookies.items()]),
[f"{name}={value}" for name, value in cookies.items()]
)
async with websockets.connect(
websocket_url, additional_headers=ws_headers
) as ws:
msg_id = str(uuid.uuid4())
execute_request = {
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "user",
"session": str(uuid.uuid4()),
"date": "",
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"channel": "shell",
} }
await ws.send(json.dumps(execute_request))
async with websockets.connect(websocket_url, additional_headers=ws_headers) as ws:
msg_id = str(uuid.uuid4())
await ws.send(
json.dumps(
{
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "user",
"session": str(uuid.uuid4()),
"date": "",
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"channel": "shell",
}
)
)
stdout, stderr, result = "", "", [] stdout, stderr, result = "", "", []
@ -101,32 +92,27 @@ async def execute_code_jupyter(
try: try:
message = await asyncio.wait_for(ws.recv(), timeout) message = await asyncio.wait_for(ws.recv(), timeout)
message_data = json.loads(message) message_data = json.loads(message)
if message_data.get("parent_header", {}).get("msg_id") == msg_id: if message_data.get("parent_header", {}).get("msg_id") != msg_id:
msg_type = message_data.get("msg_type") continue
if msg_type == "stream": msg_type = message_data.get("msg_type")
match msg_type:
case "stream":
if message_data["content"]["name"] == "stdout": if message_data["content"]["name"] == "stdout":
stdout += message_data["content"]["text"] stdout += message_data["content"]["text"]
elif message_data["content"]["name"] == "stderr": elif message_data["content"]["name"] == "stderr":
stderr += message_data["content"]["text"] stderr += message_data["content"]["text"]
case "execute_result" | "display_data":
elif msg_type in ("execute_result", "display_data"):
data = message_data["content"]["data"] data = message_data["content"]["data"]
if "image/png" in data: if "image/png" in data:
result.append( result.append(f"data:image/png;base64,{data['image/png']}")
f"data:image/png;base64,{data['image/png']}"
)
elif "text/plain" in data: elif "text/plain" in data:
result.append(data["text/plain"]) result.append(data["text/plain"])
case "error":
elif msg_type == "error":
stderr += "\n".join(message_data["content"]["traceback"]) stderr += "\n".join(message_data["content"]["traceback"])
case "status":
elif ( if message_data["content"]["execution_state"] == "idle":
msg_type == "status" break
and message_data["content"]["execution_state"] == "idle"
):
break
except asyncio.TimeoutError: except asyncio.TimeoutError:
stderr += "\nExecution timed out." stderr += "\nExecution timed out."
@ -137,12 +123,7 @@ async def execute_code_jupyter(
finally: finally:
if kernel_id: if kernel_id:
requests.delete( await client.delete(f"/api/kernels/{kernel_id}", headers=headers, params=params)
f"{kernel_url}/{kernel_id}", headers=headers, cookies=session.cookies await client.aclose()
)
return { return {"stdout": stdout.strip(), "stderr": stderr.strip(), "result": "\n".join(result).strip() if result else ""}
"stdout": stdout.strip(),
"stderr": stderr.strip(),
"result": "\n".join(result).strip() if result else "",
}