get_num_workers.py added

This commit is contained in:
chun 2023-12-22 00:36:18 +08:00
parent b62dc1b83f
commit 9d6480311c
3 changed files with 37 additions and 6 deletions

31
best_number_workers.py Normal file
View File

@ -0,0 +1,31 @@
from time import time
import multiprocessing as mp
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, RandomSampler
if __name__ == '__main__':
transform = transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
trainset = torchvision.datasets.MNIST(
root='./Dataset/MNIST/',
train=True, # 如果为True从 training.pt 创建数据,否则从 test.pt 创建数据。
download=True, # 如果为true则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
transform=transform
)
print(f"num of CPU: {mp.cpu_count()}")
for num_workers in range(2, mp.cpu_count(), 2):
train_loader = torch.utils.data.DataLoader(
trainset, shuffle=True, num_workers=num_workers, batch_size=64, pin_memory=True)
start = time()
for epoch in range(1, 3):
for i, data in enumerate(train_loader, 0):
pass
end = time()
print("Finish with:{} second, num_workers={}".format(end - start, num_workers))

View File

@ -1 +1 @@
# to be implemented
# to be implemented

View File

@ -23,12 +23,12 @@ def config_parser():
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
parser.add_argument('--epochs', default=100, type=int, help='number of epochs')
parser.add_argument('--batch_size', default=64, type=int, help='batch size')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list')
parser.add_argument('--ratio_list', default=[1/6, 1/12], type=list, help='ratio_list')
parser.add_argument('--num_workers', default=0, type=int, help='num_workers')
return parser.parse_args()
@ -44,20 +44,20 @@ def main():
def train(args: config_parser(), ratio: float, snr: float):
device = torch.device('cuda:1')
# load data
transform = transforms.Compose([transforms.ToTensor(), ])
train_dataset = datasets.CIFAR10(root='./Dataset/', train=True,
download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size)
train_loader = DataLoader(train_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers)
test_dataset = datasets.CIFAR10(root='./Dataset/', train=False,
download=True, transform=transform)
test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size)
print("training with ratio: {}, snr_db: {}, channel: {}".format(ratio, snr, args.channel))
image_fisrt = train_dataset.__getitem__(0)[0]
c = ratio2filtersize(image_fisrt, ratio)
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device)