add imagenet
This commit is contained in:
parent
579d2b0d8e
commit
4f9c0063f7
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,4 +1,5 @@
|
||||
test.py
|
||||
*.pyc
|
||||
*.log
|
||||
Dataset/*
|
||||
Dataset/*
|
||||
*.ipynb
|
||||
2
eval.py
2
eval.py
@ -22,7 +22,7 @@ def config_parser():
|
||||
def main():
|
||||
args = config_parser()
|
||||
transform = transforms.Compose([transforms.ToTensor()])
|
||||
args.saved = './saved/model_cifar10_0.33_19.00_40.pth' # to be deleted
|
||||
# args.saved = './saved/model_cifar10_0.33_19.00_40.pth' # to be deleted
|
||||
test_image = Image.open(args.test_image)
|
||||
test_image.load()
|
||||
test_image = transform(test_image)
|
||||
|
||||
26
train.py
26
train.py
@ -39,6 +39,7 @@ def config_parser():
|
||||
parser.add_argument('--parallel', default=False, type=bool, help='parallel')
|
||||
parser.add_argument('--if_scheduler', default=True, type=bool, help='if_scheduler')
|
||||
parser.add_argument('--step_size', default=640, type=int, help='scheduler')
|
||||
parser.add_argument('--device', default='cuda:0', type=str, help='device')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -54,7 +55,7 @@ def main():
|
||||
|
||||
def train(args: config_parser(), ratio: float, snr: float):
|
||||
|
||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
|
||||
# load data
|
||||
if args.dataset == 'cifar10':
|
||||
transform = transforms.Compose([transforms.ToTensor(), ])
|
||||
@ -68,14 +69,13 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
test_loader = DataLoader(test_dataset, shuffle=True,
|
||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||
elif args.dataset == 'imagenet':
|
||||
transform = transforms.Compose([transforms.ToTensor(), ])
|
||||
train_dataset = datasets.ImageNet(root='./Dataset/', train=True,
|
||||
download=True, transform=transform)
|
||||
transform = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
|
||||
train_dataset = datasets.ImageFolder(root='./Dataset/ImageNet/train',transform=transform)
|
||||
|
||||
train_loader = DataLoader(train_dataset, shuffle=True,
|
||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||
test_dataset = datasets.ImageNet(root='./Dataset/', train=False,
|
||||
download=True, transform=transform)
|
||||
test_dataset = datasets.ImageNet(root='./Dataset/ImageNet/val', transform=transform)
|
||||
test_loader = DataLoader(test_dataset, shuffle=True,
|
||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||
else:
|
||||
@ -90,9 +90,11 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
|
||||
if args.parallel and torch.cuda.device_count() > 1:
|
||||
model = DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
|
||||
model = model.cuda()
|
||||
|
||||
criterion = nn.MSELoss(reduction='mean').cuda()
|
||||
model = model.cuda()
|
||||
criterion = nn.MSELoss(reduction='mean').cuda()
|
||||
else:
|
||||
model = model.to(device)
|
||||
criterion = nn.MSELoss(reduction='mean').to(device)
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
if args.if_scheduler:
|
||||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.1)
|
||||
@ -102,7 +104,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
run_loss = 0.0
|
||||
for images, _ in tqdm((train_loader), leave=False):
|
||||
optimizer.zero_grad()
|
||||
images = images.cuda()
|
||||
images = images.cuda() if args.parallel else images.to(device)
|
||||
outputs = model(images)
|
||||
outputs = image_normalization('denormalization')(outputs)
|
||||
images = image_normalization('denormalization')(images)
|
||||
@ -116,7 +118,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
model.eval()
|
||||
test_mse = 0.0
|
||||
for images, _ in tqdm((test_loader), leave=False):
|
||||
images = images.cuda()
|
||||
images = images if args.parallel else images.to(device)
|
||||
outputs = model(images)
|
||||
images = image_normalization('denormalization')(images)
|
||||
outputs = image_normalization('denormalization')(outputs)
|
||||
@ -125,7 +127,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
model.train()
|
||||
epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
|
||||
save_model(model, args.saved, args.saved +
|
||||
'/model_{}_{}_{:.2f}_{:.2f}_{}.pth'.format(args.dataset, args.epochs,ratio, snr, c))
|
||||
'/{}_{}_{:.2f}_{:.2f}_{}.pth'.format(args.dataset, args.epochs, ratio, snr, c))
|
||||
|
||||
|
||||
def save_model(model, dir, path):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user