v1.0
This commit is contained in:
parent
78ff76955c
commit
a7900ec006
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,2 +1,3 @@
|
||||
test.py
|
||||
*.pyc
|
||||
*.log
|
||||
21
channel.py
21
channel.py
@ -1,9 +1,22 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
def AWGN_channel():
|
||||
pass
|
||||
def channel(channel_type='AWGN', snr=20):
|
||||
def AWGN_channel(z_hat: torch.Tensor):
|
||||
k = np.prod(z_hat.size()[1:])
|
||||
sig_pwr = torch.sum(torch.abs(z_hat).square(), dim=(1, 2, 3), keepdim=True)
|
||||
noi_pwr = sig_pwr / (k * 10 ** (snr / 10))
|
||||
noise = torch.randn_like(z_hat) * torch.sqrt(noi_pwr)
|
||||
return z_hat + noise
|
||||
|
||||
def Rayleigh_channel():
|
||||
pass
|
||||
def Rayleigh_channel(z_hat: torch.Tensor):
|
||||
pass
|
||||
|
||||
if channel_type == 'AWGN':
|
||||
return AWGN_channel
|
||||
elif channel_type == 'Rayleigh':
|
||||
return Rayleigh_channel
|
||||
else:
|
||||
raise Exception('Unknown type of channel')
|
||||
|
||||
106
model.py
106
model.py
@ -9,10 +9,30 @@ import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import channel
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def _image_normalization(norm_type):
|
||||
def _inner(tensor: torch.Tensor):
|
||||
if norm_type == 'nomalization':
|
||||
return tensor / 255.0
|
||||
elif norm_type == 'denormalization':
|
||||
return (tensor * 255.0).type(torch.FloatTensor)
|
||||
else:
|
||||
raise Exception('Unknown type of normalization')
|
||||
return _inner
|
||||
|
||||
|
||||
def ratio2filtersize(x, ratio):
|
||||
before_size = np.prod(x.size())
|
||||
encoder_temp = _Encoder(is_temp=True)
|
||||
z_temp = encoder_temp(x)
|
||||
c = before_size * ratio / np.prod(z_temp.size()[-2:])
|
||||
return int(c)
|
||||
|
||||
|
||||
class _ConvWithPReLU(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
|
||||
super(_ConvWithPReLU, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
|
||||
self.prelu = nn.PReLU()
|
||||
@ -24,9 +44,10 @@ class _ConvWithPReLU(nn.Module):
|
||||
|
||||
|
||||
class _TransConvWithPReLU(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, activate=nn.PReLU(), padding=0):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, activate=nn.PReLU(), padding=0, output_padding=0):
|
||||
super(_TransConvWithPReLU, self).__init__()
|
||||
self.transconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding)
|
||||
self.transconv = nn.ConvTranspose2d(
|
||||
in_channels, out_channels, kernel_size, stride, padding, output_padding)
|
||||
self.activate = activate
|
||||
|
||||
def forward(self, x):
|
||||
@ -35,37 +56,30 @@ class _TransConvWithPReLU(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def _image_normalization(tensor, norm_type):
|
||||
if norm_type == 'nomalization':
|
||||
return tensor / 255.0
|
||||
elif norm_type == 'denormalization':
|
||||
return tensor * 255.0
|
||||
else:
|
||||
raise Exception('Unknown type of normalization')
|
||||
|
||||
|
||||
def _NormlizationLayer(norm_type='nomalization'):
|
||||
pass
|
||||
|
||||
|
||||
def ratio2filter_size(x, ratio):
|
||||
before_size = np.prod(x.size())
|
||||
after_size = before_size*ratio
|
||||
encoder_temp = Encoder(is_temp=True)
|
||||
x_temp = encoder_temp(x)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
class _Encoder(nn.Module):
|
||||
def __init__(self, c=1, is_temp=False):
|
||||
super(Encoder, self).__init__()
|
||||
super(_Encoder, self).__init__()
|
||||
self.is_temp = is_temp
|
||||
self.imgae_normalization = _image_normalization(norm_type='nomalization')
|
||||
self.conv1 = _ConvWithPReLU(in_channels=3, out_channels=16, kernel_size=5, stride=2)
|
||||
self.conv2 = _ConvWithPReLU(in_channels=16, out_channels=32, kernel_size=5, stride=2)
|
||||
self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, stride=1)
|
||||
self.conv4 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, stride=1)
|
||||
self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=c, kernel_size=5, stride=1)
|
||||
self.norm = _NormlizationLayer(norm_type='nomalization')
|
||||
self.conv3 = _ConvWithPReLU(in_channels=32, out_channels=32,
|
||||
kernel_size=5, padding=2) # padding size could be changed here
|
||||
self.conv4 = _ConvWithPReLU(in_channels=32, out_channels=32, kernel_size=5, padding=2)
|
||||
self.conv5 = _ConvWithPReLU(in_channels=32, out_channels=c, kernel_size=5, padding=2)
|
||||
self.norm = self._normlizationLayer()
|
||||
|
||||
@staticmethod
|
||||
def _normlizationLayer(P=1):
|
||||
def _inner(z_hat: torch.Tensor):
|
||||
batch_size = z_hat.size()[0]
|
||||
k = np.prod(z_hat.size()[1:])
|
||||
k = torch.tensor(k)
|
||||
z_temp = z_hat.reshape(batch_size, 1, 1, -1)
|
||||
z_trans = z_hat.reshape(batch_size, 1, -1, 1)
|
||||
tensor = torch.sqrt(P * k) * z_hat / (z_temp @ z_trans)
|
||||
return tensor
|
||||
return _inner
|
||||
|
||||
def forward(self, x):
|
||||
x = self.imgae_normalization(x)
|
||||
@ -75,26 +89,46 @@ class Encoder(nn.Module):
|
||||
x = self.conv4(x)
|
||||
if not self.is_temp:
|
||||
x = self.conv5(x)
|
||||
|
||||
z = self.norm(x)
|
||||
del x
|
||||
return z
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self):
|
||||
super(Decoder, self).__init__()
|
||||
class _Decoder(nn.Module):
|
||||
def __init__(self, c=1):
|
||||
super(_Decoder, self).__init__()
|
||||
self.imgae_normalization = _image_normalization(norm_type='denormalization')
|
||||
self.tconv1 = _TransConvWithPReLU(
|
||||
in_channels=c, out_channels=32, kernel_size=5, stride=1, padding=2)
|
||||
self.tconv2 = _TransConvWithPReLU(
|
||||
in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2)
|
||||
self.tconv3 = _TransConvWithPReLU(
|
||||
in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2)
|
||||
self.tconv4 = _TransConvWithPReLU(in_channels=32, out_channels=16, kernel_size=6, stride=2)
|
||||
self.tconv5 = _TransConvWithPReLU(
|
||||
in_channels=16, out_channels=3, kernel_size=6, stride=2, activate=nn.Sigmoid())
|
||||
# may be some problems in tconv4 and tconv5, the kernal_size is not the same as the paper which is 5
|
||||
|
||||
def forward(self, x):
|
||||
pass
|
||||
x = self.tconv1(x)
|
||||
x = self.tconv2(x)
|
||||
x = self.tconv3(x)
|
||||
x = self.tconv4(x)
|
||||
x = self.tconv5(x)
|
||||
x = self.imgae_normalization(x)
|
||||
return x
|
||||
|
||||
|
||||
class DeepJSCC(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, c, channel_type='AWGN', snr=20):
|
||||
super(DeepJSCC, self).__init__()
|
||||
self.encoder = Encoder()
|
||||
self.decoder = Decoder()
|
||||
self.encoder = _Encoder(c=c)
|
||||
self.channel = channel.channel(channel_type,snr)
|
||||
self.decoder = _Decoder(c=c)
|
||||
|
||||
def forward(self, x):
|
||||
z = self.encoder(x)
|
||||
z = self.channel(z)
|
||||
x_hat = self.decoder(z)
|
||||
return x_hat
|
||||
|
||||
56
train.py
56
train.py
@ -1,6 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Tue Dec 11:00:00 2023
|
||||
Created on Tue Dec 17:00:00 2023
|
||||
|
||||
@author: chun
|
||||
"""
|
||||
@ -10,10 +10,11 @@ import torch.nn as nn
|
||||
from torchvision import transforms
|
||||
from torchvision import datasets
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import tqdm
|
||||
from model import DeepJSCC
|
||||
from model import DeepJSCC, ratio2filtersize
|
||||
from torch.nn.parallel import DataParallel
|
||||
from channel import channel
|
||||
|
||||
|
||||
def config_parser():
|
||||
@ -21,31 +22,64 @@ def config_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--seed', default=2048, type=int, help='Random seed')
|
||||
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
|
||||
parser.add_argument('--batch_size', default=128, type=int, help='batch size')
|
||||
parser.add_argument('optimizer', default='Adam', type=str, help='optimizer')
|
||||
parser.add_argument('--epochs', default=100, type=int, help='number of epochs')
|
||||
parser.add_argument('--batch_size', default=64, type=int, help='batch size')
|
||||
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
|
||||
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
|
||||
parser.add_argument('--channel', default='AWGN', type=str, help='weight decay')
|
||||
parser.add_argument('--weight_decay', default=1e-3, type=float, help='weight decay')
|
||||
parser.add_argument('--channel', default='AWGN', type=str, help='channel type')
|
||||
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
|
||||
parser.add_argument('--snr_list', default=range(1, 19, 3), type=list, help='snr_list')
|
||||
parser.add_argument('--ratio_list', default=[1/6, 1/12], type=list, help='ratio_list')
|
||||
parser.add_argument('--early_stop', default=True, type=bool, help='early_stop')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = config_parser()
|
||||
|
||||
print("Training Start")
|
||||
# for ratio in args.ratio_list:
|
||||
# for snr in args.snr_list:
|
||||
# train(args, ratio, snr)
|
||||
train(args, 1/6, 20)
|
||||
|
||||
|
||||
def train(args: config_parser(), ratio: float, snr: float):
|
||||
|
||||
print("training with ratio: {}, snr: {}, channel: {}".format(ratio, snr, args.channel))
|
||||
|
||||
# load data
|
||||
transform = transforms.Compose([transforms.ToTensor(), ])
|
||||
train_dataset = datasets.CIFAR10(root='./Dataset/cifar-10-batches-py/', train=True,
|
||||
train_dataset = datasets.CIFAR10(root='./Dataset/', train=True,
|
||||
download=True, transform=transform)
|
||||
|
||||
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size)
|
||||
test_dataset = datasets.MNIST(root='./Dataset/cifar-10-batches-py/', train=False,
|
||||
test_dataset = datasets.MNIST(root='./Dataset/', train=False,
|
||||
download=True, transform=transform)
|
||||
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size)
|
||||
image_fisrt = train_dataset.__getitem__(0)[0]
|
||||
c = ratio2filtersize(image_fisrt, ratio)
|
||||
model = DeepJSCC(c=c, channel_type=args.channel, snr=snr)
|
||||
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
epoch_loop = tqdm((args.epochs), total=len(args.epochs), leave=False)
|
||||
for epoch in epoch_loop:
|
||||
run_loss = 0.0
|
||||
for images, _ in tqdm((train_loader), leave=False):
|
||||
optimizer.zero_grad()
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs, images)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
run_loss += loss.item()
|
||||
epoch_loop.set_description(f'Epoch [{epoch}/{args.epochs}]')
|
||||
epoch_loop.set_postfix(loss=run_loss)
|
||||
save_model(model, args.saved + '/model_{}_{}.pth'.format(ratio, snr))
|
||||
|
||||
|
||||
def train():
|
||||
pass
|
||||
def save_model(model, path):
|
||||
torch.save(model.state_dict(), path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Loading…
Reference in New Issue
Block a user