mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Add scikit-learn support (joblib) and xgboost support
This commit is contained in:
		
							parent
							
								
									1bb06c0190
								
							
						
					
					
						commit
						19c5f05912
					
				
							
								
								
									
										25
									
								
								examples/joblib_example.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								examples/joblib_example.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,25 @@ | ||||
| import joblib | ||||
| 
 | ||||
| from sklearn import datasets | ||||
| from sklearn.linear_model import LogisticRegression | ||||
| from sklearn.model_selection import train_test_split | ||||
| 
 | ||||
| 
 | ||||
| from trains import Task | ||||
| 
 | ||||
| task = Task.init(project_name="examples", task_name="joblib test") | ||||
| 
 | ||||
| iris = datasets.load_iris() | ||||
| X = iris.data | ||||
| y = iris.target | ||||
| 
 | ||||
| X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | ||||
| 
 | ||||
| model = LogisticRegression()  # sklearn LogisticRegression class | ||||
| model.fit(X_train, y_train) | ||||
| 
 | ||||
| joblib.dump(model, 'model.pkl', compress=True) | ||||
| 
 | ||||
| loaded_model = joblib.load('model.pkl') | ||||
| result = loaded_model.score(X_test, y_test) | ||||
| print(result) | ||||
| @ -6,29 +6,28 @@ | ||||
| # 2 seconds per epoch on a K520 GPU. | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| import io | ||||
| import numpy as np | ||||
| import tensorflow | ||||
| 
 | ||||
| from keras.callbacks import TensorBoard, ModelCheckpoint | ||||
| from keras.datasets import mnist | ||||
| from keras.models import Sequential, Model | ||||
| from keras.layers.core import Dense, Dropout, Activation | ||||
| from keras.optimizers import SGD, Adam, RMSprop | ||||
| from keras.models import Sequential | ||||
| from keras.layers.core import Dense, Activation | ||||
| from keras.optimizers import RMSprop | ||||
| from keras.utils import np_utils | ||||
| # TODO: test these methods binding | ||||
| from keras.models import load_model, save_model, model_from_json | ||||
| 
 | ||||
| import tensorflow as tf | ||||
| from trains import Task | ||||
| 
 | ||||
| 
 | ||||
| class TensorBoardImage(TensorBoard): | ||||
|     @staticmethod | ||||
|     def make_image(tensor): | ||||
|         import tensorflow as tf | ||||
|         from PIL import Image | ||||
|         tensor = np.stack((tensor, tensor, tensor), axis=2) | ||||
|         height, width, channels = tensor.shape | ||||
|         image = Image.fromarray(tensor) | ||||
|         import io | ||||
|         output = io.BytesIO() | ||||
|         image.save(output, format='PNG') | ||||
|         image_string = output.getvalue() | ||||
| @ -38,9 +37,10 @@ class TensorBoardImage(TensorBoard): | ||||
|                                 colorspace=channels, | ||||
|                                 encoded_image_string=image_string) | ||||
| 
 | ||||
|     def on_epoch_end(self, epoch, logs={}): | ||||
|     def on_epoch_end(self, epoch, logs=None): | ||||
|         if logs is None: | ||||
|             logs = {} | ||||
|         super(TensorBoardImage, self).on_epoch_end(epoch, logs) | ||||
|         import tensorflow as tf | ||||
|         images = self.validation_data[0]  # 0 - data; 1 - labels | ||||
|         img = (255 * images[0].reshape(28, 28)).astype('uint8') | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										15
									
								
								examples/requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								examples/requirements.txt
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,15 @@ | ||||
| absl-py>=0.7.1 | ||||
| Keras>=2.2.4 | ||||
| joblib>=0.13.2 | ||||
| matplotlib>=3.1.1 | ||||
| seaborn>=0.9.0 | ||||
| sklearn>=0.0 | ||||
| tensorboard>=1.14.0 | ||||
| tensorboardX>=1.8 | ||||
| tensorflow>=1.14.0 | ||||
| torch>=1.1.0 | ||||
| torchvision>=0.3.0 | ||||
| xgboost>=0.90 | ||||
| 
 | ||||
