loss function modified

This commit is contained in:
chun 2023-12-23 00:16:27 +08:00
parent 18d5e0f8f5
commit 8b79e816d1
2 changed files with 4 additions and 6 deletions

View File

@ -87,8 +87,6 @@ class _Encoder(nn.Module):
k = torch.tensor(k)
z_temp = z_hat.reshape(batch_size, 1, 1, -1)
z_trans = z_hat.reshape(batch_size, 1, -1, 1)
temp = z_temp@z_trans
temp = torch.sqrt((z_temp @ z_trans))
tensor = torch.sqrt(P * k) * z_hat / torch.sqrt((z_temp @ z_trans))
return tensor
return _inner

View File

@ -44,7 +44,7 @@ def main():
def train(args: config_parser(), ratio: float, snr: float):
device = torch.device('cuda:1')
device = torch.device('cuda')
# load data
transform = transforms.Compose([transforms.ToTensor(), ])
train_dataset = datasets.CIFAR10(root='./Dataset/', train=True,
@ -56,13 +56,13 @@ def train(args: config_parser(), ratio: float, snr: float):
download=True, transform=transform)
test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size)
print("training with ratio: {}, snr_db: {}, channel: {}".format(ratio, snr, args.channel))
print("training with ratio: {:2f}, snr_db: {}, channel: {}".format(ratio, snr, args.channel))
image_fisrt = train_dataset.__getitem__(0)[0]
c = ratio2filtersize(image_fisrt, ratio)
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device)
criterion = nn.MSELoss().cuda(device=device)
criterion = nn.MSELoss(reduction='sum').cuda(device=device)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)
@ -72,7 +72,7 @@ def train(args: config_parser(), ratio: float, snr: float):
optimizer.zero_grad()
images = images.cuda(device=device)
outputs = model(images)
loss = criterion(outputs, images)
loss = criterion(outputs, images) / args.batch_size
loss.backward()
optimizer.step()
run_loss += loss.item()