31 lines
1.1 KiB
Python
31 lines
1.1 KiB
Python
from time import time
|
||
import multiprocessing as mp
|
||
import torch
|
||
import torchvision
|
||
from torchvision import transforms
|
||
from torch.utils.data import DataLoader, RandomSampler
|
||
|
||
|
||
if __name__ == '__main__':
|
||
transform = transforms.Compose([
|
||
torchvision.transforms.ToTensor(),
|
||
])
|
||
|
||
trainset = torchvision.datasets.CIFAR10(
|
||
root='./dataset/',
|
||
train=True, # 如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。
|
||
download=True, # 如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
|
||
transform=transform
|
||
)
|
||
|
||
print(f"num of CPU: {mp.cpu_count()}")
|
||
for num_workers in range(2, mp.cpu_count(), 2):
|
||
train_loader = torch.utils.data.DataLoader(
|
||
trainset, shuffle=True, num_workers=num_workers, batch_size=64, pin_memory=True)
|
||
start = time()
|
||
for epoch in range(1, 3):
|
||
for i, data in enumerate(train_loader, 0):
|
||
pass
|
||
end = time()
|
||
print("Finish with:{} second, num_workers={}".format(end - start, num_workers))
|