| # sudo apt-get install graphviz | ||||
| graphviz>=0.8 | ||||
| @ -3,9 +3,6 @@ | ||||
| import tensorflow as tf | ||||
| import numpy as np | ||||
| import cv2 | ||||
| from time import sleep | ||||
| #import tensorflow.compat.v1 as tf | ||||
| #tf.disable_v2_behavior() | ||||
| 
 | ||||
| from trains import Task | ||||
| task = Task.init(project_name='examples', task_name='tensorboard toy example') | ||||
|  | ||||
							
								
								
									
										59
									
								
								examples/xgboost_sample.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								examples/xgboost_sample.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,59 @@ | ||||
| import matplotlib.pyplot as plt | ||||
| import xgboost as xgb | ||||
| from sklearn import datasets | ||||
| from sklearn.metrics import accuracy_score | ||||
| from sklearn.model_selection import train_test_split | ||||
| from xgboost import plot_tree | ||||
| 
 | ||||
| from trains import Task | ||||
| 
 | ||||
| task = Task.init(project_name='examples', task_name='XGBoost simple example') | ||||
| iris = datasets.load_iris() | ||||
| X = iris.data | ||||
| y = iris.target | ||||
| X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | ||||
| dtrain = xgb.DMatrix(X_train, label=y_train) | ||||
| dtest = xgb.DMatrix(X_test, label=y_test) | ||||
| param = { | ||||
|     'max_depth': 3,  # the maximum depth of each tree | ||||
|     'eta': 0.3,  # the training step for each iteration | ||||
|     'silent': 1,  # logging mode - quiet | ||||
|     'objective': 'multi:softprob',  # error evaluation for multiclass training | ||||
|     'num_class': 3}  # the number of classes that exist in this datset | ||||
| num_round = 20  # the number of training iterations | ||||
| 
 | ||||
| try: | ||||
|     # try to load a model | ||||
|     bst = xgb.Booster(params=param, model_file='xgb.01.model') | ||||
|     bst.load_model('xgb.01.model') | ||||
| except: | ||||
|     bst = None | ||||
| 
 | ||||
| # if we dont have one train a model | ||||
| if bst is None: | ||||
|     bst = xgb.train(param, dtrain, num_round) | ||||
| 
 | ||||
| # store trained model model v1 | ||||
| bst.save_model('xgb.01.model') | ||||
| bst.dump_model('xgb.01.raw.txt') | ||||
| 
 | ||||
| # build classifier | ||||
| model = xgb.XGBClassifier() | ||||
| model.fit(X_train, y_train) | ||||
| 
 | ||||
| # store trained classifier model | ||||
| model.save_model('xgb.02.model') | ||||
| 
 | ||||
| # make predictions for test data | ||||
| y_pred = model.predict(X_test) | ||||
| predictions = [round(value) for value in y_pred] | ||||
| 
 | ||||
| # evaluate predictions | ||||
| accuracy = accuracy_score(y_test, predictions) | ||||
| print("Accuracy: %.2f%%" % (accuracy * 100.0)) | ||||
| labels = dtest.get_label() | ||||
| 
 | ||||
| # plot results | ||||
| xgb.plot_importance(model) | ||||
| plot_tree(model) | ||||
| plt.show() | ||||
| @ -144,7 +144,11 @@ class Session(TokenManager): | ||||
| 
 | ||||
|         # update api version from server response | ||||
|         try: | ||||
|             api_version = jwt.decode(self.token, verify=False).get('api_version', Session.api_version) | ||||
|             token_dict = jwt.decode(self.token, verify=False) | ||||
|             api_version = token_dict.get('api_version') | ||||
|             if not api_version: | ||||
|                 api_version = '2.2' if token_dict.get('env', '') == 'prod' else Session.api_version | ||||
| 
 | ||||
