Better support for PyJWT 2.4

This commit is contained in:
allegroai 2022-07-08 17:44:17 +03:00
parent 52c0c4d438
commit ee87778548
2 changed files with 11 additions and 15 deletions

View File

@ -4,6 +4,7 @@ import attr
import jsonmodels.models import jsonmodels.models
import jwt import jwt
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from jwt.algorithms import get_default_algorithms
from apiserver.bll.event.event_common import ( from apiserver.bll.event.event_common import (
check_empty_data, check_empty_data,
@ -77,10 +78,7 @@ class EventsIterator:
with translate_errors_context(), TimingContext("es", "count_task_events"): with translate_errors_context(), TimingContext("es", "count_task_events"):
es_result = count_company_events( es_result = count_company_events(
self.es, self.es, company_id=company_id, event_type=event_type, body=es_req,
company_id=company_id,
event_type=event_type,
body=es_req,
) )
return es_result["count"] return es_result["count"]
@ -117,10 +115,7 @@ class EventsIterator:
with translate_errors_context(), TimingContext("es", "get_task_events"): with translate_errors_context(), TimingContext("es", "get_task_events"):
es_result = search_company_events( es_result = search_company_events(
self.es, self.es, company_id=company_id, event_type=event_type, body=es_req,
company_id=company_id,
event_type=event_type,
body=es_req,
) )
hits = es_result["hits"]["hits"] hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]["value"] hits_total = es_result["hits"]["total"]["value"]
@ -140,10 +135,7 @@ class EventsIterator:
}, },
} }
es_result = search_company_events( es_result = search_company_events(
self.es, self.es, company_id=company_id, event_type=event_type, body=es_req,
company_id=company_id,
event_type=event_type,
body=es_req,
) )
last_second_hits = es_result["hits"]["hits"] last_second_hits = es_result["hits"]["hits"]
if not last_second_hits or len(last_second_hits) < 2: if not last_second_hits or len(last_second_hits) < 2:
@ -199,7 +191,7 @@ class Scroll(jsonmodels.models.Base):
key=config.get( key=config.get(
"services.events.events_retrieval.scroll_id_key", "1234567890" "services.events.events_retrieval.scroll_id_key", "1234567890"
), ),
algorithms=["HS256"], algorithms=get_default_algorithms(),
) )
) )
except jwt.PyJWTError: except jwt.PyJWTError:

View File

@ -2,6 +2,8 @@ import jwt
from datetime import datetime, timedelta from datetime import datetime, timedelta
from jwt.algorithms import get_default_algorithms
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.database.model.auth import Role from apiserver.database.model.auth import Role
@ -11,7 +13,6 @@ from .payload import Payload
token_secret = config.get("secure.auth.token_secret") token_secret = config.get("secure.auth.token_secret")
log = config.logger(__file__) log = config.logger(__file__)
@ -72,7 +73,10 @@ class Token(Payload):
{"verify_signature": False, "verify_exp": True} if not verify else None {"verify_signature": False, "verify_exp": True} if not verify else None
) )
return jwt.decode( return jwt.decode(
encoded_token, token_secret, algorithms=["HS256"], options=options encoded_token,
token_secret,
algorithms=get_default_algorithms(),
options=options,
) )
@classmethod @classmethod