add imagenet

This commit is contained in:
chun 2023-12-24 16:11:30 +08:00
parent 579d2b0d8e
commit 4f9c0063f7
3 changed files with 17 additions and 14 deletions

3
.gitignore vendored
View File

@ -1,4 +1,5 @@
test.py
*.pyc
*.log
Dataset/*
Dataset/*
*.ipynb

View File

@ -22,7 +22,7 @@ def config_parser():
def main():
args = config_parser()
transform = transforms.Compose([transforms.ToTensor()])
args.saved = './saved/model_cifar10_0.33_19.00_40.pth' # to be deleted
# args.saved = './saved/model_cifar10_0.33_19.00_40.pth' # to be deleted
test_image = Image.open(args.test_image)
test_image.load()
test_image = transform(test_image)

View File

@ -39,6 +39,7 @@ def config_parser():
parser.add_argument('--parallel', default=False, type=bool, help='parallel')
parser.add_argument('--if_scheduler', default=True, type=bool, help='if_scheduler')
parser.add_argument('--step_size', default=640, type=int, help='scheduler')
parser.add_argument('--device', default='cuda:0', type=str, help='device')
return parser.parse_args()
@ -54,7 +55,7 @@ def main():
def train(args: config_parser(), ratio: float, snr: float):
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
# load data
if args.dataset == 'cifar10':
transform = transforms.Compose([transforms.ToTensor(), ])
@ -68,14 +69,13 @@ def train(args: config_parser(), ratio: float, snr: float):
test_loader = DataLoader(test_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers)
elif args.dataset == 'imagenet':
transform = transforms.Compose([transforms.ToTensor(), ])
train_dataset = datasets.ImageNet(root='./Dataset/', train=True,
download=True, transform=transform)
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
train_dataset = datasets.ImageFolder(root='./Dataset/ImageNet/train',transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers)
test_dataset = datasets.ImageNet(root='./Dataset/', train=False,
download=True, transform=transform)
test_dataset = datasets.ImageNet(root='./Dataset/ImageNet/val', transform=transform)
test_loader = DataLoader(test_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers)
else:
@ -90,9 +90,11 @@ def train(args: config_parser(), ratio: float, snr: float):
if args.parallel and torch.cuda.device_count() > 1:
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
model = model.cuda()
criterion = nn.MSELoss(reduction='mean').cuda()
model = model.cuda()
criterion = nn.MSELoss(reduction='mean').cuda()
else:
model = model.to(device)
criterion = nn.MSELoss(reduction='mean').to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
if args.if_scheduler:
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.1)
@ -102,7 +104,7 @@ def train(args: config_parser(), ratio: float, snr: float):
run_loss = 0.0
for images, _ in tqdm((train_loader), leave=False):
optimizer.zero_grad()
images = images.cuda()
images = images.cuda() if args.parallel else images.to(device)
outputs = model(images)
outputs = image_normalization('denormalization')(outputs)
images = image_normalization('denormalization')(images)
@ -116,7 +118,7 @@ def train(args: config_parser(), ratio: float, snr: float):
model.eval()
test_mse = 0.0
for images, _ in tqdm((test_loader), leave=False):
images = images.cuda()
images = images if args.parallel else images.to(device)
outputs = model(images)
images = image_normalization('denormalization')(images)
outputs = image_normalization('denormalization')(outputs)
@ -125,7 +127,7 @@ def train(args: config_parser(), ratio: float, snr: float):
model.train()
epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
save_model(model, args.saved, args.saved +
'/model_{}_{}_{:.2f}_{:.2f}_{}.pth'.format(args.dataset, args.epochs,ratio, snr, c))
'/{}_{}_{:.2f}_{:.2f}_{}.pth'.format(args.dataset, args.epochs, ratio, snr, c))
def save_model(model, dir, path):