|             Session.api_version = str(api_version) | ||||
|         except (jwt.DecodeError, ValueError): | ||||
|             pass | ||||
|  | ||||
| @ -36,7 +36,7 @@ def or_(*converters, **kwargs): | ||||
|     """ | ||||
|     Wrapper that implements an "optional converter" pattern. Allows specifying a converter | ||||
|     for which a set of exceptions is ignored (and the original value is returned) | ||||
|     :param converter: A converter callable | ||||
|     :param converters: A converter callable | ||||
|     :param exceptions: A tuple of exception types to ignore | ||||
|     """ | ||||
|     # noinspection PyUnresolvedReferences | ||||
|  | ||||
| @ -1,9 +1,8 @@ | ||||
| import os | ||||
| import weakref | ||||
| 
 | ||||
| import numpy as np | ||||
| import hashlib | ||||
| from tempfile import mkstemp, mkdtemp | ||||
| from tempfile import mkdtemp | ||||
| from threading import Thread, Event | ||||
| from multiprocessing.pool import ThreadPool | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										52
									
								
								trains/binding/frameworks/base_bind.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								trains/binding/frameworks/base_bind.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,52 @@ | ||||
| from abc import ABCMeta, abstractmethod | ||||
| 
 | ||||
| import six | ||||
| 
 | ||||
| 
 | ||||
| @six.add_metaclass(ABCMeta) | ||||
| class PatchBaseModelIO(object): | ||||
|     """ | ||||
|     Base class for patched models | ||||
| 
 | ||||
|     :param __main_task: Task to run (Experiment) | ||||
|     :type __main_task: Task | ||||
|     :param __patched: True if the model is patched | ||||
|     :type __patched: bool | ||||
|     """ | ||||
|     @property | ||||
|     @abstractmethod | ||||
|     def __main_task(self): | ||||
|         pass | ||||
| 
 | ||||
|     @property | ||||
|     @abstractmethod | ||||
|     def __patched(self): | ||||
|         pass | ||||
| 
 | ||||
|     @staticmethod | ||||
|     @abstractmethod | ||||
|     def update_current_task(task, **kwargs): | ||||
|         """ | ||||
|         Update the model task to run | ||||
|         :param task: the experiment to do | ||||
|         :type task: Task | ||||
|         """ | ||||
|         pass | ||||
| 
 | ||||
|     @staticmethod | ||||
|     @abstractmethod | ||||
|     def _patch_model_io(): | ||||
|         """ | ||||
|         Patching the load and save functions | ||||
|         """ | ||||
|         pass | ||||
| 
 | ||||
|     @staticmethod | ||||
|     @abstractmethod | ||||
|     def _save(original_fn, obj, f, *args, **kwargs): | ||||
|         pass | ||||
| 
 | ||||
|     @staticmethod | ||||
|     @abstractmethod | ||||
|     def _load(original_fn, f, *args, **kwargs): | ||||
|         pass | ||||
| @ -3,13 +3,14 @@ import sys | ||||
| import six | ||||
| from pathlib2 import Path | ||||
| 
 | ||||
| from trains.binding.frameworks.base_bind import PatchBaseModelIO | ||||
| from ..frameworks import _patched_call, WeightsFileHandler, _Empty | ||||
| from ..import_bind import PostImportHookPatching | ||||
| from ...config import running_remotely | ||||
| from ...model import Framework | ||||
| 
 | ||||
| 
 | ||||
| class PatchPyTorchModelIO(object): | ||||
| class PatchPyTorchModelIO(PatchBaseModelIO): | ||||
|     __main_task = None | ||||
|     __patched = None | ||||
| 
 | ||||
|  | ||||
| @ -9,8 +9,6 @@ from typing import Any | ||||
| 
 | ||||
| import cv2 | ||||
| import numpy as np | ||||
| import six | ||||
| from pathlib2 import Path | ||||
| 
 | ||||
