add imagenet
This commit is contained in:
parent
3ac47a6329
commit
6fae67b556
19
dataset.py
19
dataset.py
@ -1,4 +1,23 @@
|
|||||||
import os
|
import os
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class Vanilla(Dataset):
|
||||||
|
def __init__(self, root, transform=None):
|
||||||
|
self.root = root
|
||||||
|
self.transform = transform
|
||||||
|
self.imgs = os.listdir(root)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
img_path = os.path.join(self.root, self.imgs[index])
|
||||||
|
img = Image.open(img_path).convert('RGB')
|
||||||
|
if self.transform is not None:
|
||||||
|
img = self.transform(img)
|
||||||
|
return img, 0 # 0 is a fake label not important
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.imgs)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
2
train.py
2
train.py
@ -16,7 +16,7 @@ from model import DeepJSCC, ratio2filtersize
|
|||||||
from torch.nn.parallel import DataParallel
|
from torch.nn.parallel import DataParallel
|
||||||
from utils import image_normalization
|
from utils import image_normalization
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
|
from dataset import Vanilla
|
||||||
|
|
||||||
def config_parser():
|
def config_parser():
|
||||||
import argparse
|
import argparse
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user