mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	refac: access control
This commit is contained in:
		
							parent
							
								
									a2dcbc41e5
								
							
						
					
					
						commit
						4a34ca35f0
					
				@ -43,6 +43,7 @@ from open_webui.utils.payload import (
 | 
			
		||||
    apply_model_system_prompt_to_body,
 | 
			
		||||
)
 | 
			
		||||
from open_webui.utils.utils import get_admin_user, get_verified_user
 | 
			
		||||
from open_webui.utils.access_control import has_access
 | 
			
		||||
 | 
			
		||||
log = logging.getLogger(__name__)
 | 
			
		||||
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
 | 
			
		||||
@ -316,22 +317,9 @@ async def get_all_models():
 | 
			
		||||
async def get_ollama_tags(
 | 
			
		||||
    url_idx: Optional[int] = None, user=Depends(get_verified_user)
 | 
			
		||||
):
 | 
			
		||||
    models = []
 | 
			
		||||
    if url_idx is None:
 | 
			
		||||
        models = await get_all_models()
 | 
			
		||||
 | 
			
		||||
        # TODO: Check User Group and Filter Models
 | 
			
		||||
        # if app.state.config.ENABLE_MODEL_FILTER:
 | 
			
		||||
        #     if user.role == "user":
 | 
			
		||||
        #         models["models"] = list(
 | 
			
		||||
        #             filter(
 | 
			
		||||
        #                 lambda model: model["name"]
 | 
			
		||||
        #                 in app.state.config.MODEL_FILTER_LIST,
 | 
			
		||||
        #                 models["models"],
 | 
			
		||||
        #             )
 | 
			
		||||
        #         )
 | 
			
		||||
        #         return models
 | 
			
		||||
 | 
			
		||||
        return models
 | 
			
		||||
    else:
 | 
			
		||||
        url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
 | 
			
		||||