| from ..frameworks import _patched_call, WeightsFileHandler, _Empty, TrainsFrameworkAdapter | ||||
| from ..import_bind import PostImportHookPatching | ||||
|  | ||||
							
								
								
									
										101
									
								
								trains/binding/frameworks/xgboost_bind.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										101
									
								
								trains/binding/frameworks/xgboost_bind.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,101 @@ | ||||
| import sys | ||||
| 
 | ||||
| import six | ||||
| from pathlib2 import Path | ||||
| 
 | ||||
| from trains.binding.frameworks.base_bind import PatchBaseModelIO | ||||
| from ..frameworks import _patched_call, WeightsFileHandler, _Empty | ||||
| from ..import_bind import PostImportHookPatching | ||||
| from ...config import running_remotely | ||||
| from ...model import Framework | ||||
| 
 | ||||
| 
 | ||||
| class PatchXGBoostModelIO(PatchBaseModelIO): | ||||
|     __main_task = None | ||||
|     __patched = None | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def update_current_task(task, **kwargs): | ||||
|         PatchXGBoostModelIO.__main_task = task | ||||
|         PatchXGBoostModelIO._patch_model_io() | ||||
|         PostImportHookPatching.add_on_import('xgboost', PatchXGBoostModelIO._patch_model_io) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _patch_model_io(): | ||||
|         if PatchXGBoostModelIO.__patched: | ||||
|             return | ||||
| 
 | ||||
|         if 'xgboost' not in sys.modules: | ||||
|             return | ||||
|         PatchXGBoostModelIO.__patched = True | ||||
|         try: | ||||
|             import xgboost as xgb | ||||
|             bst = xgb.Booster | ||||
|             bst.save_model = _patched_call(bst.save_model, PatchXGBoostModelIO._save) | ||||
|             bst.load_model = _patched_call(bst.load_model, PatchXGBoostModelIO._load) | ||||
|         except ImportError: | ||||
|             pass | ||||
|         except Exception: | ||||
|             pass | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _save(original_fn, obj, f, *args, **kwargs): | ||||
|         ret = original_fn(obj, f, *args, **kwargs) | ||||
|         if not PatchXGBoostModelIO.__main_task: | ||||
|             return ret | ||||
| 
 | ||||
|         if isinstance(f, six.string_types): | ||||
|             filename = f | ||||
|         elif hasattr(f, 'name'): | ||||
|             filename = f.name | ||||
|             # noinspection PyBroadException | ||||
|             try: | ||||
|                 f.flush() | ||||
|             except Exception: | ||||
|                 pass | ||||
|         else: | ||||
|             filename = None | ||||
| 
 | ||||
|         # give the model a descriptive name based on the file name | ||||
|         # noinspection PyBroadException | ||||
|         try: | ||||
|             model_name = Path(filename).stem | ||||
|         except Exception: | ||||
|             model_name = None | ||||
|         WeightsFileHandler.create_output_model(obj, filename, Framework.xgboost, PatchXGBoostModelIO.__main_task, | ||||
|                                                singlefile=True, model_name=model_name) | ||||
|         return ret | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _load(original_fn, f, *args, **kwargs): | ||||
|         if isinstance(f, six.string_types): | ||||
|             filename = f | ||||
|         elif hasattr(f, 'name'): | ||||
|             filename = f.name | ||||
|         elif len(args) == 1 and isinstance(args[0], six.string_types): | ||||
|             filename = args[0] | ||||
|         else: | ||||
|             filename = None | ||||
| 
 | ||||
|         if not PatchXGBoostModelIO.__main_task: | ||||
|             return original_fn(f, *args, **kwargs) | ||||
| 
 | ||||
