Checkpointing and fault tolerance
Previously, we discussed two different implementations of data parallel training, namely DistributedDataParallel()
and DataParallel()
.
One thing we are missing here is fault tolerance, which is important in distributed systems.
Since DistributedDataParallel()
is better than DataParallel()
, we will illustrate our checkpointing implementation in DistributedDataParallel()
setting. In this setting, each process is responsible for checkpointing a model from one GPU.
Model checkpointing
First, we will discuss how we can achieve in-parallel model saving, also known as model checkpointing.
The checkpointing function in the multi-processing setting is defined as follows:
def checkpointing(rank, epoch, net, optimizer, loss): path = f"model{rank}.pt" torch.save({ 'epoch':epoch, 'model_state'...