PyTorch回归训练 傷城~ 2022-01-23 23:07 234阅读 0赞 ### 1. 创建用于回归的虚拟数据集 ### ### 2. 划分训练集和测试集 ### ### 3. 参数初始化比较 ### ### 4 批训练方法 ### #!/usr/bin/env python # -*- coding: utf-8 -*- """ __title__ = '' """ import torch from torch import nn import torch.nn.functional as F import torch.utils.data as Data import matplotlib.pyplot as plt from sklearn.model_selection import train_test_split import numpy as np from torchvision import datasets, transforms from torch.nn import init #创建fake data # torch.manual_seed(99) # x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1) # y = x.pow(2) + 0.1 * torch.normal(torch.zeros(x.size())) # plt.scatter(x.numpy(), y.numpy()) # plt.show() np.random.seed(666) X = np.linspace(-1, 1, 1000) y = np.power(X, 2) + 0.1 * np.random.normal(0, 1, X.size) print(X.shape) print(y.shape) # plt.scatter(X, y) # plt.show() # 创建训练集和测试集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1024) X_train = torch.from_numpy(X_train).type(torch.FloatTensor) X_train = torch.unsqueeze(X_train, dim=1) #转换成二维 y_train = torch.from_numpy(y_train).type(torch.FloatTensor) y_train = torch.unsqueeze(y_train, dim=1) print(X_train.type) X_test = torch.from_numpy(X_test).type(torch.FloatTensor) X_test = torch.unsqueeze(X_test, dim=1) #转换成二维 # train_size = int(0.7 * len(X)) # test_size = len(X) - train_size # X_train, X_test = Data.random_split(X, [train_size, test_size]) # print(len(X_train), len(X_test)) BATCH_SIZE = 50 LR = 0.02 EPOCH = 5 #将数据装载镜data中, 对数据进行分批训练 torch_data = Data.TensorDataset(X_train, y_train) loader = Data.DataLoader(dataset=torch_data, batch_size=BATCH_SIZE, shuffle=True) #创建自己的nn class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.hidden = nn.Linear(1, 20) self.predict = nn.Linear(20, 1) def forward(self, x): x = F.relu(self.hidden(x)) x = self.predict(x) return x def weights_init(m): if isinstance(m, nn.Linear): init.kaiming_normal(m.weight.data) # init.xavier_normal(m.bias.data) adam_net = Net() # adam_net.apply(weights_init) # 对整个网络层进行参数初始化 # 有初始化损失 xavier_normal : #{1: tensor(0.0972, grad_fn=<MseLossBackward>), 2: tensor(0.0731, grad_fn=<MseLossBackward>), # 3: tensor(0.0881, grad_fn=<MseLossBackward>), 4: tensor(0.1120, grad_fn=<MseLossBackward>), # 5: tensor(0.1012, grad_fn=<MseLossBackward>)} # 有初始化损失 kaiming_normal : 表现相对较好--- #{1: tensor(0.1476, grad_fn=<MseLossBackward>), 2: tensor(0.0234, grad_fn=<MseLossBackward>), # 3: tensor(0.0162, grad_fn=<MseLossBackward>), 4: tensor(0.0170, grad_fn=<MseLossBackward>), # 5: tensor(0.0218, grad_fn=<MseLossBackward>)} # 没有初始化损失 # {1: tensor(0.0265, grad_fn=<MseLossBackward>), 2: tensor(0.0121, grad_fn=<MseLossBackward>), # 3: tensor(0.0096, grad_fn=<MseLossBackward>), 4: tensor(0.0109, grad_fn=<MseLossBackward>), # 5: tensor(0.0104, grad_fn=<MseLossBackward>)} #设置优化器和损失函数 opt_adam = torch.optim.Adam(adam_net.parameters(), lr=LR) loss_func = nn.MSELoss() #对数据进行分批训练 # 在神经网络中传递完整的数据集一次是不够的, # 而且我们需要将完整的数据集在同样的神经网络中传递多次。 # 但是请记住,我们使用的是有限的数据集, # 并且我们使用一个迭代过程即梯度下降。因此仅仅更新权重一次或者说使用一个 epoch 是不够的。 # 比如对于一个有 2000 个训练样本的数据集。将 2000 个样本分成大小为 500 的 batch,那么完成一个 epoch 需要 4 个 iteration。 all_loss = {} for epoch in range(EPOCH): print('epoch', epoch) for step, (b_x, b_y) in enumerate(loader): print('step', step) pre = adam_net(b_x) loss = loss_func(pre, b_y) opt_adam.zero_grad() loss.backward() opt_adam.step() # print(loss) all_loss[epoch+1] = loss print(all_loss) #对测试集进行预测 adam_net.eval() predict = adam_net(X_test) predict = predict.data.numpy() plt.scatter(X_test.numpy(), y_test, label='origin') plt.scatter(X_test.numpy(), predict, color='red', label='predict') plt.legend() plt.show() ![watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzQwNTg3NTc1_size_16_color_FFFFFF_t_70][] **没有使用参数初始化的结果** ![watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzQwNTg3NTc1_size_16_color_FFFFFF_t_70 1][] **使用参数初始化的结果** ### 项目推荐: ### **[2000多G的计算机各行业电子资源分享(持续更新)][2000_G]** **[2020年微信小程序全栈项目之喵喵交友【附课件和源码】][2020]** **[Spring Boot开发小而美的个人博客【附课件和源码】][Spring Boot]** **[Java微服务实战296集大型视频-谷粒商城【附代码和课件】][Java_296_-]** **[Java开发微服务畅购商城实战【全357集大项目】-附代码和课件][Java_357_-]** **[最全最详细数据结构与算法视频-【附课件和源码】][-]** ![在这里插入图片描述][watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzQwNTg3NTc1_size_16_color_FFFFFF_t_70_pic_center] [watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzQwNTg3NTc1_size_16_color_FFFFFF_t_70]: /images/20220124/87bb9f78ad35405ca471654daac526f6.png [watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzQwNTg3NTc1_size_16_color_FFFFFF_t_70 1]: /images/20220124/40fcd58a5aee469ca7d6bc8c029724c9.png [2000_G]: https://mp.weixin.qq.com/s/sP4JgGWkCzpgwKr9sAV2_Q [2020]: https://mp.weixin.qq.com/s?__biz=MzIyNTI3NDQ4NQ==&mid=2247487704&idx=1&sn=5f4b2127c4d49fd07ae072a0721424a2&chksm=e8036fc2df74e6d489c4aa9b06f917ef7cee6027f13e150fca53cf79d5d188d4ccc1af49e098&scene=21#wechat_redirect [Spring Boot]: https://mp.weixin.qq.com/s?__biz=MzIyNTI3NDQ4NQ==&mid=2247487798&idx=2&sn=ac0293b996521b872a9dba5fbb3e65e6&chksm=e8036e2cdf74e73aba104a9a994a5b2e31483e8dcbe0f1d9936f6d5173b887e1560f59d2819c&scene=21#wechat_redirect [Java_296_-]: https://mp.weixin.qq.com/s?__biz=MzIyNTI3NDQ4NQ==&mid=2247487674&idx=1&sn=7aff0bdf2bb727303f3d3618995aef21&chksm=e8036fa0df74e6b6d872c7e6ece179c524ed463a4a6b74c96875475c9a3d5ddb903427dd993b&scene=21#wechat_redirect [Java_357_-]: https://mp.weixin.qq.com/s?__biz=MzIyNTI3NDQ4NQ==&mid=2247486376&idx=1&sn=d1fef270c463ea8ac663f6fbfedd70a0&chksm=e80374b2df74fda4d3bafba878a106a19e18c5fcda266008f4f37975847a21bc612ffcd5ff39&scene=21#wechat_redirect [-]: https://mp.weixin.qq.com/s?__biz=MzIyNTI3NDQ4NQ==&mid=2247487750&idx=1&sn=747bccbb5f5ea6b58915198de40da777&chksm=e8036e1cdf74e70ae97a5e8e265b49d7236d904d291203309159d07ba1724033062c0e370843&scene=21#wechat_redirect [watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzQwNTg3NTc1_size_16_color_FFFFFF_t_70_pic_center]: /images/20220124/cfa4dd52ec014080acb81a1b8b85374a.png
还没有评论,来说两句吧...