add imagenet

This commit is contained in:
chun 2023-12-24 17:15:36 +08:00
parent 89659b8292
commit 57feb8ef75

View File

@ -127,6 +127,8 @@ def train(args: config_parser(), ratio: float, snr: float):
test_mse += loss.item()
model.train()
epoch_loop.set_postfix(loss=run_loss/len(train_loader), test_mse=test_mse/len(test_loader))
print("epoch: {}, loss: {:.4f}, test_mse: {:.4f}".format(
epoch, run_loss/len(train_loader), test_mse/len(test_loader)))
save_model(model, args.saved, args.saved +
'/{}_{}_{:.2f}_{:.2f}_{}.pth'.format(args.dataset, args.epochs, ratio, snr, c))