get_num_workers.py added
This commit is contained in:
parent
b62dc1b83f
commit
9d6480311c
31
best_number_workers.py
Normal file
31
best_number_workers.py
Normal file
@ -0,0 +1,31 @@
|
||||
from time import time
|
||||
import multiprocessing as mp
|
||||
import torch
|
||||
import torchvision
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
transform = transforms.Compose([
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize((0.1307,), (0.3081,))
|
||||
])
|
||||
|
||||
trainset = torchvision.datasets.MNIST(
|
||||
root='./Dataset/MNIST/',
|
||||
train=True, # 如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。
|
||||
download=True, # 如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
|
||||
transform=transform
|
||||
)
|
||||
|
||||
print(f"num of CPU: {mp.cpu_count()}")
|
||||
for num_workers in range(2, mp.cpu_count(), 2):
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
trainset, shuffle=True, num_workers=num_workers, batch_size=64, pin_memory=True)
|
||||
start = time()
|
||||
for epoch in range(1, 3):
|
||||
for i, data in enumerate(train_loader, 0):
|
||||
pass
|
||||
end = time()
|
||||
print("Finish with:{} second, num_workers={}".format(end - start, num_workers))
|
||||
6
train.py
6
train.py
@ -23,12 +23,12 @@ def config_parser():
|
||||
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
|
||||
parser.add_argument('--epochs', default=100, type=int, help='number of epochs')
|
||||
parser.add_argument('--batch_size', default=64, type=int, help='batch size')
|
||||
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
|
||||
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
|
||||
parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
|
||||
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
|
||||
parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list')
|
||||
parser.add_argument('--ratio_list', default=[1/6, 1/12], type=list, help='ratio_list')
|
||||
parser.add_argument('--num_workers', default=0, type=int, help='num_workers')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -44,14 +44,14 @@ def main():
|
||||
|
||||
def train(args: config_parser(), ratio: float, snr: float):
|
||||
|
||||
|
||||
device = torch.device('cuda:1')
|
||||
# load data
|
||||
transform = transforms.Compose([transforms.ToTensor(), ])
|
||||
train_dataset = datasets.CIFAR10(root='./Dataset/', train=True,
|
||||
download=True, transform=transform)
|
||||
|
||||
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size)
|
||||
train_loader = DataLoader(train_dataset, shuffle=True,
|
||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||
test_dataset = datasets.CIFAR10(root='./Dataset/', train=False,
|
||||
download=True, transform=transform)
|
||||
test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user