add imagenet
This commit is contained in:
parent
6fae67b556
commit
89659b8292
3
train.py
3
train.py
@ -18,6 +18,7 @@ from utils import image_normalization
|
|||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from dataset import Vanilla
|
from dataset import Vanilla
|
||||||
|
|
||||||
|
|
||||||
def config_parser():
|
def config_parser():
|
||||||
import argparse
|
import argparse
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -75,7 +76,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
|
|
||||||
train_loader = DataLoader(train_dataset, shuffle=True,
|
train_loader = DataLoader(train_dataset, shuffle=True,
|
||||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||||
test_dataset = datasets.ImageFolder(root='./Dataset/ImageNet/val', transform=transform)
|
test_dataset = Vanilla(root='./Dataset/ImageNet/val', transform=transform)
|
||||||
test_loader = DataLoader(test_dataset, shuffle=True,
|
test_loader = DataLoader(test_dataset, shuffle=True,
|
||||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user