mirror of
https://github.com/clearml/clearml
synced 2025-05-07 06:14:31 +00:00
Add scikit-learn support (joblib) and xgboost support
This commit is contained in:
parent
1bb06c0190
commit
19c5f05912
25
examples/joblib_example.py
Normal file
25
examples/joblib_example.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import joblib
|
||||||
|
|
||||||
|
from sklearn import datasets
|
||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
|
|
||||||
|
from trains import Task
|
||||||
|
|
||||||
|
task = Task.init(project_name="examples", task_name="joblib test")
|
||||||
|
|
||||||
|
iris = datasets.load_iris()
|
||||||
|
X = iris.data
|
||||||
|
y = iris.target
|
||||||
|
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||||
|
|
||||||
|
model = LogisticRegression() # sklearn LogisticRegression class
|
||||||
|
model.fit(X_train, y_train)
|
||||||
|
|
||||||
|
joblib.dump(model, 'model.pkl', compress=True)
|
||||||
|
|
||||||
|
loaded_model = joblib.load('model.pkl')
|
||||||
|
result = loaded_model.score(X_test, y_test)
|
||||||
|
print(result)
|
@ -6,29 +6,28 @@
|
|||||||
# 2 seconds per epoch on a K520 GPU.
|
# 2 seconds per epoch on a K520 GPU.
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import io
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow
|
|
||||||
|
|
||||||
from keras.callbacks import TensorBoard, ModelCheckpoint
|
from keras.callbacks import TensorBoard, ModelCheckpoint
|
||||||
from keras.datasets import mnist
|
from keras.datasets import mnist
|
||||||
from keras.models import Sequential, Model
|
from keras.models import Sequential
|
||||||
from keras.layers.core import Dense, Dropout, Activation
|
from keras.layers.core import Dense, Activation
|
||||||
from keras.optimizers import SGD, Adam, RMSprop
|
from keras.optimizers import RMSprop
|
||||||
from keras.utils import np_utils
|
from keras.utils import np_utils
|
||||||
|
# TODO: test these methods binding
|
||||||
from keras.models import load_model, save_model, model_from_json
|
from keras.models import load_model, save_model, model_from_json
|
||||||
|
import tensorflow as tf
|
||||||
from trains import Task
|
from trains import Task
|
||||||
|
|
||||||
|
|
||||||
class TensorBoardImage(TensorBoard):
|
class TensorBoardImage(TensorBoard):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_image(tensor):
|
def make_image(tensor):
|
||||||
import tensorflow as tf
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
tensor = np.stack((tensor, tensor, tensor), axis=2)
|
tensor = np.stack((tensor, tensor, tensor), axis=2)
|
||||||
height, width, channels = tensor.shape
|
height, width, channels = tensor.shape
|
||||||
image = Image.fromarray(tensor)
|
image = Image.fromarray(tensor)
|
||||||
import io
|
|
||||||
output = io.BytesIO()
|
output = io.BytesIO()
|
||||||
image.save(output, format='PNG')
|
image.save(output, format='PNG')
|
||||||
image_string = output.getvalue()
|
image_string = output.getvalue()
|
||||||
@ -38,9 +37,10 @@ class TensorBoardImage(TensorBoard):
|
|||||||
colorspace=channels,
|
colorspace=channels,
|
||||||
encoded_image_string=image_string)
|
encoded_image_string=image_string)
|
||||||
|
|
||||||
def on_epoch_end(self, epoch, logs={}):
|
def on_epoch_end(self, epoch, logs=None):
|
||||||
|
if logs is None:
|
||||||
|
logs = {}
|
||||||
super(TensorBoardImage, self).on_epoch_end(epoch, logs)
|
super(TensorBoardImage, self).on_epoch_end(epoch, logs)
|
||||||
import tensorflow as tf
|
|
||||||
images = self.validation_data[0] # 0 - data; 1 - labels
|
images = self.validation_data[0] # 0 - data; 1 - labels
|
||||||
img = (255 * images[0].reshape(28, 28)).astype('uint8')
|
img = (255 * images[0].reshape(28, 28)).astype('uint8')
|
||||||
|
|
||||||
|
15
examples/requirements.txt
Normal file
15
examples/requirements.txt
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
absl-py>=0.7.1
|
||||||
|
Keras>=2.2.4
|
||||||
|
joblib>=0.13.2
|
||||||
|
matplotlib>=3.1.1
|
||||||
|
seaborn>=0.9.0
|
||||||
|
sklearn>=0.0
|
||||||
|
tensorboard>=1.14.0
|
||||||
|
tensorboardX>=1.8
|
||||||
|
tensorflow>=1.14.0
|
||||||
|
torch>=1.1.0
|
||||||
|
torchvision>=0.3.0
|
||||||
|
xgboost>=0.90
|
||||||
|
|
||||||
|
# sudo apt-get install graphviz
|
||||||
|
graphviz>=0.8
|
@ -3,9 +3,6 @@
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
from time import sleep
|
|
||||||
#import tensorflow.compat.v1 as tf
|
|
||||||
#tf.disable_v2_behavior()
|
|
||||||
|
|
||||||
from trains import Task
|
from trains import Task
|
||||||
task = Task.init(project_name='examples', task_name='tensorboard toy example')
|
task = Task.init(project_name='examples', task_name='tensorboard toy example')
|
||||||
|
59
examples/xgboost_sample.py
Normal file
59
examples/xgboost_sample.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import xgboost as xgb
|
||||||
|
from sklearn import datasets
|
||||||
|
from sklearn.metrics import accuracy_score
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from xgboost import plot_tree
|
||||||
|
|
||||||
|
from trains import Task
|
||||||
|
|
||||||
|
task = Task.init(project_name='examples', task_name='XGBoost simple example')
|
||||||
|
iris = datasets.load_iris()
|
||||||
|
X = iris.data
|
||||||
|
y = iris.target
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||||
|
dtrain = xgb.DMatrix(X_train, label=y_train)
|
||||||
|
dtest = xgb.DMatrix(X_test, label=y_test)
|
||||||
|
param = {
|
||||||
|
'max_depth': 3, # the maximum depth of each tree
|
||||||
|
'eta': 0.3, # the training step for each iteration
|
||||||
|
'silent': 1, # logging mode - quiet
|
||||||
|
'objective': 'multi:softprob', # error evaluation for multiclass training
|
||||||
|
'num_class': 3} # the number of classes that exist in this datset
|
||||||
|
num_round = 20 # the number of training iterations
|
||||||
|
|
||||||
|
try:
|
||||||
|
# try to load a model
|
||||||
|
bst = xgb.Booster(params=param, model_file='xgb.01.model')
|
||||||
|
bst.load_model('xgb.01.model')
|
||||||
|
except:
|
||||||
|
bst = None
|
||||||
|
|
||||||
|
# if we dont have one train a model
|
||||||
|
if bst is None:
|
||||||
|
bst = xgb.train(param, dtrain, num_round)
|
||||||
|
|
||||||
|
# store trained model model v1
|
||||||
|
bst.save_model('xgb.01.model')
|
||||||
|
bst.dump_model('xgb.01.raw.txt')
|
||||||
|
|
||||||
|
# build classifier
|
||||||
|
model = xgb.XGBClassifier()
|
||||||
|
model.fit(X_train, y_train)
|
||||||
|
|
||||||
|
# store trained classifier model
|
||||||
|
model.save_model('xgb.02.model')
|
||||||
|
|
||||||
|
# make predictions for test data
|
||||||
|
y_pred = model.predict(X_test)
|
||||||
|
predictions = [round(value) for value in y_pred]
|
||||||
|
|
||||||
|
# evaluate predictions
|
||||||
|
accuracy = accuracy_score(y_test, predictions)
|
||||||
|
print("Accuracy: %.2f%%" % (accuracy * 100.0))
|
||||||
|
labels = dtest.get_label()
|
||||||
|
|
||||||
|
# plot results
|
||||||
|
xgb.plot_importance(model)
|
||||||
|
plot_tree(model)
|
||||||
|
plt.show()
|
@ -144,7 +144,11 @@ class Session(TokenManager):
|
|||||||
|
|
||||||
# update api version from server response
|
# update api version from server response
|
||||||
try:
|
try:
|
||||||
api_version = jwt.decode(self.token, verify=False).get('api_version', Session.api_version)
|
token_dict = jwt.decode(self.token, verify=False)
|
||||||
|
api_version = token_dict.get('api_version')
|
||||||
|
if not api_version:
|
||||||
|
api_version = '2.2' if token_dict.get('env', '') == 'prod' else Session.api_version
|
||||||
|
|
||||||
Session.api_version = str(api_version)
|
Session.api_version = str(api_version)
|
||||||
except (jwt.DecodeError, ValueError):
|
except (jwt.DecodeError, ValueError):
|
||||||
pass
|
pass
|
||||||
|
@ -36,7 +36,7 @@ def or_(*converters, **kwargs):
|
|||||||
"""
|
"""
|
||||||
Wrapper that implements an "optional converter" pattern. Allows specifying a converter
|
Wrapper that implements an "optional converter" pattern. Allows specifying a converter
|
||||||
for which a set of exceptions is ignored (and the original value is returned)
|
for which a set of exceptions is ignored (and the original value is returned)
|
||||||
:param converter: A converter callable
|
:param converters: A converter callable
|
||||||
:param exceptions: A tuple of exception types to ignore
|
:param exceptions: A tuple of exception types to ignore
|
||||||
"""
|
"""
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
import os
|
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import hashlib
|
import hashlib
|
||||||
from tempfile import mkstemp, mkdtemp
|
from tempfile import mkdtemp
|
||||||
from threading import Thread, Event
|
from threading import Thread, Event
|
||||||
from multiprocessing.pool import ThreadPool
|
from multiprocessing.pool import ThreadPool
|
||||||
|
|
||||||
|
52
trains/binding/frameworks/base_bind.py
Normal file
52
trains/binding/frameworks/base_bind.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
|
|
||||||
|
@six.add_metaclass(ABCMeta)
|
||||||
|
class PatchBaseModelIO(object):
|
||||||
|
"""
|
||||||
|
Base class for patched models
|
||||||
|
|
||||||
|
:param __main_task: Task to run (Experiment)
|
||||||
|
:type __main_task: Task
|
||||||
|
:param __patched: True if the model is patched
|
||||||
|
:type __patched: bool
|
||||||
|
"""
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def __main_task(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def __patched(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def update_current_task(task, **kwargs):
|
||||||
|
"""
|
||||||
|
Update the model task to run
|
||||||
|
:param task: the experiment to do
|
||||||
|
:type task: Task
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def _patch_model_io():
|
||||||
|
"""
|
||||||
|
Patching the load and save functions
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def _save(original_fn, obj, f, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def _load(original_fn, f, *args, **kwargs):
|
||||||
|
pass
|
@ -3,13 +3,14 @@ import sys
|
|||||||
import six
|
import six
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
|
|
||||||
|
from trains.binding.frameworks.base_bind import PatchBaseModelIO
|
||||||
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
|
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
|
||||||
from ..import_bind import PostImportHookPatching
|
from ..import_bind import PostImportHookPatching
|
||||||
from ...config import running_remotely
|
from ...config import running_remotely
|
||||||
from ...model import Framework
|
from ...model import Framework
|
||||||
|
|
||||||
|
|
||||||
class PatchPyTorchModelIO(object):
|
class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||||
__main_task = None
|
__main_task = None
|
||||||
__patched = None
|
__patched = None
|
||||||
|
|
||||||
|
@ -9,8 +9,6 @@ from typing import Any
|
|||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
|
||||||
from pathlib2 import Path
|
|
||||||
|
|
||||||
from ..frameworks import _patched_call, WeightsFileHandler, _Empty, TrainsFrameworkAdapter
|
from ..frameworks import _patched_call, WeightsFileHandler, _Empty, TrainsFrameworkAdapter
|
||||||
from ..import_bind import PostImportHookPatching
|
from ..import_bind import PostImportHookPatching
|
||||||
|
101
trains/binding/frameworks/xgboost_bind.py
Normal file
101
trains/binding/frameworks/xgboost_bind.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
import six
|
||||||
|
from pathlib2 import Path
|
||||||
|
|
||||||
|
from trains.binding.frameworks.base_bind import PatchBaseModelIO
|
||||||
|
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
|
||||||
|
from ..import_bind import PostImportHookPatching
|
||||||
|
from ...config import running_remotely
|
||||||
|
from ...model import Framework
|
||||||
|
|
||||||
|
|
||||||
|
class PatchXGBoostModelIO(PatchBaseModelIO):
|
||||||
|
__main_task = None
|
||||||
|
__patched = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_current_task(task, **kwargs):
|
||||||
|
PatchXGBoostModelIO.__main_task = task
|
||||||
|
PatchXGBoostModelIO._patch_model_io()
|
||||||
|
PostImportHookPatching.add_on_import('xgboost', PatchXGBoostModelIO._patch_model_io)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _patch_model_io():
|
||||||
|
if PatchXGBoostModelIO.__patched:
|
||||||
|
return
|
||||||
|
|
||||||
|
if 'xgboost' not in sys.modules:
|
||||||
|
return
|
||||||
|
PatchXGBoostModelIO.__patched = True
|
||||||
|
try:
|
||||||
|
import xgboost as xgb
|
||||||
|
bst = xgb.Booster
|
||||||
|
bst.save_model = _patched_call(bst.save_model, PatchXGBoostModelIO._save)
|
||||||
|
bst.load_model = _patched_call(bst.load_model, PatchXGBoostModelIO._load)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _save(original_fn, obj, f, *args, **kwargs):
|
||||||
|
ret = original_fn(obj, f, *args, **kwargs)
|
||||||
|
if not PatchXGBoostModelIO.__main_task:
|
||||||
|
return ret
|
||||||
|
|
||||||
|
if isinstance(f, six.string_types):
|
||||||
|
filename = f
|
||||||
|
elif hasattr(f, 'name'):
|
||||||
|
filename = f.name
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
f.flush()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
filename = None
|
||||||
|
|
||||||
|
# give the model a descriptive name based on the file name
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
model_name = Path(filename).stem
|
||||||
|
except Exception:
|
||||||
|
model_name = None
|
||||||
|
WeightsFileHandler.create_output_model(obj, filename, Framework.xgboost, PatchXGBoostModelIO.__main_task,
|
||||||
|
singlefile=True, model_name=model_name)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load(original_fn, f, *args, **kwargs):
|
||||||
|
if isinstance(f, six.string_types):
|
||||||
|
filename = f
|
||||||
|
elif hasattr(f, 'name'):
|
||||||
|
filename = f.name
|
||||||
|
elif len(args) == 1 and isinstance(args[0], six.string_types):
|
||||||
|
filename = args[0]
|
||||||
|
else:
|
||||||
|
filename = None
|
||||||
|
|
||||||
|
if not PatchXGBoostModelIO.__main_task:
|
||||||
|
return original_fn(f, *args, **kwargs)
|
||||||
|
|
||||||
|
# register input model
|
||||||
|
empty = _Empty()
|
||||||
|
if running_remotely():
|
||||||
|
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
|
||||||
|
PatchXGBoostModelIO.__main_task)
|
||||||
|
model = original_fn(filename or f, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
# try to load model before registering, in case we fail
|
||||||
|
model = original_fn(f, *args, **kwargs)
|
||||||
|
WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
|
||||||
|
PatchXGBoostModelIO.__main_task)
|
||||||
|
|
||||||
|
if empty.trains_in_model:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
model.trains_in_model = empty.trains_in_model
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return model
|
110
trains/binding/joblib_bind.py
Normal file
110
trains/binding/joblib_bind.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
try:
|
||||||
|
import joblib
|
||||||
|
except ImportError as e:
|
||||||
|
joblib = None
|
||||||
|
|
||||||
|
import six
|
||||||
|
from pathlib2 import Path
|
||||||
|
|
||||||
|
from trains.binding.frameworks import _patched_call, _Empty, WeightsFileHandler
|
||||||
|
from trains.config import running_remotely
|
||||||
|
from trains.debugging.log import LoggerRoot
|
||||||
|
|
||||||
|
|
||||||
|
class PatchedJoblib(object):
|
||||||
|
_patched_original_dump = None
|
||||||
|
_patched_original_load = None
|
||||||
|
_current_task = None
|
||||||
|
_current_framework = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def patch_joblib():
|
||||||
|
if PatchedJoblib._patched_original_dump is not None and PatchedJoblib._patched_original_load is not None:
|
||||||
|
# We don't need to patch anything else, so we are done
|
||||||
|
return True
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump)
|
||||||
|
joblib.load = _patched_call(joblib.load, PatchedJoblib._load)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_current_task(task):
|
||||||
|
if PatchedJoblib.patch_joblib():
|
||||||
|
PatchedJoblib._current_task = task
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _dump(original_fn, obj, f, *args, **kwargs):
|
||||||
|
ret = original_fn(obj, f, *args, **kwargs)
|
||||||
|
if not PatchedJoblib._current_task:
|
||||||
|
return ret
|
||||||
|
|
||||||
|
if isinstance(f, six.string_types):
|
||||||
|
filename = f
|
||||||
|
elif hasattr(f, 'name'):
|
||||||
|
filename = f.name
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
f.flush()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
filename = None
|
||||||
|
|
||||||
|
# give the model a descriptive name based on the file name
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
model_name = Path(filename).stem
|
||||||
|
except Exception:
|
||||||
|
model_name = None
|
||||||
|
PatchedJoblib._current_framework = PatchedJoblib.get_model_framework(obj)
|
||||||
|
WeightsFileHandler.create_output_model(obj, filename, PatchedJoblib._current_framework,
|
||||||
|
PatchedJoblib._current_task, singlefile=True, model_name=model_name)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load(original_fn, f, *args, **kwargs):
|
||||||
|
if isinstance(f, six.string_types):
|
||||||
|
filename = f
|
||||||
|
elif hasattr(f, 'name'):
|
||||||
|
filename = f.name
|
||||||
|
else:
|
||||||
|
filename = None
|
||||||
|
|
||||||
|
if not PatchedJoblib._current_task:
|
||||||
|
return original_fn(f, *args, **kwargs)
|
||||||
|
|
||||||
|
# register input model
|
||||||
|
empty = _Empty()
|
||||||
|
if running_remotely():
|
||||||
|
filename = WeightsFileHandler.restore_weights_file(empty, filename, PatchedJoblib._current_framework,
|
||||||
|
PatchedJoblib._current_task)
|
||||||
|
model = original_fn(filename or f, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
# try to load model before registering, in case we fail
|
||||||
|
model = original_fn(f, *args, **kwargs)
|
||||||
|
WeightsFileHandler.restore_weights_file(empty, filename, PatchedJoblib._current_framework,
|
||||||
|
PatchedJoblib._current_task)
|
||||||
|
|
||||||
|
if empty.trains_in_model:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
model.trains_in_model = empty.trains_in_model
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_model_framework(obj):
|
||||||
|
object_orig_module = obj.__module__
|
||||||
|
framework = object_orig_module
|
||||||
|
try:
|
||||||
|
framework = object_orig_module.partition(".")[0]
|
||||||
|
except Exception as _:
|
||||||
|
LoggerRoot.get_base_logger().warning(
|
||||||
|
"Can't get model framework, model framework will be: {} ".format(object_orig_module))
|
||||||
|
finally:
|
||||||
|
return framework
|
@ -1,6 +1,5 @@
|
|||||||
import abc
|
import abc
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import tarfile
|
import tarfile
|
||||||
import zipfile
|
import zipfile
|
||||||
from tempfile import mkdtemp, mkstemp
|
from tempfile import mkdtemp, mkstemp
|
||||||
@ -40,6 +39,7 @@ class Framework(Options):
|
|||||||
darknet = 'Darknet'
|
darknet = 'Darknet'
|
||||||
paddlepaddle = 'PaddlePaddle'
|
paddlepaddle = 'PaddlePaddle'
|
||||||
scikitlearn = 'ScikitLearn'
|
scikitlearn = 'ScikitLearn'
|
||||||
|
xgboost = 'XGBoost'
|
||||||
|
|
||||||
__file_extensions_mapping = {
|
__file_extensions_mapping = {
|
||||||
'.pb': (tensorflow, tensorflowjs, onnx, ),
|
'.pb': (tensorflow, tensorflowjs, onnx, ),
|
||||||
@ -59,13 +59,13 @@ class Framework(Options):
|
|||||||
'.h5': (keras, ),
|
'.h5': (keras, ),
|
||||||
'.hdf5': (keras, ),
|
'.hdf5': (keras, ),
|
||||||
'.keras': (keras, ),
|
'.keras': (keras, ),
|
||||||
'.model': (mknet, cntk, ),
|
'.model': (mknet, cntk, xgboost),
|
||||||
'-symbol.json': (mknet, ),
|
'-symbol.json': (mknet, ),
|
||||||
'.cntk': (cntk, ),
|
'.cntk': (cntk, ),
|
||||||
'.t7': (torch, ),
|
'.t7': (torch, ),
|
||||||
'.cfg': (darknet, ),
|
'.cfg': (darknet, ),
|
||||||
'__model__': (paddlepaddle, ),
|
'__model__': (paddlepaddle, ),
|
||||||
'.pkl': (scikitlearn, keras, ),
|
'.pkl': (scikitlearn, keras, xgboost),
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -10,6 +10,7 @@ from collections import OrderedDict, Callable
|
|||||||
import psutil
|
import psutil
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from trains.binding.joblib_bind import PatchedJoblib
|
||||||
from .backend_api.services import tasks, projects
|
from .backend_api.services import tasks, projects
|
||||||
from .backend_api.session.session import Session
|
from .backend_api.session.session import Session
|
||||||
from .backend_interface.model import Model as BackendModel
|
from .backend_interface.model import Model as BackendModel
|
||||||
@ -34,8 +35,9 @@ from .utilities.args import argparser_parseargs_called, get_argparser_last_args,
|
|||||||
from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
||||||
from .binding.frameworks.tensorflow_bind import PatchSummaryToEventTransformer, PatchTensorFlowEager, \
|
from .binding.frameworks.tensorflow_bind import PatchSummaryToEventTransformer, PatchTensorFlowEager, \
|
||||||
PatchKerasModelIO, PatchTensorflowModelIO
|
PatchKerasModelIO, PatchTensorflowModelIO
|
||||||
from .utilities.resource_monitor import ResourceMonitor
|
from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
|
||||||
from .binding.matplotlib_bind import PatchedMatplotlib
|
from .binding.matplotlib_bind import PatchedMatplotlib
|
||||||
|
from .utilities.resource_monitor import ResourceMonitor
|
||||||
from .utilities.seed import make_deterministic
|
from .utilities.seed import make_deterministic
|
||||||
|
|
||||||
NotSet = object()
|
NotSet = object()
|
||||||
@ -118,15 +120,15 @@ class Task(_Task):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init(
|
def init(
|
||||||
cls,
|
cls,
|
||||||
project_name=None,
|
project_name=None,
|
||||||
task_name=None,
|
task_name=None,
|
||||||
task_type=TaskTypes.training,
|
task_type=TaskTypes.training,
|
||||||
reuse_last_task_id=True,
|
reuse_last_task_id=True,
|
||||||
output_uri=None,
|
output_uri=None,
|
||||||
auto_connect_arg_parser=True,
|
auto_connect_arg_parser=True,
|
||||||
auto_connect_frameworks=True,
|
auto_connect_frameworks=True,
|
||||||
auto_resource_monitoring=True,
|
auto_resource_monitoring=True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Return the Task object for the main execution task (task context).
|
Return the Task object for the main execution task (task context).
|
||||||
@ -239,14 +241,15 @@ class Task(_Task):
|
|||||||
# patch OS forking
|
# patch OS forking
|
||||||
PatchOsFork.patch_fork()
|
PatchOsFork.patch_fork()
|
||||||
if auto_connect_frameworks:
|
if auto_connect_frameworks:
|
||||||
|
PatchedJoblib.update_current_task(task)
|
||||||
PatchedMatplotlib.update_current_task(Task.__main_task)
|
PatchedMatplotlib.update_current_task(Task.__main_task)
|
||||||
PatchAbsl.update_current_task(Task.__main_task)
|
PatchAbsl.update_current_task(Task.__main_task)
|
||||||
PatchSummaryToEventTransformer.update_current_task(task)
|
PatchSummaryToEventTransformer.update_current_task(task)
|
||||||
# PatchModelCheckPointCallback.update_current_task(task)
|
|
||||||
PatchTensorFlowEager.update_current_task(task)
|
PatchTensorFlowEager.update_current_task(task)
|
||||||
PatchKerasModelIO.update_current_task(task)
|
PatchKerasModelIO.update_current_task(task)
|
||||||
PatchTensorflowModelIO.update_current_task(task)
|
PatchTensorflowModelIO.update_current_task(task)
|
||||||
PatchPyTorchModelIO.update_current_task(task)
|
PatchPyTorchModelIO.update_current_task(task)
|
||||||
|
PatchXGBoostModelIO.update_current_task(task)
|
||||||
if auto_resource_monitoring:
|
if auto_resource_monitoring:
|
||||||
task._resource_monitor = ResourceMonitor(task)
|
task._resource_monitor = ResourceMonitor(task)
|
||||||
task._resource_monitor.start()
|
task._resource_monitor.start()
|
||||||
@ -277,10 +280,10 @@ class Task(_Task):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
cls,
|
cls,
|
||||||
task_name=None,
|
task_name=None,
|
||||||
project_name=None,
|
project_name=None,
|
||||||
task_type=TaskTypes.training,
|
task_type=TaskTypes.training,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a new Task object, regardless of the main execution task (Task.init).
|
Create a new Task object, regardless of the main execution task (Task.init).
|
||||||
@ -345,7 +348,7 @@ class Task(_Task):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# if we force no task reuse from os environment
|
# if we force no task reuse from os environment
|
||||||
if DEV_TASK_NO_REUSE.get():
|
if DEV_TASK_NO_REUSE.get() or reuse_last_task_id:
|
||||||
default_task = None
|
default_task = None
|
||||||
else:
|
else:
|
||||||
# if we have a previous session to use, get the task id from it
|
# if we have a previous session to use, get the task id from it
|
||||||
@ -364,7 +367,6 @@ class Task(_Task):
|
|||||||
default_task_id = reuse_last_task_id
|
default_task_id = reuse_last_task_id
|
||||||
elif not reuse_last_task_id or not cls.__task_is_relevant(default_task):
|
elif not reuse_last_task_id or not cls.__task_is_relevant(default_task):
|
||||||
default_task_id = None
|
default_task_id = None
|
||||||
closed_old_task = cls.__close_timed_out_task(default_task)
|
|
||||||
else:
|
else:
|
||||||
default_task_id = default_task.get('id') if default_task else None
|
default_task_id = default_task.get('id') if default_task else None
|
||||||
|
|
||||||
@ -693,7 +695,7 @@ class Task(_Task):
|
|||||||
If `config_text` is not None, `config_dict` must not be provided.
|
If `config_text` is not None, `config_dict` must not be provided.
|
||||||
"""
|
"""
|
||||||
config_text = self.get_model_config_text()
|
config_text = self.get_model_config_text()
|
||||||
return OutputModel._text_to_config_dict(config_text)
|
return OutputModel._text_to_config_dict(config_text)
|
||||||
|
|
||||||
def set_model_label_enumeration(self, enumeration=None):
|
def set_model_label_enumeration(self, enumeration=None):
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user