from torch.multiprocessing import Process
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
class Random_Dataset(Dataset):
def __init__(self, num=1200, dim=300):
self.num = num
self.dim = dim
self.data = torch.rand(num, dim)
def __len__(self):
return self.num
def __getitem__(self, idx):
return self.data[idx]
class Model(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super(Model, self).__init__()
self.rnn = nn.LSTM(
input_size=input_size, hidden_size=hidden_size,
num_layers=num_layers, batch_first=True
)
def forward(self, x, lengths):
total_length = x.size(1)
self.rnn.flatten_parameters()
x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True)
outputs, (h_n, c_n) = self.rnn(x)
outputs, _ = nn.utils.rnn.pad_packed_sequence(
outputs, batch_first=True, total_length=total_length
)
return outputs
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
super(Generator, self).__init__()
padding = int((kernel_size - 1) / 2)
self.conv = nn.Conv1d(
in_channels, out_channels, kernel_size,
padding=padding
)
def forward(self, x):
x = torch.unsqueeze(x, 1)
x = self.conv(x)
x = F.sigmoid(x)
return x
def init_process(rank, world_size, backend, func, params):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
group = dist.new_group([0, 1])
func(*params, group)
def generate_data(dataloader, num_epochs, rank, device_ids, network_params,
use_cuda, group):
generator = Generator(*network_params)
if use_cuda:
generator = generator.cuda()
print("Network initialized!")
generator = nn.parallel.distributed.DistributedDataParallel(
generator, device_ids=device_ids
)
for epoch in range(num_epochs):
for i_batch, batch in enumerate(dataloader):
batch = batch.cuda()
rnn_input = generator(batch)
dist.broadcast(rnn_input, rank, group)
print("epoch:{}, batch_num:{}, broadcast finished!".format(
epoch, i_batch
)
)
def run_rnn(num_batchs, num_epochs, src_rank, device_ids, network_params,
input_size, use_cuda, group):
rnn = Model(*network_params)
if use_cuda:
rnn = rnn.cuda()
print("Network initialized!")
rnn = nn.parallel.distributed.DistributedDataParallel(
rnn, device_ids=device_ids
)
optimizer = optim.Adam(rnn.parameters(), lr=1e-4)
for epoch in range(num_epochs):
for i_batch in range(num_batchs):
optimizer.zero_grad()
rnn_input = torch.Tensor(input_size).cuda()
dist.broadcast(rnn_input, src_rank, group)
print("epoch:{}, batch_num:{}, receive finished!".format(
epoch, i_batch
)
)
batch_size = rnn_inputs.size(0)
lengths = np.random.randint(low=3, high=30, size=(batch_size))
lengths = -np.sort(-lengths)
lengths = torch.from_numpy(lengths).long().cuda()
out = rnn(rnn_input, lengths)
out = torch.sum(out)
out.backward()
optimizer.step()
print("out:{}".format(out.item()))
rnn = rnn.cpu()
torch.save('rnn.net', rnn.state_dict)
def main(use_cuda=False):
world_size = 2
processes = []
dataset = Random_Dataset()
dataloader = DataLoader(
dataset, batch_size=12, shuffle=True, num_workers=1
)
num_epochs = 2
num_batchs = 100
generator_device_ids = [0, 1]
rnn_device_ids = [2, 3]
generator_params = (1, 50, 5)
rnn_params = (300, 300, 3)
p1 = Process(
target=init_process,
args=(0, world_size, 'nccl', generate_data,
(dataloader, num_epochs, 0, generator_device_ids,
generator_params, use_cuda)
)
)
p1.start()
processes.append(p1)
p2 = Process(
target=init_process,
args=(1, world_size, 'gloo', run_rnn,
(num_batchs, num_epochs, 0,
rnn_device_ids, rnn_params, torch.Size((12, 50, 300)),
use_cuda)
)
)
p2.start()
processes.append(p2)
for p in processes:
p.join()
if __name__ == '__main__':
main(True)
还没有评论,来说两句吧...