loss function modified
This commit is contained in:
parent
18d5e0f8f5
commit
8b79e816d1
2
model.py
2
model.py
@ -87,8 +87,6 @@ class _Encoder(nn.Module):
|
|||||||
k = torch.tensor(k)
|
k = torch.tensor(k)
|
||||||
z_temp = z_hat.reshape(batch_size, 1, 1, -1)
|
z_temp = z_hat.reshape(batch_size, 1, 1, -1)
|
||||||
z_trans = z_hat.reshape(batch_size, 1, -1, 1)
|
z_trans = z_hat.reshape(batch_size, 1, -1, 1)
|
||||||
temp = z_temp@z_trans
|
|
||||||
temp = torch.sqrt((z_temp @ z_trans))
|
|
||||||
tensor = torch.sqrt(P * k) * z_hat / torch.sqrt((z_temp @ z_trans))
|
tensor = torch.sqrt(P * k) * z_hat / torch.sqrt((z_temp @ z_trans))
|
||||||
return tensor
|
return tensor
|
||||||
return _inner
|
return _inner
|
||||||
|
|||||||
8
train.py
8
train.py
@ -44,7 +44,7 @@ def main():
|
|||||||
|
|
||||||
def train(args: config_parser(), ratio: float, snr: float):
|
def train(args: config_parser(), ratio: float, snr: float):
|
||||||
|
|
||||||
device = torch.device('cuda:1')
|
device = torch.device('cuda')
|
||||||
# load data
|
# load data
|
||||||
transform = transforms.Compose([transforms.ToTensor(), ])
|
transform = transforms.Compose([transforms.ToTensor(), ])
|
||||||
train_dataset = datasets.CIFAR10(root='./Dataset/', train=True,
|
train_dataset = datasets.CIFAR10(root='./Dataset/', train=True,
|
||||||
@ -56,13 +56,13 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
download=True, transform=transform)
|
download=True, transform=transform)
|
||||||
test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size)
|
test_loader = RandomSampler(test_dataset, replacement=True, num_samples=args.batch_size)
|
||||||
|
|
||||||
print("training with ratio: {}, snr_db: {}, channel: {}".format(ratio, snr, args.channel))
|
print("training with ratio: {:2f}, snr_db: {}, channel: {}".format(ratio, snr, args.channel))
|
||||||
|
|
||||||
image_fisrt = train_dataset.__getitem__(0)[0]
|
image_fisrt = train_dataset.__getitem__(0)[0]
|
||||||
c = ratio2filtersize(image_fisrt, ratio)
|
c = ratio2filtersize(image_fisrt, ratio)
|
||||||
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device)
|
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr).cuda(device=device)
|
||||||
|
|
||||||
criterion = nn.MSELoss().cuda(device=device)
|
criterion = nn.MSELoss(reduction='sum').cuda(device=device)
|
||||||
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||||
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)
|
epoch_loop = tqdm(range(args.epochs), total=args.epochs, leave=False)
|
||||||
|
|
||||||
@ -72,7 +72,7 @@ def train(args: config_parser(), ratio: float, snr: float):
|
|||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
images = images.cuda(device=device)
|
images = images.cuda(device=device)
|
||||||
outputs = model(images)
|
outputs = model(images)
|
||||||
loss = criterion(outputs, images)
|
loss = criterion(outputs, images) / args.batch_size
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
run_loss += loss.item()
|
run_loss += loss.item()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user