add imagenet

This commit is contained in:
chun 2023-12-24 17:01:31 +08:00
parent 6fae67b556
commit 89659b8292

View File

@ -18,6 +18,7 @@ from utils import image_normalization
from fractions import Fraction
from dataset import Vanilla
def config_parser():
import argparse
parser = argparse.ArgumentParser()
@ -75,7 +76,7 @@ def train(args: config_parser(), ratio: float, snr: float):
train_loader = DataLoader(train_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers)
test_dataset = datasets.ImageFolder(root='./Dataset/ImageNet/val', transform=transform)
test_dataset = Vanilla(root='./Dataset/ImageNet/val', transform=transform)
test_loader = DataLoader(test_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers)
else: