diff --git a/best_number_workers.py b/best_number_workers.py index efa35f3..b1e9038 100644 --- a/best_number_workers.py +++ b/best_number_workers.py @@ -9,11 +9,10 @@ from torch.utils.data import DataLoader, RandomSampler if __name__ == '__main__': transform = transforms.Compose([ torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize((0.1307,), (0.3081,)) ]) - trainset = torchvision.datasets.MNIST( - root='./Dataset/MNIST/', + trainset = torchvision.datasets.CIFAR10( + root='./Dataset/', train=True, # 如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。 download=True, # 如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。 transform=transform