softlink dataset

This commit is contained in:
chun 2024-01-23 22:18:11 +08:00
parent 52b7efb407
commit 14cd4ed4e4
4 changed files with 13 additions and 12 deletions

3
.gitignore vendored
View File

@ -1,9 +1,10 @@
test.py
*.pyc
*.log
Dataset/*
dataset
*.ipynb
*.swp
.vscode/*
input.txt
output.txt
*.json

View File

@ -12,7 +12,7 @@ if __name__ == '__main__':
])
trainset = torchvision.datasets.CIFAR10(
root='./Dataset/',
root='./dataset/',
train=True, # 如果为True从 training.pt 创建数据,否则从 test.pt 创建数据。
download=True, # 如果为true则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
transform=transform

View File

@ -21,21 +21,21 @@ class Vanilla(Dataset):
def main():
data_path = './Dataset'
data_path = './dataset'
os.makedirs(data_path, exist_ok=True)
# ILSVRC2012_img_train.tar and ILSVRC2012_img_val.tar should be downloaded from https://image-net.org/
if not os.path.exists('./Dataset/ILSVRC2012_img_train.tar') or not os.path.exists('./Dataset/ILSVRC2012_img_val.tar'):
if not os.path.exists('./dataset/ILSVRC2012_img_train.tar') or not os.path.exists('./dataset/ILSVRC2012_img_val.tar'):
print('ILSVRC2012_img_train.tar and ILSVRC2012_img_val.tar should be downloaded from https://image-net.org/')
print('Please download the dataset from https://image-net.org/challenges/LSVRC/2012/2012-downloads and put it in ./Dataset')
print('Please download the dataset from https://image-net.org/challenges/LSVRC/2012/2012-downloads and put it in ./dataset')
raise Exception('not find dataset')
phases = ['train', 'val']
for phase in phases:
print("extracting {} dataset".format(phase))
path = './Dataset/ImageNet/{}'.format(phase)
path = './dataset/ImageNet/{}'.format(phase)
print('path is {}'.format(path))
os.makedirs(path, exist_ok=True)
print('tar -xf ./Dataset/ILSVRC2012_img_{}.tar -C {}'.format(phase, path))
os.system('tar -xf ./Dataset/ILSVRC2012_img_{}.tar -C {}'.format(phase, path))
print('tar -xf ./dataset/ILSVRC2012_img_{}.tar -C {}'.format(phase, path))
os.system('tar -xf ./dataset/ILSVRC2012_img_{}.tar -C {}'.format(phase, path))
if phase == 'train':
for tar in os.listdir(path):
print('tar -xf {}/{} -C {}/{}'.format(path, tar, path, tar.split('.')[0]))

View File

@ -72,12 +72,12 @@ def train(args: config_parser(), ratio: float, snr: float):
# load data
if args.dataset == 'cifar10':
transform = transforms.Compose([transforms.ToTensor(), ])
train_dataset = datasets.CIFAR10(root='./Dataset/', train=True,
train_dataset = datasets.CIFAR10(root='./dataset/', train=True,
download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers)
test_dataset = datasets.CIFAR10(root='./Dataset/', train=False,
test_dataset = datasets.CIFAR10(root='./dataset/', train=False,
download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers)
@ -85,11 +85,11 @@ def train(args: config_parser(), ratio: float, snr: float):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Resize((128, 128))]) # the size of paper is 128
print("loading data of imagenet")
train_dataset = datasets.ImageFolder(root='./Dataset/ImageNet/train', transform=transform)
train_dataset = datasets.ImageFolder(root='./dataset/ImageNet/train', transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers)
test_dataset = Vanilla(root='./Dataset/ImageNet/val', transform=transform)
test_dataset = Vanilla(root='./dataset/ImageNet/val', transform=transform)
test_loader = DataLoader(test_dataset, shuffle=True,
batch_size=args.batch_size, num_workers=args.num_workers)
else: