diff --git a/examples/torchscript/requirements.txt b/examples/torchscript/requirements.txt new file mode 100644 index 0000000..0cf6632 --- /dev/null +++ b/examples/torchscript/requirements.txt @@ -0,0 +1,7 @@ +clearml +torch +torchvision +Pillow +nvidia-pyindex +tritonclient +requests \ No newline at end of file diff --git a/examples/torchscript/torchscript_mnist.py b/examples/torchscript/torchscript_mnist.py new file mode 100644 index 0000000..3e6edaa --- /dev/null +++ b/examples/torchscript/torchscript_mnist.py @@ -0,0 +1,176 @@ +# ClearML - Torchscript example code, automatic logging traced model +# Then store the torchscript model to be served by clearml-serving +import argparse +from pathlib import Path + +import requests +import torch +from clearml import OutputModel, Task +from PIL import Image +from torch import nn +from torchvision import transforms + + +class AlexNet(nn.Module): + + def __init__(self, num_classes: int = 1000) -> None: + super(AlexNet, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(64, 192, kernel_size=5, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(192, 384, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(384, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + ) + self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) + self.classifier = nn.Sequential( + nn.Dropout(), + nn.Linear(256 * 6 * 6, 4096), + nn.ReLU(inplace=True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(inplace=True), + nn.Linear(4096, num_classes), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.features(x) + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.classifier(x) + return x + + +transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), +]) + + +def create_config_pbtxt(config_pbtxt_file): + platform = "pytorch_libtorch" + input_name = 'INPUT__0' + output_name = 'OUTPUT__0' + input_data_type = "TYPE_FP32" + output_data_type = "TYPE_FP32" + input_dims = str([-1, 3, 224, 224]) + output_dims = str([-1, 1000]) + + config_pbtxt = """ + platform: "%s" + input [ + { + name: "%s" + data_type: %s + dims: %s + } + ] + output [ + { + name: "%s" + data_type: %s + dims: %s + } + ] + """ % ( + platform, + input_name, input_data_type, input_dims, + output_name, output_data_type, output_dims + ) + + with open(config_pbtxt_file, "w") as config_file: + config_file.write(config_pbtxt) + + +def preprocess(url): + response = requests.get(url) + filename = "sample.jpg" + with open(filename, 'wb') as f: + f.write(response.content) + + input_image = Image.open(filename) + input_tensor = transform(input_image) + input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model + return input_batch + + +def postprocess(output): + probabilities = torch.nn.functional.softmax(output[0], dim=0) + top5_prob, top5_catid = torch.topk(probabilities, 5) + return top5_prob, top5_catid + + +def get_mnist_labels(): + data = requests.get("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt") + return data.text.split("\n") + + +def get_alexnet_state_dict(): + filename = "alexnet_weights.pt" + if not Path(filename).exists(): + response = requests.get("https://download.pytorch.org/models/alexnet-owt-7be5be79.pth") + with open(filename, 'wb') as f: + f.write(response.content) + return torch.load(filename) + + +def main(): + parser = argparse.ArgumentParser(description='Torchscript MNIST Example - serving torchscript model') + args = parser.parse_args() + + # Connecting ClearML with the current process, + # from here on everything is logged automatically + task = Task.init(project_name='examples', task_name='Torchscript MNIST serve example', output_uri=True) + + # This could work, but the github api limits the number of downloads + # model = torch.hub.load('pytorch/vision:v0.9.0', 'alexnet', pretrained=True) + + # Instead we use hardcoded AlexNet model + model = AlexNet() + state_dict = get_alexnet_state_dict() + model.load_state_dict(state_dict) + + model.eval() + + # Advanced: setting model class enumeration + mnist_labels_list = get_mnist_labels() + labels = {label: i for i, label in enumerate(mnist_labels_list)} + task.set_model_label_enumeration(labels) + + # Get a input image for the model + url = 'https://github.com/pytorch/hub/raw/master/images/dog.jpg' + input = preprocess(url) + + # Trace and save the model in a format that can be served + jit_model = torch.jit.trace(model, input) + jit_model.save('serving_model') + + # Predict class using traced model on input + output = jit_model(input) + top5_prob, top5_catid = postprocess(output) + + for i in range(top5_prob.size(0)): + print(mnist_labels_list[top5_catid[i]], top5_prob[i].item()) + + # create the config.pbtxt for triton to be able to serve the model + create_config_pbtxt(config_pbtxt_file='config.pbtxt') + + task.update_output_model(model_path='serving_model') + + # store the configuration on the creating Task, + # this will allow us to skip over manually setting the config.pbtxt for `clearml-serving` + task.connect_configuration(configuration=Path('config.pbtxt'), name='config.pbtxt') + + +if __name__ == '__main__': + main()