1.4模型实现

import torch
import torchvision # 常用数据集包
import torchvision.transforms as transforms # 数据转归一化
transform = transforms.Compose( # 自定义一个转换器,先变成张量,再归一化(每个通道的均值序列,标准差序列)
    [transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # (0-0.5)/0.5 = -1

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, 
                                         transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True,
                                         num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True,
                                       transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False,
                                        num_workers=2)
# num_workers根据计算机的CPU和内存来设置,充足可以设置多一些
# 设为0表示不用内存

classes = ('plane','car', 'bird','cat','deer', 'dog','frog','horse','ship','truck')
Files already downloaded and verified
Files already downloaded and verified
png
png

最后更新于

这有帮助吗?