mirror of
https://github.com/clearml/clearml-serving
synced 2025-06-26 18:16:00 +00:00
add torchscript mnist example
This commit is contained in:
parent
b5f5d72046
commit
6a9acdab3d
7
examples/torchscript/requirements.txt
Normal file
7
examples/torchscript/requirements.txt
Normal file
@ -0,0 +1,7 @@
|
||||
clearml
|
||||
torch
|
||||
torchvision
|
||||
Pillow
|
||||
nvidia-pyindex
|
||||
tritonclient
|
||||
requests
|
176
examples/torchscript/torchscript_mnist.py
Normal file
176
examples/torchscript/torchscript_mnist.py
Normal file
@ -0,0 +1,176 @@
|
||||
# ClearML - Torchscript example code, automatic logging traced model
|
||||
# Then store the torchscript model to be served by clearml-serving
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from clearml import OutputModel, Task
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
class AlexNet(nn.Module):
|
||||
|
||||
def __init__(self, num_classes: int = 1000) -> None:
|
||||
super(AlexNet, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2),
|
||||
nn.Conv2d(64, 192, kernel_size=5, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2),
|
||||
nn.Conv2d(192, 384, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(384, 256, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2),
|
||||
)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(),
|
||||
nn.Linear(256 * 6 * 6, 4096),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(),
|
||||
nn.Linear(4096, 4096),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(4096, num_classes),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.features(x)
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
|
||||
def create_config_pbtxt(config_pbtxt_file):
|
||||
platform = "pytorch_libtorch"
|
||||
input_name = 'INPUT__0'
|
||||
output_name = 'OUTPUT__0'
|
||||
input_data_type = "TYPE_FP32"
|
||||
output_data_type = "TYPE_FP32"
|
||||
input_dims = str([-1, 3, 224, 224])
|
||||
output_dims = str([-1, 1000])
|
||||
|
||||
config_pbtxt = """
|
||||
platform: "%s"
|
||||
input [
|
||||
{
|
||||
name: "%s"
|
||||
data_type: %s
|
||||
dims: %s
|
||||
}
|
||||
]
|
||||
output [
|
||||
{
|
||||
name: "%s"
|
||||
data_type: %s
|
||||
dims: %s
|
||||
}
|
||||
]
|
||||
""" % (
|
||||
platform,
|
||||
input_name, input_data_type, input_dims,
|
||||
output_name, output_data_type, output_dims
|
||||
)
|
||||
|
||||
with open(config_pbtxt_file, "w") as config_file:
|
||||
config_file.write(config_pbtxt)
|
||||
|
||||
|
||||
def preprocess(url):
|
||||
response = requests.get(url)
|
||||
filename = "sample.jpg"
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(response.content)
|
||||
|
||||
input_image = Image.open(filename)
|
||||
input_tensor = transform(input_image)
|
||||
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
|
||||
return input_batch
|
||||
|
||||
|
||||
def postprocess(output):
|
||||
probabilities = torch.nn.functional.softmax(output[0], dim=0)
|
||||
top5_prob, top5_catid = torch.topk(probabilities, 5)
|
||||
return top5_prob, top5_catid
|
||||
|
||||
|
||||
def get_mnist_labels():
|
||||
data = requests.get("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")
|
||||
return data.text.split("\n")
|
||||
|
||||
|
||||
def get_alexnet_state_dict():
|
||||
filename = "alexnet_weights.pt"
|
||||
if not Path(filename).exists():
|
||||
response = requests.get("https://download.pytorch.org/models/alexnet-owt-7be5be79.pth")
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(response.content)
|
||||
return torch.load(filename)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Torchscript MNIST Example - serving torchscript model')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Connecting ClearML with the current process,
|
||||
# from here on everything is logged automatically
|
||||
task = Task.init(project_name='examples', task_name='Torchscript MNIST serve example', output_uri=True)
|
||||
|
||||
# This could work, but the github api limits the number of downloads
|
||||
# model = torch.hub.load('pytorch/vision:v0.9.0', 'alexnet', pretrained=True)
|
||||
|
||||
# Instead we use hardcoded AlexNet model
|
||||
model = AlexNet()
|
||||
state_dict = get_alexnet_state_dict()
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
model.eval()
|
||||
|
||||
# Advanced: setting model class enumeration
|
||||
mnist_labels_list = get_mnist_labels()
|
||||
labels = {label: i for i, label in enumerate(mnist_labels_list)}
|
||||
task.set_model_label_enumeration(labels)
|
||||
|
||||
# Get a input image for the model
|
||||
url = 'https://github.com/pytorch/hub/raw/master/images/dog.jpg'
|
||||
input = preprocess(url)
|
||||
|
||||
# Trace and save the model in a format that can be served
|
||||
jit_model = torch.jit.trace(model, input)
|
||||
jit_model.save('serving_model')
|
||||
|
||||
# Predict class using traced model on input
|
||||
output = jit_model(input)
|
||||
top5_prob, top5_catid = postprocess(output)
|
||||
|
||||
for i in range(top5_prob.size(0)):
|
||||
print(mnist_labels_list[top5_catid[i]], top5_prob[i].item())
|
||||
|
||||
# create the config.pbtxt for triton to be able to serve the model
|
||||
create_config_pbtxt(config_pbtxt_file='config.pbtxt')
|
||||
|
||||
task.update_output_model(model_path='serving_model')
|
||||
|
||||
# store the configuration on the creating Task,
|
||||
# this will allow us to skip over manually setting the config.pbtxt for `clearml-serving`
|
||||
task.connect_configuration(configuration=Path('config.pbtxt'), name='config.pbtxt')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue
Block a user