要训练一个模型,主要分成几个部分,如下。
数据预处理入门的话肯定是拿 MNIST 手写数据集先练习。pytorch 中有帮助我们制作数据生成器的模块,其中有 Dataset、TensorDataset、DataLoader 等类可以来创建数据入口。之前在 tensorflow 中可以用 dataset.from_generator() 的形式,pytorch 中也类似,目前我了解到的有两种方法可以实现。
第一种就继承 pytorch 定义的 dataset,改写其中的方法即可。如下,就获得了一个 DataLoader 生成器。
class MyDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __getitem__(self, index):
return self.data[index], self.labels[index]
def __len__(self):
return len(self.labels)
train_dataset = MyDataset(train_data, train_label)
train_loader = DataLoader(dataset = train_dataset,
batch_size = 1,
shuffle = True)
第二种就是转换,先把我们准备好的数据转化成 pytorch 的变量(或者是 Tensor),然后传入 TensorDataset,再构造 DataLoader。
X = torch.from_numpy(train_data).float()
Y = torch.from_numpy(train_label).float()
train_dataset = TensorDataset(X, Y)
train_loader = DataLoader(dataset = train_dataset,
batch_size = 1,
shuffle = True)
模型定义class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 3)
self.conv2 = nn.Co