diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 6510d7c99..9005ce033 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -3,6 +3,7 @@ import logging import sys import os import base64 +import textwrap import asyncio from aiocache import cached @@ -84,6 +85,7 @@ from open_webui.config import ( CACHE_DIR, DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, DEFAULT_CODE_INTERPRETER_PROMPT, + CODE_INTERPRETER_BLACKLISTED_MODULES, ) from open_webui.env import ( SRC_LOG_LEVELS, @@ -2207,6 +2209,25 @@ async def process_chat_response( try: if content_blocks[-1]["attributes"].get("type") == "code": code = content_blocks[-1]["content"] + if CODE_INTERPRETER_BLACKLISTED_MODULES: + blocking_code = textwrap.dedent(f""" + import builtins + + BLACKLISTED_MODULES = {CODE_INTERPRETER_BLACKLISTED_MODULES} + + _real_import = builtins.__import__ + def restricted_import(name, globals=None, locals=None, fromlist=(), level=0): + if name.split('.')[0] in BLACKLISTED_MODULES: + importer_name = globals.get('__name__') if globals else None + if importer_name == '__main__': + raise ImportError( + f"Direct import of module {{name}} is restricted." + ) + return _real_import(name, globals, locals, fromlist, level) + + builtins.__import__ = restricted_import + """) + code = blocking_code + "\n" + code if ( request.app.state.config.CODE_INTERPRETER_ENGINE