diff --git a/channel.py b/channel.py index e69de29..7009645 100644 --- a/channel.py +++ b/channel.py @@ -0,0 +1,9 @@ +import torch +import torch.nn as nn + + +def AWGN_channel(): + pass + +def Rayleigh_channel(): + pass \ No newline at end of file diff --git a/model.py b/model.py index 2a444c8..1e9ecc1 100644 --- a/model.py +++ b/model.py @@ -8,6 +8,7 @@ Created on Tue Dec 11:00:00 2023 import torch import torch.nn as nn import numpy as np +import channel class _ConvWithPReLU(nn.Module): @@ -50,11 +51,12 @@ def _NormlizationLayer(norm_type='nomalization'): def ratio2filter_size(x, ratio): before_size = np.prod(x.size()) after_size = before_size*ratio - encoder_temp = Encoder(c=after_size) + encoder_temp = Encoder(is_temp=True) + x_temp = encoder_temp(x) class Encoder(nn.Module): - def __init__(self, c, is_temp=False): + def __init__(self, c=1, is_temp=False): super(Encoder, self).__init__() self.is_temp = is_temp self.imgae_normalization = _image_normalization(norm_type='nomalization') @@ -81,3 +83,18 @@ class Encoder(nn.Module): class Decoder(nn.Module): def __init__(self): super(Decoder, self).__init__() + + def forward(self, x): + pass + + +class DeepJSCC(nn.Module): + def __init__(self): + super(DeepJSCC, self).__init__() + self.encoder = Encoder() + self.decoder = Decoder() + + def forward(self, x): + z = self.encoder(x) + x_hat = self.decoder(z) + return x_hat diff --git a/train.py b/train.py index fa323c5..93029d7 100644 --- a/train.py +++ b/train.py @@ -7,15 +7,46 @@ Created on Tue Dec 11:00:00 2023 import torch import torch.nn as nn +from torchvision import transforms +from torchvision import datasets +from torch.utils.data import DataLoader +import torch.nn.functional as F +import torch.optim as optim +import tqdm +from model import DeepJSCC +def config_parser(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--seed', default=2048, type=int, help='Random seed') + parser.add_argument('--lr', default=0.1, type=float, help='learning rate') + parser.add_argument('--batch_size', default=128, type=int, help='batch size') + parser.add_argument('optimizer', default='Adam', type=str, help='optimizer') + parser.add_argument('--momentum', default=0.9, type=float, help='momentum') + parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay') + parser.add_argument('--channel', default='AWGN', type=str, help='weight decay') + parser.add_argument('--saved', default='./saved', type=str, help='saved_path') + return parser.parse_args() + def main(): + args = config_parser() + + # load data + transform = transforms.Compose([transforms.ToTensor(), ]) + train_dataset = datasets.CIFAR10(root='./Dataset/cifar-10-batches-py/', train=True, + download=True, transform=transform) + + train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size) + test_dataset = datasets.MNIST(root='./Dataset/cifar-10-batches-py/', train=False, + download=True, transform=transform) + test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size) + + +def train(): pass - - - if __name__ == '__main__': - main() \ No newline at end of file + main()