mirror of
https://github.com/clearml/clearml-agent
synced 2025-02-26 05:59:24 +00:00
104 lines
3.0 KiB
Python
104 lines
3.0 KiB
Python
![]() |
import json
|
||
|
|
||
|
from .algorithms import get_default_algorithms
|
||
|
from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError
|
||
|
|
||
|
|
||
|
class PyJWK:
|
||
|
def __init__(self, jwk_data, algorithm=None):
|
||
|
self._algorithms = get_default_algorithms()
|
||
|
self._jwk_data = jwk_data
|
||
|
|
||
|
kty = self._jwk_data.get("kty", None)
|
||
|
if not kty:
|
||
|
raise InvalidKeyError(f"kty is not found: {self._jwk_data}")
|
||
|
|
||
|
if not algorithm and isinstance(self._jwk_data, dict):
|
||
|
algorithm = self._jwk_data.get("alg", None)
|
||
|
|
||
|
if not algorithm:
|
||
|
# Determine alg with kty (and crv).
|
||
|
crv = self._jwk_data.get("crv", None)
|
||
|
if kty == "EC":
|
||
|
if crv == "P-256" or not crv:
|
||
|
algorithm = "ES256"
|
||
|
elif crv == "P-384":
|
||
|
algorithm = "ES384"
|
||
|
elif crv == "P-521":
|
||
|
algorithm = "ES512"
|
||
|
elif crv == "secp256k1":
|
||
|
algorithm = "ES256K"
|
||
|
else:
|
||
|
raise InvalidKeyError(f"Unsupported crv: {crv}")
|
||
|
elif kty == "RSA":
|
||
|
algorithm = "RS256"
|
||
|
elif kty == "oct":
|
||
|
algorithm = "HS256"
|
||
|
elif kty == "OKP":
|
||
|
if not crv:
|
||
|
raise InvalidKeyError(f"crv is not found: {self._jwk_data}")
|
||
|
if crv == "Ed25519":
|
||
|
algorithm = "EdDSA"
|
||
|
else:
|
||
|
raise InvalidKeyError(f"Unsupported crv: {crv}")
|
||
|
else:
|
||
|
raise InvalidKeyError(f"Unsupported kty: {kty}")
|
||
|
|
||
|
self.Algorithm = self._algorithms.get(algorithm)
|
||
|
|
||
|
if not self.Algorithm:
|
||
|
raise PyJWKError(f"Unable to find a algorithm for key: {self._jwk_data}")
|
||
|
|
||
|
self.key = self.Algorithm.from_jwk(self._jwk_data)
|
||
|
|
||
|
@staticmethod
|
||
|
def from_dict(obj, algorithm=None):
|
||
|
return PyJWK(obj, algorithm)
|
||
|
|
||
|
@staticmethod
|
||
|
def from_json(data, algorithm=None):
|
||
|
obj = json.loads(data)
|
||
|
return PyJWK.from_dict(obj, algorithm)
|
||
|
|
||
|
@property
|
||
|
def key_type(self):
|
||
|
return self._jwk_data.get("kty", None)
|
||
|
|
||
|
@property
|
||
|
def key_id(self):
|
||
|
return self._jwk_data.get("kid", None)
|
||
|
|
||
|
@property
|
||
|
def public_key_use(self):
|
||
|
return self._jwk_data.get("use", None)
|
||
|
|
||
|
|
||
|
class PyJWKSet:
|
||
|
def __init__(self, keys):
|
||
|
self.keys = []
|
||
|
|
||
|
if not keys or not isinstance(keys, list):
|
||
|
raise PyJWKSetError("Invalid JWK Set value")
|
||
|
|
||
|
if len(keys) == 0:
|
||
|
raise PyJWKSetError("The JWK Set did not contain any keys")
|
||
|
|
||
|
for key in keys:
|
||
|
self.keys.append(PyJWK(key))
|
||
|
|
||
|
@staticmethod
|
||
|
def from_dict(obj):
|
||
|
keys = obj.get("keys", [])
|
||
|
return PyJWKSet(keys)
|
||
|
|
||
|
@staticmethod
|
||
|
def from_json(data):
|
||
|
obj = json.loads(data)
|
||
|
return PyJWKSet.from_dict(obj)
|
||
|
|
||
|
def __getitem__(self, kid):
|
||
|
for key in self.keys:
|
||
|
if key.key_id == kid:
|
||
|
return key
|
||
|
raise KeyError(f"keyset has no key for kid: {kid}")
|