softlink dataset
This commit is contained in:
parent
52b7efb407
commit
14cd4ed4e4
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,9 +1,10 @@
|
||||
test.py
|
||||
*.pyc
|
||||
*.log
|
||||
Dataset/*
|
||||
dataset
|
||||
*.ipynb
|
||||
*.swp
|
||||
.vscode/*
|
||||
input.txt
|
||||
output.txt
|
||||
*.json
|
||||
@ -12,7 +12,7 @@ if __name__ == '__main__':
|
||||
])
|
||||
|
||||
trainset = torchvision.datasets.CIFAR10(
|
||||
root='./Dataset/',
|
||||
root='./dataset/',
|
||||
train=True, # 如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。
|
||||
download=True, # 如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
|
||||
transform=transform
|
||||
|
||||
12
dataset.py
12
dataset.py
@ -21,21 +21,21 @@ class Vanilla(Dataset):
|
||||
|
||||
|
||||
def main():
|
||||
data_path = './Dataset'
|
||||
data_path = './dataset'
|
||||
os.makedirs(data_path, exist_ok=True)
|
||||
# ILSVRC2012_img_train.tar and ILSVRC2012_img_val.tar should be downloaded from https://image-net.org/
|
||||
if not os.path.exists('./Dataset/ILSVRC2012_img_train.tar') or not os.path.exists('./Dataset/ILSVRC2012_img_val.tar'):
|
||||
if not os.path.exists('./dataset/ILSVRC2012_img_train.tar') or not os.path.exists('./dataset/ILSVRC2012_img_val.tar'):
|
||||
print('ILSVRC2012_img_train.tar and ILSVRC2012_img_val.tar should be downloaded from https://image-net.org/')
|
||||
print('Please download the dataset from https://image-net.org/challenges/LSVRC/2012/2012-downloads and put it in ./Dataset')
|
||||
print('Please download the dataset from https://image-net.org/challenges/LSVRC/2012/2012-downloads and put it in ./dataset')
|
||||
raise Exception('not find dataset')
|
||||
phases = ['train', 'val']
|
||||
for phase in phases:
|
||||
print("extracting {} dataset".format(phase))
|
||||
path = './Dataset/ImageNet/{}'.format(phase)
|
||||
path = './dataset/ImageNet/{}'.format(phase)
|
||||
print('path is {}'.format(path))
|
||||
os.makedirs(path, exist_ok=True)
|
||||
print('tar -xf ./Dataset/ILSVRC2012_img_{}.tar -C {}'.format(phase, path))
|
||||
os.system('tar -xf ./Dataset/ILSVRC2012_img_{}.tar -C {}'.format(phase, path))
|
||||
print('tar -xf ./dataset/ILSVRC2012_img_{}.tar -C {}'.format(phase, path))
|
||||
os.system('tar -xf ./dataset/ILSVRC2012_img_{}.tar -C {}'.format(phase, path))
|
||||
if phase == 'train':
|
||||
for tar in os.listdir(path):
|
||||
print('tar -xf {}/{} -C {}/{}'.format(path, tar, path, tar.split('.')[0]))
|
||||
|
||||
8
train.py
8
train.py
@ -72,12 +72,12 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
# load data
|
||||
if args.dataset == 'cifar10':
|
||||
transform = transforms.Compose([transforms.ToTensor(), ])
|
||||
train_dataset = datasets.CIFAR10(root='./Dataset/', train=True,
|
||||
train_dataset = datasets.CIFAR10(root='./dataset/', train=True,
|
||||
download=True, transform=transform)
|
||||
|
||||
train_loader = DataLoader(train_dataset, shuffle=True,
|
||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||
test_dataset = datasets.CIFAR10(root='./Dataset/', train=False,
|
||||
test_dataset = datasets.CIFAR10(root='./dataset/', train=False,
|
||||
download=True, transform=transform)
|
||||
test_loader = DataLoader(test_dataset, shuffle=True,
|
||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||
@ -85,11 +85,11 @@ def train(args: config_parser(), ratio: float, snr: float):
|
||||
transform = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
|
||||
print("loading data of imagenet")
|
||||
train_dataset = datasets.ImageFolder(root='./Dataset/ImageNet/train', transform=transform)
|
||||
train_dataset = datasets.ImageFolder(root='./dataset/ImageNet/train', transform=transform)
|
||||
|
||||
train_loader = DataLoader(train_dataset, shuffle=True,
|
||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||
test_dataset = Vanilla(root='./Dataset/ImageNet/val', transform=transform)
|
||||
test_dataset = Vanilla(root='./dataset/ImageNet/val', transform=transform)
|
||||
test_loader = DataLoader(test_dataset, shuffle=True,
|
||||
batch_size=args.batch_size, num_workers=args.num_workers)
|
||||
else:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user