JSCC/dataset.py
2024-02-03 20:05:40 +08:00

49 lines
2.0 KiB
Python

import os
from torch.utils.data import Dataset
from PIL import Image
class Vanilla(Dataset):
def __init__(self, root, transform=None):
self.root = root
self.transform = transform
self.imgs = os.listdir(root)
def __getitem__(self, index):
img_path = os.path.join(self.root, self.imgs[index])
img = Image.open(img_path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img, 0 # 0 is a fake label not important
def __len__(self):
return len(self.imgs)
def main():
data_path = './dataset'
os.makedirs(data_path, exist_ok=True)
# ILSVRC2012_img_train.tar and ILSVRC2012_img_val.tar should be downloaded from https://image-net.org/
if not os.path.exists('./dataset/ILSVRC2012_img_train.tar') or not os.path.exists('./dataset/ILSVRC2012_img_val.tar'):
print('ILSVRC2012_img_train.tar and ILSVRC2012_img_val.tar should be downloaded from https://image-net.org/')
print('Please download the dataset from https://image-net.org/challenges/LSVRC/2012/2012-downloads and put it in ./dataset')
raise Exception('not find dataset')
phases = ['train', 'val']
for phase in phases:
print("extracting {} dataset".format(phase))
path = './dataset/ImageNet/{}'.format(phase)
print('path is {}'.format(path))
os.makedirs(path, exist_ok=True)
print('tar -xf ./dataset/ILSVRC2012_img_{}.tar -C {}'.format(phase, path))
os.system('tar -xf ./dataset/ILSVRC2012_img_{}.tar -C {}'.format(phase, path))
if phase == 'train':
for tar in os.listdir(path):
print('tar -xf {}/{} -C {}/{}'.format(path, tar, path, tar.split('.')[0]))
os.makedirs('{}/{}'.format(path, tar.split('.')[0]), exist_ok=True)
os.system('tar -xf {}/{} -C {}/{}'.format(path, tar, path, tar.split('.')[0]))
os.remove('{}/{}'.format(path, tar))
if __name__ == '__main__':
main()