JSCC/train.py
2023-12-19 22:10:35 +08:00

53 lines
1.7 KiB
Python

# -*- coding: utf-8 -*-
"""
Created on Tue Dec 11:00:00 2023
@author: chun
"""
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
import tqdm
from model import DeepJSCC
def config_parser():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--seed', default=2048, type=int, help='Random seed')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--batch_size', default=128, type=int, help='batch size')
parser.add_argument('optimizer', default='Adam', type=str, help='optimizer')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
parser.add_argument('--channel', default='AWGN', type=str, help='weight decay')
parser.add_argument('--saved', default='./saved', type=str, help='saved_path')
return parser.parse_args()
def main():
args = config_parser()
# load data
transform = transforms.Compose([transforms.ToTensor(), ])
train_dataset = datasets.CIFAR10(root='./Dataset/cifar-10-batches-py/', train=True,
download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size)
test_dataset = datasets.MNIST(root='./Dataset/cifar-10-batches-py/', train=False,
download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size)
def train():
pass
if __name__ == '__main__':
main()