add imagenet
This commit is contained in:
parent
6fae67b556
commit
89659b8292
5
train.py
5
train.py
@ -18,6 +18,7 @@ from utils import image_normalization
|
|||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from dataset import Vanilla
|
from dataset import Vanilla
|
||||||
|
|
||||||
|
|
||||||
def config_parser():
|
def config_parser():
|
||||||
import argparse
|
import argparse
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -71,11 +72,11 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
elif args.dataset == 'imagenet':
|
elif args.dataset == 'imagenet':
|
||||||
transform = transforms.Compose(
|
transform = transforms.Compose(
|
||||||
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
|
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
|
||||||
train_dataset = datasets.ImageFolder(root='./Dataset/ImageNet/train',transform=transform)
|
train_dataset = datasets.ImageFolder(root='./Dataset/ImageNet/train', transform=transform)
|
||||||
|
|
||||||
train_loader = DataLoader(train_dataset, shuffle=True,
|
train_loader = DataLoader(train_dataset, shuffle=True,
|
||||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||||
test_dataset = datasets.ImageFolder(root='./Dataset/ImageNet/val', transform=transform)
|
test_dataset = Vanilla(root='./Dataset/ImageNet/val', transform=transform)
|
||||||
test_loader = DataLoader(test_dataset, shuffle=True,
|
test_loader = DataLoader(test_dataset, shuffle=True,
|
||||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user