JSCC/best_number_workers.py
2024-01-23 22:18:11 +08:00

31 lines
1.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from time import time
import multiprocessing as mp
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, RandomSampler
if __name__ == '__main__':
transform = transforms.Compose([
torchvision.transforms.ToTensor(),
])
trainset = torchvision.datasets.CIFAR10(
root='./dataset/',
train=True, # 如果为True从 training.pt 创建数据,否则从 test.pt 创建数据。
download=True, # 如果为true则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
transform=transform
)
print(f"num of CPU: {mp.cpu_count()}")
for num_workers in range(2, mp.cpu_count(), 2):
train_loader = torch.utils.data.DataLoader(
trainset, shuffle=True, num_workers=num_workers, batch_size=64, pin_memory=True)
start = time()
for epoch in range(1, 3):
for i, data in enumerate(train_loader, 0):
pass
end = time()
print("Finish with:{} second, num_workers={}".format(end - start, num_workers))