pytorch Distribute分布式训练

r囧r小猫 2022-04-15 03:47 306阅读 0赞
  1. from torch.multiprocessing import Process
  2. from torch.utils.data import Dataset, DataLoader
  3. import numpy as np
  4. import os
  5. import torch
  6. import torch.distributed as dist
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. class Random_Dataset(Dataset):
  10. def __init__(self, num=1200, dim=300):
  11. self.num = num
  12. self.dim = dim
  13. self.data = torch.rand(num, dim)
  14. def __len__(self):
  15. return self.num
  16. def __getitem__(self, idx):
  17. return self.data[idx]
  18. class Model(nn.Module):
  19. def __init__(self, input_size, hidden_size, num_layers):
  20. super(Model, self).__init__()
  21. self.rnn = nn.LSTM(
  22. input_size=input_size, hidden_size=hidden_size,
  23. num_layers=num_layers, batch_first=True
  24. )
  25. def forward(self, x, lengths):
  26. total_length = x.size(1)
  27. self.rnn.flatten_parameters()
  28. x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True)
  29. outputs, (h_n, c_n) = self.rnn(x)
  30. outputs, _ = nn.utils.rnn.pad_packed_sequence(
  31. outputs, batch_first=True, total_length=total_length
  32. )
  33. return outputs
  34. class Generator(nn.Module):
  35. def __init__(self, in_channels, out_channels, kernel_size):
  36. super(Generator, self).__init__()
  37. padding = int((kernel_size - 1) / 2)
  38. self.conv = nn.Conv1d(
  39. in_channels, out_channels, kernel_size,
  40. padding=padding
  41. )
  42. def forward(self, x):
  43. x = torch.unsqueeze(x, 1)
  44. x = self.conv(x)
  45. x = F.sigmoid(x)
  46. return x
  47. def init_process(rank, world_size, backend, func, params):
  48. os.environ['MASTER_ADDR'] = '127.0.0.1'
  49. os.environ['MASTER_PORT'] = '29500'
  50. dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
  51. group = dist.new_group([0, 1])
  52. func(*params, group)
  53. def generate_data(dataloader, num_epochs, rank, device_ids, network_params,
  54. use_cuda, group):
  55. generator = Generator(*network_params)
  56. if use_cuda:
  57. generator = generator.cuda()
  58. print("Network initialized!")
  59. generator = nn.parallel.distributed.DistributedDataParallel(
  60. generator, device_ids=device_ids
  61. )
  62. for epoch in range(num_epochs):
  63. for i_batch, batch in enumerate(dataloader):
  64. batch = batch.cuda()
  65. rnn_input = generator(batch)
  66. dist.broadcast(rnn_input, rank, group)
  67. print("epoch:{}, batch_num:{}, broadcast finished!".format(
  68. epoch, i_batch
  69. )
  70. )
  71. def run_rnn(num_batchs, num_epochs, src_rank, device_ids, network_params,
  72. input_size, use_cuda, group):
  73. rnn = Model(*network_params)
  74. if use_cuda:
  75. rnn = rnn.cuda()
  76. print("Network initialized!")
  77. rnn = nn.parallel.distributed.DistributedDataParallel(
  78. rnn, device_ids=device_ids
  79. )
  80. optimizer = optim.Adam(rnn.parameters(), lr=1e-4)
  81. for epoch in range(num_epochs):
  82. for i_batch in range(num_batchs):
  83. optimizer.zero_grad()
  84. rnn_input = torch.Tensor(input_size).cuda()
  85. dist.broadcast(rnn_input, src_rank, group)
  86. print("epoch:{}, batch_num:{}, receive finished!".format(
  87. epoch, i_batch
  88. )
  89. )
  90. batch_size = rnn_inputs.size(0)
  91. lengths = np.random.randint(low=3, high=30, size=(batch_size))
  92. lengths = -np.sort(-lengths)
  93. lengths = torch.from_numpy(lengths).long().cuda()
  94. out = rnn(rnn_input, lengths)
  95. out = torch.sum(out)
  96. out.backward()
  97. optimizer.step()
  98. print("out:{}".format(out.item()))
  99. rnn = rnn.cpu()
  100. torch.save('rnn.net', rnn.state_dict)
  101. def main(use_cuda=False):
  102. world_size = 2
  103. processes = []
  104. dataset = Random_Dataset()
  105. dataloader = DataLoader(
  106. dataset, batch_size=12, shuffle=True, num_workers=1
  107. )
  108. num_epochs = 2
  109. num_batchs = 100
  110. generator_device_ids = [0, 1]
  111. rnn_device_ids = [2, 3]
  112. generator_params = (1, 50, 5)
  113. rnn_params = (300, 300, 3)
  114. p1 = Process(
  115. target=init_process,
  116. args=(0, world_size, 'nccl', generate_data,
  117. (dataloader, num_epochs, 0, generator_device_ids,
  118. generator_params, use_cuda)
  119. )
  120. )
  121. p1.start()
  122. processes.append(p1)
  123. p2 = Process(
  124. target=init_process,
  125. args=(1, world_size, 'gloo', run_rnn,
  126. (num_batchs, num_epochs, 0,
  127. rnn_device_ids, rnn_params, torch.Size((12, 50, 300)),
  128. use_cuda)
  129. )
  130. )
  131. p2.start()
  132. processes.append(p2)
  133. for p in processes:
  134. p.join()
  135. if __name__ == '__main__':
  136. main(True)

发表评论

表情:
评论列表 (有 0 条评论,306人围观)

还没有评论,来说两句吧...

相关阅读