train.py modified

This commit is contained in:
chun 2023-12-22 00:22:06 +08:00
parent 9c60ca0e2c
commit 9744980ede

View File

@ -44,7 +44,6 @@ def main():
def train(args: config_parser(), ratio: float, snr: float): def train(args: config_parser(), ratio: float, snr: float):
print("training with ratio: {}, snr_db: {}, channel: {}".format(ratio, snr, args.channel))
device = torch.device('cuda:1') device = torch.device('cuda:1')
# load data # load data
@ -56,6 +55,9 @@ def train(args: config_parser(), ratio: float, snr: float):
test_dataset = datasets.CIFAR10(root='./Dataset/', train=False, test_dataset = datasets.CIFAR10(root='./Dataset/', train=False,
download=True, transform=transform) download=True, transform=transform)
test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size) test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size)
print("training with ratio: {}, snr_db: {}, channel: {}".format(ratio, snr, args.channel))
image_fisrt = train_dataset.__getitem__(0)[0] image_fisrt = train_dataset.__getitem__(0)[0]
c = ratio2filtersize(image_fisrt, ratio) c = ratio2filtersize(image_fisrt, ratio)
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device) model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device)