add imagenet

This commit is contained in:
chun 2023-12-24 17:01:15 +08:00
parent 3ac47a6329
commit 6fae67b556
2 changed files with 20 additions and 1 deletions

View File

@ -1,4 +1,23 @@
import os
from torch.utils.data import Dataset
from PIL import Image
class Vanilla(Dataset):
def __init__(self, root, transform=None):
self.root = root
self.transform = transform
self.imgs = os.listdir(root)
def __getitem__(self, index):
img_path = os.path.join(self.root, self.imgs[index])
img = Image.open(img_path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img, 0 # 0 is a fake label not important
def __len__(self):
return len(self.imgs)
def main():

View File

@ -16,7 +16,7 @@ from model import DeepJSCC, ratio2filtersize
from torch.nn.parallel import DataParallel
from utils import image_normalization
from fractions import Fraction
from dataset import Vanilla
def config_parser():
import argparse