the embryo
This commit is contained in:
parent
dc08bfa17e
commit
cc6254843a
@ -0,0 +1,9 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def AWGN_channel():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def Rayleigh_channel():
|
||||||
|
pass
|
||||||
21
model.py
21
model.py
@ -8,6 +8,7 @@ Created on Tue Dec 11:00:00 2023
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import channel
|
||||||
|
|
||||||
|
|
||||||
class _ConvWithPReLU(nn.Module):
|
class _ConvWithPReLU(nn.Module):
|
||||||
@ -50,11 +51,12 @@ def _NormlizationLayer(norm_type='nomalization'):
|
|||||||
def ratio2filter_size(x, ratio):
|
def ratio2filter_size(x, ratio):
|
||||||
before_size = np.prod(x.size())
|
before_size = np.prod(x.size())
|
||||||
after_size = before_size*ratio
|
after_size = before_size*ratio
|
||||||
encoder_temp = Encoder(c=after_size)
|
encoder_temp = Encoder(is_temp=True)
|
||||||
|
x_temp = encoder_temp(x)
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
def __init__(self, c, is_temp=False):
|
def __init__(self, c=1, is_temp=False):
|
||||||
super(Encoder, self).__init__()
|
super(Encoder, self).__init__()
|
||||||
self.is_temp = is_temp
|
self.is_temp = is_temp
|
||||||
self.imgae_normalization = _image_normalization(norm_type='nomalization')
|
self.imgae_normalization = _image_normalization(norm_type='nomalization')
|
||||||
@ -81,3 +83,18 @@ class Encoder(nn.Module):
|
|||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DeepJSCC(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(DeepJSCC, self).__init__()
|
||||||
|
self.encoder = Encoder()
|
||||||
|
self.decoder = Decoder()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
z = self.encoder(x)
|
||||||
|
x_hat = self.decoder(z)
|
||||||
|
return x_hat
|
||||||
|
|||||||
37
train.py
37
train.py
@ -7,15 +7,46 @@ Created on Tue Dec 11:00:00 2023
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torchvision import transforms
|
||||||
|
from torchvision import datasets
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.optim as optim
|
||||||
|
import tqdm
|
||||||
|
from model import DeepJSCC
|
||||||
|
|
||||||
|
|
||||||
|
def config_parser():
|
||||||
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--seed', default=2048, type=int, help='Random seed')
|
||||||
|
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
|
||||||
|
parser.add_argument('--batch_size', default=128, type=int, help='batch size')
|
||||||
|
parser.add_argument('optimizer', default='Adam', type=str, help='optimizer')
|
||||||
|
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
|
||||||
|
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
|
||||||
|
parser.add_argument('--channel', default='AWGN', type=str, help='weight decay')
|
||||||
|
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
args = config_parser()
|
||||||
|
|
||||||
|
# load data
|
||||||
|
transform = transforms.Compose([transforms.ToTensor(), ])
|
||||||
|
train_dataset = datasets.CIFAR10(root='./Dataset/cifar-10-batches-py/', train=True,
|
||||||
|
download=True, transform=transform)
|
||||||
|
|
||||||
|
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size)
|
||||||
|
test_dataset = datasets.MNIST(root='./Dataset/cifar-10-batches-py/', train=False,
|
||||||
|
download=True, transform=transform)
|
||||||
|
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
def train():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
Loading…
Reference in New Issue
Block a user