Model serving in data parallelism
So far, we have discussed the whole training pipeline via data parallelism. We will now illustrate the implementation details of data parallel serving.
First, we need to define our test dataset:
test_set = datasets.MNIST('./mnist_data', download=True, train=False, transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,),(0.3081,))]))
Next, we need to load the trained models into our GPUs using the load_checkpoint()
function, which we defined previously.
Then, we need to define our in-parallel model test function, as follows:
def test(local_rank, args): world_size = args.machines*args.gpus rank = args.mid * args.gpus + local_rank ... torch.cuda...