mirror of
https://github.com/clearml/clearml-serving
synced 2025-06-26 18:16:00 +00:00
177 lines
5.5 KiB
Python
177 lines
5.5 KiB
Python
# ClearML - Torchscript example code, automatic logging traced model
|
|
# Then store the torchscript model to be served by clearml-serving
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
import requests
|
|
import torch
|
|
from clearml import OutputModel, Task
|
|
from PIL import Image
|
|
from torch import nn
|
|
from torchvision import transforms
|
|
|
|
|
|
class AlexNet(nn.Module):
|
|
|
|
def __init__(self, num_classes: int = 1000) -> None:
|
|
super(AlexNet, self).__init__()
|
|
self.features = nn.Sequential(
|
|
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
|
|
nn.ReLU(inplace=True),
|
|
nn.MaxPool2d(kernel_size=3, stride=2),
|
|
nn.Conv2d(64, 192, kernel_size=5, padding=2),
|
|
nn.ReLU(inplace=True),
|
|
nn.MaxPool2d(kernel_size=3, stride=2),
|
|
nn.Conv2d(192, 384, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(384, 256, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.MaxPool2d(kernel_size=3, stride=2),
|
|
)
|
|
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
|
|
self.classifier = nn.Sequential(
|
|
nn.Dropout(),
|
|
nn.Linear(256 * 6 * 6, 4096),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout(),
|
|
nn.Linear(4096, 4096),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(4096, num_classes),
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.features(x)
|
|
x = self.avgpool(x)
|
|
x = torch.flatten(x, 1)
|
|
x = self.classifier(x)
|
|
return x
|
|
|
|
|
|
transform = transforms.Compose([
|
|
transforms.Resize(256),
|
|
transforms.CenterCrop(224),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
])
|
|
|
|
|
|
def create_config_pbtxt(config_pbtxt_file):
|
|
platform = "pytorch_libtorch"
|
|
input_name = 'INPUT__0'
|
|
output_name = 'OUTPUT__0'
|
|
input_data_type = "TYPE_FP32"
|
|
output_data_type = "TYPE_FP32"
|
|
input_dims = str([-1, 3, 224, 224])
|
|
output_dims = str([-1, 1000])
|
|
|
|
config_pbtxt = """
|
|
platform: "%s"
|
|
input [
|
|
{
|
|
name: "%s"
|
|
data_type: %s
|
|
dims: %s
|
|
}
|
|
]
|
|
output [
|
|
{
|
|
name: "%s"
|
|
data_type: %s
|
|
dims: %s
|
|
}
|
|
]
|
|
""" % (
|
|
platform,
|
|
input_name, input_data_type, input_dims,
|
|
output_name, output_data_type, output_dims
|
|
)
|
|
|
|
with open(config_pbtxt_file, "w") as config_file:
|
|
config_file.write(config_pbtxt)
|
|
|
|
|
|
def preprocess(url):
|
|
response = requests.get(url)
|
|
filename = "sample.jpg"
|
|
with open(filename, 'wb') as f:
|
|
f.write(response.content)
|
|
|
|
input_image = Image.open(filename)
|
|
input_tensor = transform(input_image)
|
|
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
|
|
return input_batch
|
|
|
|
|
|
def postprocess(output):
|
|
probabilities = torch.nn.functional.softmax(output[0], dim=0)
|
|
top5_prob, top5_catid = torch.topk(probabilities, 5)
|
|
return top5_prob, top5_catid
|
|
|
|
|
|
def get_mnist_labels():
|
|
data = requests.get("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")
|
|
return data.text.split("\n")
|
|
|
|
|
|
def get_alexnet_state_dict():
|
|
filename = "alexnet_weights.pt"
|
|
if not Path(filename).exists():
|
|
response = requests.get("https://download.pytorch.org/models/alexnet-owt-7be5be79.pth")
|
|
with open(filename, 'wb') as f:
|
|
f.write(response.content)
|
|
return torch.load(filename)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Torchscript MNIST Example - serving torchscript model')
|
|
args = parser.parse_args()
|
|
|
|
# Connecting ClearML with the current process,
|
|
# from here on everything is logged automatically
|
|
task = Task.init(project_name='examples', task_name='Torchscript MNIST serve example', output_uri=True)
|
|
|
|
# This could work, but the github api limits the number of downloads
|
|
# model = torch.hub.load('pytorch/vision:v0.9.0', 'alexnet', pretrained=True)
|
|
|
|
# Instead we use hardcoded AlexNet model
|
|
model = AlexNet()
|
|
state_dict = get_alexnet_state_dict()
|
|
model.load_state_dict(state_dict)
|
|
|
|
model.eval()
|
|
|
|
# Advanced: setting model class enumeration
|
|
mnist_labels_list = get_mnist_labels()
|
|
labels = {label: i for i, label in enumerate(mnist_labels_list)}
|
|
task.set_model_label_enumeration(labels)
|
|
|
|
# Get a input image for the model
|
|
url = 'https://github.com/pytorch/hub/raw/master/images/dog.jpg'
|
|
input = preprocess(url)
|
|
|
|
# Trace and save the model in a format that can be served
|
|
jit_model = torch.jit.trace(model, input)
|
|
jit_model.save('serving_model')
|
|
|
|
# Predict class using traced model on input
|
|
output = jit_model(input)
|
|
top5_prob, top5_catid = postprocess(output)
|
|
|
|
for i in range(top5_prob.size(0)):
|
|
print(mnist_labels_list[top5_catid[i]], top5_prob[i].item())
|
|
|
|
# create the config.pbtxt for triton to be able to serve the model
|
|
create_config_pbtxt(config_pbtxt_file='config.pbtxt')
|
|
|
|
task.update_output_model(model_path='serving_model')
|
|
|
|
# store the configuration on the creating Task,
|
|
# this will allow us to skip over manually setting the config.pbtxt for `clearml-serving`
|
|
task.connect_configuration(configuration=Path('config.pbtxt'), name='config.pbtxt')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|