@ -347,7 +335,7 @@ async def get_ollama_tags(
 | 
			
		||||
            r = requests.request(method="GET", url=f"{url}/api/tags", headers=headers)
 | 
			
		||||
            r.raise_for_status()
 | 
			
		||||
 | 
			
		||||
            return r.json()
 | 
			
		||||
            models = r.json()
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            log.exception(e)
 | 
			
		||||
            error_detail = "Open WebUI: Server Connection Error"
 | 
			
		||||
@ -363,6 +351,23 @@ async def get_ollama_tags(
 | 
			
		||||
                status_code=r.status_code if r else 500,
 | 
			
		||||
                detail=error_detail,
 | 
			
		||||
            )
 | 
			
		||||
        
 | 
			
		||||
    if user.role == "user":
 | 
			
		||||
        # Filter models based on user access control
 | 
			
		||||
        filtered_models = []
 | 
			
		||||
        for model in models.get("models", []):
 | 
			
		||||
            model_info = Models.get_model_by_id(model["model"])
 | 
			
		||||
            if model_info:
 | 
			
		||||
                if has_access(
 | 
			
		||||
                    user.id, type="read", access_control=model_info.access_control
 | 
			
		||||
                ):
 | 
			
		||||
                    filtered_models.append(model)
 | 
			
		||||
            else:
 | 
			
		||||
                filtered_models.append(model)
 | 
			
		||||
        models["models"] = filtered_models
 | 
			
		||||
    
 | 
			
		||||
        
 | 
			
		||||
    return models
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.get("/api/version")
 | 
			
		||||
@ -926,16 +931,9 @@ async def generate_chat_completion(
 | 
			
		||||
    if "metadata" in payload:
 | 
			
		||||
        del payload["metadata"]
 | 
			
		||||
 | 
			
		||||
    model_id = form_data.model
 | 
			
		||||
 | 
			
		||||
    # TODO: Check User Group and Filter Models
 | 
			
		||||
    # if not bypass_filter:
 | 
			
		||||
    #     if app.state.config.ENABLE_MODEL_FILTER:
 | 
			
		||||
    #         if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
 | 
			
		||||
    #             raise HTTPException(
 | 
			
		||||
    #                 status_code=403,
 | 
			
		||||
    #                 detail="Model not found",
 | 
			
		||||
    #             )
 | 
			
		||||
    model_id = payload["model"]
 | 
			
		||||
    if ":" not in model_id:
 | 
			
		||||
        model_id = f"{model_id}:latest"
 | 
			
		||||
 | 
			
		||||
    model_info = Models.get_model_by_id(model_id)
 | 
			
		||||
 | 
			
		||||
@ -954,9 +952,19 @@ async def generate_chat_completion(
 | 
			
		||||
            )
 | 
			
		||||
            payload = apply_model_system_prompt_to_body(params, payload, user)
 | 
			
		||||
 | 
			
		||||
        # Check if user has access to the model
 | 
			
		||||
        if not bypass_filter and user.role == "user" and not has_access(
 | 
			
		||||
            user.id, type="read", access_control=model_info.access_control
 | 
			
		||||
        ):
 | 
			
		||||
            raise HTTPException(
 | 
			
		||||
                status_code=403,
 | 
			
		||||
                detail="Model not found",
 | 
			
		||||
            )
 | 
			
		||||
    
 | 
			
		||||
    if ":" not in payload["model"]:
 | 
			
		||||
        payload["model"] = f"{payload['model']}:latest"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    url = await get_ollama_url(url_idx, payload["model"])
 | 
			
		||||
    log.info(f"url: {url}")
 | 
			
		||||
    log.debug(f"generate_chat_completion() - 2.payload = {payload}")
 | 
			
		||||
@ -1015,17 +1023,11 @@ async def generate_openai_chat_completion(
 | 
			
		||||
        del payload["metadata"]
 | 
			
		||||
 | 
			
		||||
    model_id = completion_form.model
 | 
			
		||||
    if ":" not in model_id:
 | 
			
		||||
        model_id = f"{model_id}:latest"
 | 
			
		||||
 | 
			
		||||
    # TODO: Check User Group and Filter Models
 | 
			
		||||
    # if app.state.config.ENABLE_MODEL_FILTER:
 | 
			
		||||
    #     if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
 | 
			
		||||
    #         raise HTTPException(
 | 
			
		||||
    #             status_code=403,
 | 
			
		||||
    #             detail="Model not found",
 | 
			
		||||
    #         )
 | 
			
		||||
 | 
			
		||||
    model_info = Models.get_model_by_id(model_id)
 | 
			
		||||
 | 
			
		||||
    if model_info:
 | 
			
		||||
        if model_info.base_model_id:
 | 
			
		||||
            payload["model"] = model_info.base_model_id
 | 
			
		||||
@ -1036,6 +1038,15 @@ async def generate_openai_chat_completion(
 | 
			
		||||
            payload = apply_model_params_to_body_openai(params, payload)
 | 
			
		||||
            payload = apply_model_system_prompt_to_body(params, payload, user)
 | 
			
		||||
 | 
			
		||||
        # Check if user has access to the model
 | 
			
		||||
        if user.role == "user" and not has_access(
 | 
			
		||||
            user.id, type="read", access_control=model_info.access_control
 | 
			
		||||
        ):
 | 
			
		||||
            raise HTTPException(
 | 
			
		||||
                status_code=403,
 | 
			
		||||
                detail="Model not found",
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    if ":" not in payload["model"]:
 | 
			
		||||
        payload["model"] = f"{payload['model']}:latest"
 | 
			
		||||
 | 
			
		||||
@ -1060,32 +1071,19 @@ async def get_openai_models(
 | 
			
		||||
    url_idx: Optional[int] = None,
 | 
			
		||||
    user=Depends(get_verified_user),
 | 
			
		||||
):
 | 
			
		||||
    
 | 
			
		||||
    models = []
 | 
			
		||||
    if url_idx is None:
 | 
			
		||||
        models = await get_all_models()
 | 
			
		||||
 | 
			
		||||
        # TODO: Check User Group and Filter Models
 | 
			
		||||
        # if app.state.config.ENABLE_MODEL_FILTER:
 | 
			
		||||
        #     if user.role == "user":
 | 
			
		||||
        #         models["models"] = list(
 | 
			
		||||
        #             filter(
 | 
			
		||||
        #                 lambda model: model["name"]
 | 
			
		||||
        #                 in app.state.config.MODEL_FILTER_LIST,
 | 
			
		||||
        #                 models["models"],
 | 
			
		||||
        #             )
 | 
			
		||||
        #         )
 | 
			
		||||
 | 
			
		||||
        return {
 | 
			
		||||
            "data": [
 | 
			
		||||
        model_list = await get_all_models()
 | 
			
		||||
        models = [
 | 
			
		||||
                {
 | 
			
		||||
                    "id": model["model"],
 | 
			
		||||
                    "object": "model",
 | 
			
		||||
                    "created": int(time.time()),
 | 
			
		||||
                    "owned_by": "openai",
 | 
			
		||||
                }
 | 
			
		||||
                for model in models["models"]
 | 
			
		||||
            ],
 | 
			
		||||
            "object": "list",
 | 
			
		||||
        }
 | 
			
		||||
                for model in model_list["models"]
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
@ -1093,10 +1091,9 @@ async def get_openai_models(
 | 
			
		||||
            r = requests.request(method="GET", url=f"{url}/api/tags")
 | 
			
		||||
            r.raise_for_status()
 | 
			
		||||
 | 
			
		||||
            models = r.json()
 | 
			
		||||
            model_list = r.json()
 | 
			
		||||
 | 
			
		||||
            return {
 | 
			
		||||
                "data": [
 | 
			
		||||
            models = [
 | 
			
		||||
                    {
 | 
			
		||||
                        "id": model["model"],
 | 
			
		||||
                        "object": "model",
 | 
			
		||||
@ -1104,10 +1101,7 @@ async def get_openai_models(
 | 
			
		||||
                        "owned_by": "openai",
 | 
			
		||||
                    }
 | 
			
		||||
                    for model in models["models"]
 | 
			
		||||
                ],
 | 
			
		||||
                "object": "list",
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
                ]
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            log.exception(e)
 | 
			
		||||
            error_detail = "Open WebUI: Server Connection Error"
 | 
			
		||||
@ -1123,6 +1117,27 @@ async def get_openai_models(
 | 
			
		||||
                status_code=r.status_code if r else 500,
 | 
			
		||||
                detail=error_detail,
 | 
			
		||||
            )
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
    if user.role == "user":
 | 
			
		||||
        # Filter models based on user access control
 | 
			
		||||
        filtered_models = []
 | 
			
		||||
        for model in models:
 | 
			
		||||
            model_info = Models.get_model_by_id(model["id"])
 | 
			
		||||
            if model_info:
 | 
			
		||||
                if has_access(
 | 
			
		||||
                    user.id, type="read", access_control=model_info.access_control
 | 
			
		||||
                ):
 | 
			
		||||
                    filtered_models.append(model)
 | 
			
		||||
            else:
 | 
			
		||||
                filtered_models.append(model)
 | 
			
		||||
        models = filtered_models
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
    return {
 | 
			
		||||
            "data": models,
 | 
			
		||||
            "object": "list",
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UrlForm(BaseModel):
 | 
			
		||||
 | 
			
		||||
@ -501,7 +501,7 @@ async def generate_chat_completion(
 | 
			
		||||
        payload = apply_model_system_prompt_to_body(params, payload, user)
 | 
			
		||||
 | 
			
		||||
        # Check if user has access to the model
 | 
			
		||||
        if user.role == "user" and not has_access(
 | 
			
		||||
        if not bypass_filter and user.role == "user" and not has_access(
 | 
			
		||||
            user.id, type="read", access_control=model_info.access_control
 | 
			
		||||
        ):
 | 
			
		||||
            raise HTTPException(
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user