add imagenet
This commit is contained in:
parent
6fae67b556
commit
89659b8292
5
train.py
5
train.py
@ -18,6 +18,7 @@ from utils import image_normalization
|
||||
from fractions import Fraction
|
||||
from dataset import Vanilla
|
||||
|
||||
|
||||
def config_parser():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -71,11 +72,11 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
elif args.dataset == 'imagenet':
|
||||
transform = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
|
||||
train_dataset = datasets.ImageFolder(root='./Dataset/ImageNet/train',transform=transform)
|
||||
train_dataset = datasets.ImageFolder(root='./Dataset/ImageNet/train', transform=transform)
|
||||
|
||||
train_loader = DataLoader(train_dataset, shuffle=True,
|
||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||
test_dataset = datasets.ImageFolder(root='./Dataset/ImageNet/val', transform=transform)
|
||||
test_dataset = Vanilla(root='./Dataset/ImageNet/val', transform=transform)
|
||||
test_loader = DataLoader(test_dataset, shuffle=True,
|
||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||
else:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user