|         # register input model | ||||
|         empty = _Empty() | ||||
|         if running_remotely(): | ||||
|             filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost, | ||||
|                                                                PatchXGBoostModelIO.__main_task) | ||||
|             model = original_fn(filename or f, *args, **kwargs) | ||||
|         else: | ||||
|             # try to load model before registering, in case we fail | ||||
|             model = original_fn(f, *args, **kwargs) | ||||
|             WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost, | ||||
|                                                     PatchXGBoostModelIO.__main_task) | ||||
| 
 | ||||
|         if empty.trains_in_model: | ||||
|             # noinspection PyBroadException | ||||
|             try: | ||||
|                 model.trains_in_model = empty.trains_in_model | ||||
|             except Exception: | ||||
|                 pass | ||||
|         return model | ||||
							
								
								
									
										110
									
								
								trains/binding/joblib_bind.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								trains/binding/joblib_bind.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,110 @@ | ||||
| try: | ||||
|     import joblib | ||||
| except ImportError as e: | ||||
|     joblib = None | ||||
| 
 | ||||
| import six | ||||
| from pathlib2 import Path | ||||
| 
 | ||||
| from trains.binding.frameworks import _patched_call, _Empty, WeightsFileHandler | ||||
| from trains.config import running_remotely | ||||
| from trains.debugging.log import LoggerRoot | ||||
| 
 | ||||
| 
 | ||||
| class PatchedJoblib(object): | ||||
|     _patched_original_dump = None | ||||
|     _patched_original_load = None | ||||
|     _current_task = None | ||||
|     _current_framework = None | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def patch_joblib(): | ||||
|         if PatchedJoblib._patched_original_dump is not None and PatchedJoblib._patched_original_load is not None: | ||||
|             # We don't need to patch anything else, so we are done | ||||
|             return True | ||||
|         # noinspection PyBroadException | ||||
|         try: | ||||
|             joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump) | ||||
|             joblib.load = _patched_call(joblib.load, PatchedJoblib._load) | ||||
| 
 | ||||
|         except Exception: | ||||
|             return False | ||||
|         return True | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def update_current_task(task): | ||||
|         if PatchedJoblib.patch_joblib(): | ||||
|             PatchedJoblib._current_task = task | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _dump(original_fn, obj, f, *args, **kwargs): | ||||
|         ret = original_fn(obj, f, *args, **kwargs) | ||||
|         if not PatchedJoblib._current_task: | ||||
|             return ret | ||||
| 
 | ||||
|         if isinstance(f, six.string_types): | ||||
|             filename = f | ||||
|         elif hasattr(f, 'name'): | ||||
|             filename = f.name | ||||
|         #     noinspection PyBroadException | ||||
|             try: | ||||
|                 f.flush() | ||||
|             except Exception: | ||||
|                 pass | ||||
|         else: | ||||
|             filename = None | ||||
| 
 | ||||
|         # give the model a descriptive name based on the file name | ||||
|         # noinspection PyBroadException | ||||
|         try: | ||||
|             model_name = Path(filename).stem | ||||
|         except Exception: | ||||
|             model_name = None | ||||
|         PatchedJoblib._current_framework = PatchedJoblib.get_model_framework(obj) | ||||
|         WeightsFileHandler.create_output_model(obj, filename, PatchedJoblib._current_framework, | ||||
|                                                PatchedJoblib._current_task, singlefile=True, model_name=model_name) | ||||
|         return ret | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _load(original_fn, f, *args, **kwargs): | ||||
|         if isinstance(f, six.string_types): | ||||
|             filename = f | ||||
|         elif hasattr(f, 'name'): | ||||
|             filename = f.name | ||||
|         else: | ||||
|             filename = None | ||||
| 
 | ||||
|         if not PatchedJoblib._current_task: | ||||
|             return original_fn(f, *args, **kwargs) | ||||
| 
 | ||||
