From 9744980eded4bc167fb2d3990252f7b715058a6e Mon Sep 17 00:00:00 2001 From: chun Date: Fri, 22 Dec 2023 00:22:06 +0800 Subject: [PATCH] train.py modified --- train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 57b1b72..a8862fc 100644 --- a/train.py +++ b/train.py @@ -44,7 +44,6 @@ def main(): def train(args: config_parser(), ratio: float, snr: float): - print("training with ratio: {}, snr_db: {}, channel: {}".format(ratio, snr, args.channel)) device = torch.device('cuda:1') # load data @@ -56,6 +55,9 @@ def train(args: config_parser(), ratio: float, snr: float): test_dataset = datasets.CIFAR10(root='./Dataset/', train=False, download=True, transform=transform) test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size) + + print("training with ratio: {}, snr_db: {}, channel: {}".format(ratio, snr, args.channel)) + image_fisrt = train_dataset.__getitem__(0)[0] c = ratio2filtersize(image_fisrt, ratio) model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device)