softlink dataset
This commit is contained in:
parent
52b7efb407
commit
14cd4ed4e4
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,9 +1,10 @@
|
|||||||
test.py
|
test.py
|
||||||
*.pyc
|
*.pyc
|
||||||
*.log
|
*.log
|
||||||
Dataset/*
|
dataset
|
||||||
*.ipynb
|
*.ipynb
|
||||||
*.swp
|
*.swp
|
||||||
.vscode/*
|
.vscode/*
|
||||||
input.txt
|
input.txt
|
||||||
output.txt
|
output.txt
|
||||||
|
*.json
|
||||||
@ -12,7 +12,7 @@ if __name__ == '__main__':
|
|||||||
])
|
])
|
||||||
|
|
||||||
trainset = torchvision.datasets.CIFAR10(
|
trainset = torchvision.datasets.CIFAR10(
|
||||||
root='./Dataset/',
|
root='./dataset/',
|
||||||
train=True, # 如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。
|
train=True, # 如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。
|
||||||
download=True, # 如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
|
download=True, # 如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
|
||||||
transform=transform
|
transform=transform
|
||||||
|
|||||||
12
dataset.py
12
dataset.py
@ -21,21 +21,21 @@ class Vanilla(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
data_path = './Dataset'
|
data_path = './dataset'
|
||||||
os.makedirs(data_path, exist_ok=True)
|
os.makedirs(data_path, exist_ok=True)
|
||||||
# ILSVRC2012_img_train.tar and ILSVRC2012_img_val.tar should be downloaded from https://image-net.org/
|
# ILSVRC2012_img_train.tar and ILSVRC2012_img_val.tar should be downloaded from https://image-net.org/
|
||||||
if not os.path.exists('./Dataset/ILSVRC2012_img_train.tar') or not os.path.exists('./Dataset/ILSVRC2012_img_val.tar'):
|
if not os.path.exists('./dataset/ILSVRC2012_img_train.tar') or not os.path.exists('./dataset/ILSVRC2012_img_val.tar'):
|
||||||
print('ILSVRC2012_img_train.tar and ILSVRC2012_img_val.tar should be downloaded from https://image-net.org/')
|
print('ILSVRC2012_img_train.tar and ILSVRC2012_img_val.tar should be downloaded from https://image-net.org/')
|
||||||
print('Please download the dataset from https://image-net.org/challenges/LSVRC/2012/2012-downloads and put it in ./Dataset')
|
print('Please download the dataset from https://image-net.org/challenges/LSVRC/2012/2012-downloads and put it in ./dataset')
|
||||||
raise Exception('not find dataset')
|
raise Exception('not find dataset')
|
||||||
phases = ['train', 'val']
|
phases = ['train', 'val']
|
||||||
for phase in phases:
|
for phase in phases:
|
||||||
print("extracting {} dataset".format(phase))
|
print("extracting {} dataset".format(phase))
|
||||||
path = './Dataset/ImageNet/{}'.format(phase)
|
path = './dataset/ImageNet/{}'.format(phase)
|
||||||
print('path is {}'.format(path))
|
print('path is {}'.format(path))
|
||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
print('tar -xf ./Dataset/ILSVRC2012_img_{}.tar -C {}'.format(phase, path))
|
print('tar -xf ./dataset/ILSVRC2012_img_{}.tar -C {}'.format(phase, path))
|
||||||
os.system('tar -xf ./Dataset/ILSVRC2012_img_{}.tar -C {}'.format(phase, path))
|
os.system('tar -xf ./dataset/ILSVRC2012_img_{}.tar -C {}'.format(phase, path))
|
||||||
if phase == 'train':
|
if phase == 'train':
|
||||||
for tar in os.listdir(path):
|
for tar in os.listdir(path):
|
||||||
print('tar -xf {}/{} -C {}/{}'.format(path, tar, path, tar.split('.')[0]))
|
print('tar -xf {}/{} -C {}/{}'.format(path, tar, path, tar.split('.')[0]))
|
||||||
|
|||||||
8
train.py
8
train.py
@ -72,12 +72,12 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
# load data
|
# load data
|
||||||
if args.dataset == 'cifar10':
|
if args.dataset == 'cifar10':
|
||||||
transform = transforms.Compose([transforms.ToTensor(), ])
|
transform = transforms.Compose([transforms.ToTensor(), ])
|
||||||
train_dataset = datasets.CIFAR10(root='./Dataset/', train=True,
|
train_dataset = datasets.CIFAR10(root='./dataset/', train=True,
|
||||||
download=True, transform=transform)
|
download=True, transform=transform)
|
||||||
|
|
||||||
train_loader = DataLoader(train_dataset, shuffle=True,
|
train_loader = DataLoader(train_dataset, shuffle=True,
|
||||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||||
test_dataset = datasets.CIFAR10(root='./Dataset/', train=False,
|
test_dataset = datasets.CIFAR10(root='./dataset/', train=False,
|
||||||
download=True, transform=transform)
|
download=True, transform=transform)
|
||||||
test_loader = DataLoader(test_dataset, shuffle=True,
|
test_loader = DataLoader(test_dataset, shuffle=True,
|
||||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||||
@ -85,11 +85,11 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
transform = transforms.Compose(
|
transform = transforms.Compose(
|
||||||
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
|
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
|
||||||
print("loading data of imagenet")
|
print("loading data of imagenet")
|
||||||
train_dataset = datasets.ImageFolder(root='./Dataset/ImageNet/train', transform=transform)
|
train_dataset = datasets.ImageFolder(root='./dataset/ImageNet/train', transform=transform)
|
||||||
|
|
||||||
train_loader = DataLoader(train_dataset, shuffle=True,
|
train_loader = DataLoader(train_dataset, shuffle=True,
|
||||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||||
test_dataset = Vanilla(root='./Dataset/ImageNet/val', transform=transform)
|
test_dataset = Vanilla(root='./dataset/ImageNet/val', transform=transform)
|
||||||
test_loader = DataLoader(test_dataset, shuffle=True,
|
test_loader = DataLoader(test_dataset, shuffle=True,
|
||||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user