介绍DDP(DistributedDataParallel)使用
1.初始化模型,数据,配置
1 2 3 4 5 6 7 8 9 10
| import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP parser = argparse.ArgumentParser() parser.add_argument("--local_rank", default=-1, type=int) FLAGS = parser.parse_args() local_rank = FLAGS.local_rank
torch.cuda.set_device(local_rank) dist.init_process_group(backend='nccl')
|
2.构造数据
1 2
| train_sampler = torch.utils.data.distributed.DistributedSampler(my_trainset) trainloader = torch.utils.data.DataLoader(my_trainset, batch_size=16, num_workers=2, sampler=train_sampler)
|
3.模型初始化
1 2 3
| model = Transformer(46, 7004, 0, 0).to(local_rank)
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
|
4.训练设置
1 2 3
| train_loader.sampler.set_epoch(epoch) zhuyin = zhuyin.to(local_rank) text = text.to(local_rank)
|
5.保存模型
1 2
| if dist.get_rank() == 0: torch.save(model.module.state_dict(), "checkpoints/transformer.pth")
|