diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index a346e31a..2f7d52a0 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -450,12 +450,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): # type: () -> Model return self._get_output_model(upload_required=False, force=True) - def _get_output_model(self, upload_required=True, force=False): - # type: (bool, bool) -> Model + def _get_output_model(self, upload_required=True, force=False, model_id=None): + # type: (bool, bool, Optional[str]) -> Model return Model( session=self.session, - model_id=None if force else self._get_task_property( - 'output.model', raise_on_error=False, log_on_error=False), + model_id=model_id or (None if force else self._get_task_property( + 'output.model', raise_on_error=False, log_on_error=False)), cache_dir=self.cache_dir, upload_storage_uri=self.storage_uri or self.get_output_destination( raise_on_error=upload_required, log_on_error=upload_required),