|         # register input model | ||||
|         empty = _Empty() | ||||
|         if running_remotely(): | ||||
|             filename = WeightsFileHandler.restore_weights_file(empty, filename, PatchedJoblib._current_framework, | ||||
|                                                                PatchedJoblib._current_task) | ||||
|             model = original_fn(filename or f, *args, **kwargs) | ||||
|         else: | ||||
|             # try to load model before registering, in case we fail | ||||
|             model = original_fn(f, *args, **kwargs) | ||||
|             WeightsFileHandler.restore_weights_file(empty, filename, PatchedJoblib._current_framework, | ||||
|                                                     PatchedJoblib._current_task) | ||||
| 
 | ||||
|         if empty.trains_in_model: | ||||
|             # noinspection PyBroadException | ||||
|             try: | ||||
|                 model.trains_in_model = empty.trains_in_model | ||||
|             except Exception: | ||||
|                 pass | ||||
|         return model | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def get_model_framework(obj): | ||||
|         object_orig_module = obj.__module__ | ||||
|         framework = object_orig_module | ||||
|         try: | ||||
|             framework = object_orig_module.partition(".")[0] | ||||
|         except Exception as _: | ||||
|             LoggerRoot.get_base_logger().warning( | ||||
|                 "Can't get model framework, model framework will be: {} ".format(object_orig_module)) | ||||
|         finally: | ||||
|             return framework | ||||
| @ -1,6 +1,5 @@ | ||||
| import abc | ||||
| import os | ||||
| import re | ||||
| import tarfile | ||||
| import zipfile | ||||
| from tempfile import mkdtemp, mkstemp | ||||
| @ -40,6 +39,7 @@ class Framework(Options): | ||||
|     darknet = 'Darknet' | ||||
|     paddlepaddle = 'PaddlePaddle' | ||||
|     scikitlearn = 'ScikitLearn' | ||||
|     xgboost = 'XGBoost' | ||||
| 
 | ||||
|     __file_extensions_mapping = { | ||||
|         '.pb': (tensorflow, tensorflowjs, onnx, ), | ||||
| @ -59,13 +59,13 @@ class Framework(Options): | ||||
|         '.h5': (keras, ), | ||||
|         '.hdf5': (keras, ), | ||||
|         '.keras': (keras, ), | ||||
|         '.model': (mknet, cntk, ), | ||||
|         '.model': (mknet, cntk, xgboost), | ||||
|         '-symbol.json': (mknet, ), | ||||
|         '.cntk': (cntk, ), | ||||
|         '.t7': (torch, ), | ||||
|         '.cfg': (darknet, ), | ||||
|         '__model__': (paddlepaddle, ), | ||||
|         '.pkl': (scikitlearn, keras, ), | ||||
|         '.pkl': (scikitlearn, keras, xgboost), | ||||
|     } | ||||
| 
 | ||||
|     @classmethod | ||||
|  | ||||
| @ -10,6 +10,7 @@ from collections import OrderedDict, Callable | ||||
| import psutil | ||||
| import six | ||||
| 
 | ||||
| from trains.binding.joblib_bind import PatchedJoblib | ||||
| from .backend_api.services import tasks, projects | ||||
| from .backend_api.session.session import Session | ||||
| from .backend_interface.model import Model as BackendModel | ||||
| @ -34,8 +35,9 @@ from .utilities.args import argparser_parseargs_called, get_argparser_last_args, | ||||
| from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO | ||||
| from .binding.frameworks.tensorflow_bind import PatchSummaryToEventTransformer, PatchTensorFlowEager, \ | ||||
|     PatchKerasModelIO, PatchTensorflowModelIO | ||||
| from .utilities.resource_monitor import ResourceMonitor | ||||
| from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO | ||||
| from .binding.matplotlib_bind import PatchedMatplotlib | ||||
| from .utilities.resource_monitor import ResourceMonitor | ||||
| from .utilities.seed import make_deterministic | ||||
| 
 | ||||
| NotSet = object() | ||||
| @ -118,15 +120,15 @@ class Task(_Task): | ||||
| 
 | ||||
