add imagenet

This commit is contained in:
chun 2023-12-24 17:01:31 +08:00
parent 6fae67b556
commit 89659b8292

View File

@ -18,6 +18,7 @@ from utils import image_normalization
from fractions import Fraction
from dataset import Vanilla
def config_parser():
import argparse
parser = argparse.ArgumentParser()
@ -71,11 +72,11 @@ def train(args: config_parser(), ratio: float, snr: float):
elif args.dataset == 'imagenet':
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
train_dataset = datasets.ImageFolder(root='./Dataset/ImageNet/train',transform=transform)
train_dataset = datasets.ImageFolder(root='./Dataset/ImageNet/train', transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers)
test_dataset = datasets.ImageFolder(root='./Dataset/ImageNet/val', transform=transform)
test_dataset = Vanilla(root='./Dataset/ImageNet/val', transform=transform)
test_loader = DataLoader(test_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers)
else: