pytorch模型训练流程中遇到的一些坑(持续更新)

发布时间:2024-12-17 23:30

训练过程中的进步,让人持续兴奋 #生活乐趣# #运动乐趣# #运动训练#

最新推荐文章于 2024-10-20 21:59:13 发布

渡边君 于 2019-12-29 23:34:25 发布

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

       要训练一个模型,主要分成几个部分,如下。

数据预处理

      入门的话肯定是拿 MNIST 手写数据集先练习。pytorch 中有帮助我们制作数据生成器的模块,其中有 Dataset、TensorDataset、DataLoader 等类可以来创建数据入口。之前在 tensorflow 中可以用 dataset.from_generator() 的形式,pytorch 中也类似,目前我了解到的有两种方法可以实现。

       第一种就继承 pytorch 定义的 dataset,改写其中的方法即可。如下,就获得了一个 DataLoader 生成器。 

class MyDataset(Dataset):

def __init__(self, data, labels):

self.data = data

self.labels = labels

def __getitem__(self, index):

return self.data[index], self.labels[index]

def __len__(self):

return len(self.labels)

train_dataset = MyDataset(train_data, train_label)

train_loader = DataLoader(dataset = train_dataset,

batch_size = 1,

shuffle = True)

        第二种就是转换,先把我们准备好的数据转化成 pytorch 的变量(或者是 Tensor),然后传入 TensorDataset,再构造 DataLoader。

X = torch.from_numpy(train_data).float()

Y = torch.from_numpy(train_label).float()

train_dataset = TensorDataset(X, Y)

train_loader = DataLoader(dataset = train_dataset,

batch_size = 1,

shuffle = True)

模型定义

class Net(nn.Module):

def __init__(self):

super(Net, self).__init__()

self.conv1 = nn.Conv2d(1, 6, 3)

self.conv2 = nn.Co

网址:pytorch模型训练流程中遇到的一些坑(持续更新) https://www.yuejiaxmz.com/news/view/504192

相关内容

那些年,Pytorh的坑(持续更新)
节省显存新思路,在 PyTorch 里使用 2 bit 激活压缩训练神经网络
pytorch 1.1.0升级
深入了解PyTorch中的语音识别和语音生成
持续改进指南:模型、流程和计划 – PingCode
pytorch中的model=model.to(device)使用说明
PyTorch 节省显存的策略总结
PyTorch 深度学习框架简介:灵活、高效的 AI 开发工具
深入理解PyTorch的语音识别与语音合成1.背景介绍 语音识别和语音合成是人工智能领域中的两个重要技术,它们在现实生活
把显存用在刀刃上!17 种 pytorch 节约显存技巧

随便看看