the embryo

This commit is contained in:
chun 2023-12-19 22:10:35 +08:00
parent dc08bfa17e
commit cc6254843a
3 changed files with 63 additions and 6 deletions

View File

@ -0,0 +1,9 @@
import torch
import torch.nn as nn
def AWGN_channel():
pass
def Rayleigh_channel():
pass

View File

@ -8,6 +8,7 @@ Created on Tue Dec 11:00:00 2023
import torch
import torch.nn as nn
import numpy as np
import channel
class _ConvWithPReLU(nn.Module):
@ -50,11 +51,12 @@ def _NormlizationLayer(norm_type='nomalization'):
def ratio2filter_size(x, ratio):
before_size = np.prod(x.size())
after_size = before_size*ratio
encoder_temp = Encoder(c=after_size)
encoder_temp = Encoder(is_temp=True)
x_temp = encoder_temp(x)
class Encoder(nn.Module):
def __init__(self, c, is_temp=False):
def __init__(self, c=1, is_temp=False):
super(Encoder, self).__init__()
self.is_temp = is_temp
self.imgae_normalization = _image_normalization(norm_type='nomalization')
@ -81,3 +83,18 @@ class Encoder(nn.Module):
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
def forward(self, x):
pass
class DeepJSCC(nn.Module):
def __init__(self):
super(DeepJSCC, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
z = self.encoder(x)
x_hat = self.decoder(z)
return x_hat

View File

@ -7,15 +7,46 @@ Created on Tue Dec 11:00:00 2023
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
import tqdm
from model import DeepJSCC
def config_parser():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--seed', default=2048, type=int, help='Random seed')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--batch_size', default=128, type=int, help='batch size')
parser.add_argument('optimizer', default='Adam', type=str, help='optimizer')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
parser.add_argument('--channel', default='AWGN', type=str, help='weight decay')
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
return parser.parse_args()
def main():
args = config_parser()
# load data
transform = transforms.Compose([transforms.ToTensor(), ])
train_dataset = datasets.CIFAR10(root='./Dataset/cifar-10-batches-py/', train=True,
download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size)
test_dataset = datasets.MNIST(root='./Dataset/cifar-10-batches-py/', train=False,
download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size)
def train():
pass
if __name__ == '__main__':
main()