From 9d6480311c605fb5f35485195b850836eec0f4fa Mon Sep 17 00:00:00 2001 From: chun Date: Fri, 22 Dec 2023 00:36:18 +0800 Subject: [PATCH] get_num_workers.py added --- best_number_workers.py | 31 +++++++++++++++++++++++++++++++ eval.py | 2 +- train.py | 10 +++++----- 3 files changed, 37 insertions(+), 6 deletions(-) create mode 100644 best_number_workers.py diff --git a/best_number_workers.py b/best_number_workers.py new file mode 100644 index 0000000..efa35f3 --- /dev/null +++ b/best_number_workers.py @@ -0,0 +1,31 @@ +from time import time +import multiprocessing as mp +import torch +import torchvision +from torchvision import transforms +from torch.utils.data import DataLoader, RandomSampler + + +if __name__ == '__main__': + transform = transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.1307,), (0.3081,)) + ]) + + trainset = torchvision.datasets.MNIST( + root='./Dataset/MNIST/', + train=True, # 如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。 + download=True, # 如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。 + transform=transform + ) + + print(f"num of CPU: {mp.cpu_count()}") + for num_workers in range(2, mp.cpu_count(), 2): + train_loader = torch.utils.data.DataLoader( + trainset, shuffle=True, num_workers=num_workers, batch_size=64, pin_memory=True) + start = time() + for epoch in range(1, 3): + for i, data in enumerate(train_loader, 0): + pass + end = time() + print("Finish with:{} second, num_workers={}".format(end - start, num_workers)) diff --git a/eval.py b/eval.py index f60246f..2624b4c 100644 --- a/eval.py +++ b/eval.py @@ -1 +1 @@ -# to be implemented \ No newline at end of file +# to be implemented diff --git a/train.py b/train.py index a8862fc..4ff4100 100644 --- a/train.py +++ b/train.py @@ -23,12 +23,12 @@ def config_parser(): parser.add_argument('--lr', default=1e-3, type=float, help='learning rate') parser.add_argument('--epochs', default=100, type=int, help='number of epochs') parser.add_argument('--batch_size', default=64, type=int, help='batch size') - parser.add_argument('--momentum', default=0.9, type=float, help='momentum') parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay') parser.add_argument('--channel', default='AWGN', type=str, help='channel type') parser.add_argument('--saved', default='./saved', type=str, help='saved_path') parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list') parser.add_argument('--ratio_list', default=[1/6, 1/12], type=list, help='ratio_list') + parser.add_argument('--num_workers', default=0, type=int, help='num_workers') return parser.parse_args() @@ -44,20 +44,20 @@ def main(): def train(args: config_parser(), ratio: float, snr: float): - device = torch.device('cuda:1') # load data transform = transforms.Compose([transforms.ToTensor(), ]) train_dataset = datasets.CIFAR10(root='./Dataset/', train=True, download=True, transform=transform) - train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size) + train_loader = DataLoader(train_dataset, shuffle=True, + batch_size=args.batch_size, num_workers=args.num_workers) test_dataset = datasets.CIFAR10(root='./Dataset/', train=False, download=True, transform=transform) test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size) - + print("training with ratio: {}, snr_db: {}, channel: {}".format(ratio, snr, args.channel)) - + image_fisrt = train_dataset.__getitem__(0)[0] c = ratio2filtersize(image_fisrt, ratio) model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device)