|     @classmethod | ||||
|     def init( | ||||
|         cls, | ||||
|         project_name=None, | ||||
|         task_name=None, | ||||
|         task_type=TaskTypes.training, | ||||
|         reuse_last_task_id=True, | ||||
|         output_uri=None, | ||||
|         auto_connect_arg_parser=True, | ||||
|         auto_connect_frameworks=True, | ||||
|         auto_resource_monitoring=True, | ||||
|             cls, | ||||
|             project_name=None, | ||||
|             task_name=None, | ||||
|             task_type=TaskTypes.training, | ||||
|             reuse_last_task_id=True, | ||||
|             output_uri=None, | ||||
|             auto_connect_arg_parser=True, | ||||
|             auto_connect_frameworks=True, | ||||
|             auto_resource_monitoring=True, | ||||
|     ): | ||||
|         """ | ||||
|         Return the Task object for the main execution task (task context). | ||||
| @ -239,14 +241,15 @@ class Task(_Task): | ||||
|             # patch OS forking | ||||
|             PatchOsFork.patch_fork() | ||||
|             if auto_connect_frameworks: | ||||
|                 PatchedJoblib.update_current_task(task) | ||||
|                 PatchedMatplotlib.update_current_task(Task.__main_task) | ||||
|                 PatchAbsl.update_current_task(Task.__main_task) | ||||
|                 PatchSummaryToEventTransformer.update_current_task(task) | ||||
|                 # PatchModelCheckPointCallback.update_current_task(task) | ||||
|                 PatchTensorFlowEager.update_current_task(task) | ||||
|                 PatchKerasModelIO.update_current_task(task) | ||||
|                 PatchTensorflowModelIO.update_current_task(task) | ||||
|                 PatchPyTorchModelIO.update_current_task(task) | ||||
|                 PatchXGBoostModelIO.update_current_task(task) | ||||
|             if auto_resource_monitoring: | ||||
|                 task._resource_monitor = ResourceMonitor(task) | ||||
|                 task._resource_monitor.start() | ||||
| @ -277,10 +280,10 @@ class Task(_Task): | ||||
| 
 | ||||
|     @classmethod | ||||
|     def create( | ||||
|         cls, | ||||
|         task_name=None, | ||||
|         project_name=None, | ||||
|         task_type=TaskTypes.training, | ||||
|             cls, | ||||
|             task_name=None, | ||||
|             project_name=None, | ||||
|             task_type=TaskTypes.training, | ||||
|     ): | ||||
|         """ | ||||
|         Create a new Task object, regardless of the main execution task (Task.init). | ||||
| @ -345,7 +348,7 @@ class Task(_Task): | ||||
|                         pass | ||||
| 
 | ||||
|         # if we force no task reuse from os environment | ||||
|         if DEV_TASK_NO_REUSE.get(): | ||||
|         if DEV_TASK_NO_REUSE.get() or reuse_last_task_id: | ||||
|             default_task = None | ||||
|         else: | ||||
|             # if we have a previous session to use, get the task id from it | ||||
| @ -364,7 +367,6 @@ class Task(_Task): | ||||
|                 default_task_id = reuse_last_task_id | ||||
|             elif not reuse_last_task_id or not cls.__task_is_relevant(default_task): | ||||
|                 default_task_id = None | ||||
|                 closed_old_task = cls.__close_timed_out_task(default_task) | ||||
|             else: | ||||
|                 default_task_id = default_task.get('id') if default_task else None | ||||
| 
 | ||||
| @ -693,7 +695,7 @@ class Task(_Task): | ||||
|             If `config_text` is not None, `config_dict` must not be provided. | ||||
|         """ | ||||
|         config_text = self.get_model_config_text() | ||||
|         return  OutputModel._text_to_config_dict(config_text) | ||||
|         return OutputModel._text_to_config_dict(config_text) | ||||
| 
 | ||||
|     def set_model_label_enumeration(self, enumeration=None): | ||||
|         """ | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 allegroai
						allegroai