PyTorch 单机多卡训练方法

介绍DDP(DistributedDataParallel)使用

1.初始化模型,数据,配置

1
2
3
4
5
6
7
8
9
10
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=-1, type=int)
FLAGS = parser.parse_args()
local_rank = FLAGS.local_rank

# DDP初始化
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl')

2.构造数据

1
2
train_sampler = torch.utils.data.distributed.DistributedSampler(my_trainset)
trainloader = torch.utils.data.DataLoader(my_trainset, batch_size=16, num_workers=2, sampler=train_sampler)

3.模型初始化

1
2
3
model = Transformer(46, 7004, 0, 0).to(local_rank)

model = DDP(model, device_ids=[local_rank], output_device=local_rank)

4.训练设置

1
2
3
train_loader.sampler.set_epoch(epoch)
zhuyin = zhuyin.to(local_rank)
text = text.to(local_rank)

5.保存模型

1
2
if dist.get_rank() == 0:
torch.save(model.module.state_dict(), "checkpoints/transformer.pth")