Fix get_or_create_project() might crash when called in parallel

This commit is contained in:
allegroai 2023-02-28 17:14:53 +02:00
parent aff7f87975
commit fba1f4c9a6

View File

@ -84,34 +84,52 @@ def rename_project(session, project_name, new_project_name):
def get_or_create_project(session, project_name, description=None, system_tags=None, project_id=None):
# noinspection PyBroadException
try:
return _get_or_create_project(session, project_name, description=description, system_tags=system_tags, project_id=project_id)
except Exception:
# we only get here if the following race happens:
# imagine there are 2 processes the call `_get_or_create_project` and the project requested doesn't exist.
# because it doesn't exist, it needs to be created. the 2 processes will call the `CreateRequest` one after
# another, one of which will raise an Exception because the project has already been created.
# so we need to retry in this case (the retry should now succeed)
return _get_or_create_project(session, project_name, description=description, system_tags=system_tags, project_id=project_id)
def _get_or_create_project(session, project_name, description=None, system_tags=None, project_id=None):
"""Return the ID of an existing project, or if it does not exist, make a new one and return that ID instead."""
project_system_tags = []
if not project_id:
res = session.send(projects.GetAllRequest(
res = session.send(
projects.GetAllRequest(
name=exact_match_regex(project_name),
only_fields=['id', 'system_tags'] if system_tags else ['id'],
search_hidden=True, _allow_extra_fields_=True))
only_fields=["id", "system_tags"] if system_tags else ["id"],
search_hidden=True,
_allow_extra_fields_=True,
)
)
if res and res.response and res.response.projects:
project_id = res.response.projects[0].id
if system_tags:
project_system_tags = res.response.projects[0].system_tags
if project_id and system_tags and (not project_system_tags or
set(project_system_tags) & set(system_tags) != set(system_tags)):
if (
project_id
and system_tags
and (not project_system_tags or not set(system_tags).issubset(project_system_tags))
):
# set system_tags
session.send(
projects.UpdateRequest(
project=project_id, system_tags=list(set((project_system_tags or []) + system_tags))
)
projects.UpdateRequest(project=project_id, system_tags=list(set((project_system_tags or []) + system_tags)))
)
if project_id:
return project_id
# Project was not found, so create a new one
res = session.send(projects.CreateRequest(
name=project_name, description=description or '', system_tags=system_tags))
# project was not found, so create a new one
res = session.send(
projects.CreateRequest(name=project_name, description=description or "", system_tags=system_tags)
)
return res.response.id if res else None