train.py modified
This commit is contained in:
parent
9c60ca0e2c
commit
9744980ede
4
train.py
4
train.py
@ -44,7 +44,6 @@ def main():
|
|||||||
|
|
||||||
def train(args: config_parser(), ratio: float, snr: float):
|
def train(args: config_parser(), ratio: float, snr: float):
|
||||||
|
|
||||||
print("training with ratio: {}, snr_db: {}, channel: {}".format(ratio, snr, args.channel))
|
|
||||||
|
|
||||||
device = torch.device('cuda:1')
|
device = torch.device('cuda:1')
|
||||||
# load data
|
# load data
|
||||||
@ -56,6 +55,9 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
test_dataset = datasets.CIFAR10(root='./Dataset/', train=False,
|
test_dataset = datasets.CIFAR10(root='./Dataset/', train=False,
|
||||||
download=True, transform=transform)
|
download=True, transform=transform)
|
||||||
test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size)
|
test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size)
|
||||||
|
|
||||||
|
print("training with ratio: {}, snr_db: {}, channel: {}".format(ratio, snr, args.channel))
|
||||||
|
|
||||||
image_fisrt = train_dataset.__getitem__(0)[0]
|
image_fisrt = train_dataset.__getitem__(0)[0]
|
||||||
c = ratio2filtersize(image_fisrt, ratio)
|
c = ratio2filtersize(image_fisrt, ratio)
|
||||||
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device)
|
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user