Pytorch Deep Neural Networks Save

pytorch >>> 快速搭建自己的模型!

Project README

> 包含的网络模型:

Deep Belief Network (DBN) Deep Autoencoder (DAE) Stacked Autoencoder (sAE) Stacked Sparse Autoencoder (sSAE) Stacked Denoising Autoencoder (sDAE) Convolutional Neural Network (CNN) Visual Geometry Group (VGG) Residual Network (ResNet) 模型详细介绍见 README.md

> 视频教程:

建议三连,23333 点我进入B站

> 开始使用:

首先请按path.txt设置包的路径 Pytorch初学: 建议看看 官网教程网络模型codes 理解本package:看看这个不依赖其他文件运行的 简单AE 运行代码:请运行example文件夹下的文件

> 用于任务:

task == 'cls' 用于分类任务 task == 'prd' 用于预测任务

> 读入数据集:

1、 建立 'ReadData' 类来读入数据集 —— 详见 gene_dynamic_data.py

  • 自动加载文件: path定位到根目录,根目录下建立triantest文件夹,文件夹中包含文件名为_x_y的文件来区分输入和输出, 文件后缀可以为csv,txt,dat,xls,xlsx,mat
  • 数据预处理:类初始化中设置 prep = ['prep_x', 'prep_y'], prep 方式包括 'st'标准化, 'mm'归一化, 'oh'01编码
  • 制作动态数据:可设置动态滑窗边长'dynamic', 步长 'stride'

2、 输入网络前一般还需将数据集转换成 DataLoader 以便批次训练 —— 详见 load.py

> CNN快速建模:

List

用一个列表表示CNN的结构: 如[[3, 8], ['M', 2], ['R', [6, (6,6)], '|', [3, 1, 1] ]表示 1、 3@8×8 - 3个8乘8的卷积核 2、 MaxPool - 核大小为 2×2 的池化层(默认stride = kernel_size) 3、 残差块 - 主体为6个6乘6的卷积核,残差部分为3个1乘1的卷积核

列表还有很多灵活的用法,如:

  • '/2' 表示 stride = 2
  • '+1' 表示 padding = 1
  • '#2' 表示 dilation = 2
  • '2*' 表示将后面一个元素循环2次

更多详见 README.md

DataFrame

包的内部会自动将List转换为DataFrame以进一步构建模型 DataFrame中有6列: 'Conv', '*', 'Pool', 'Res', 'Loop', 'Out' 分别表示“卷积结构”(可以是列表),“卷积循环次数”“池化结构”“残差结构”“整个块循环次数”“输出尺寸”

Parameter

模型构建好后,网络的“可优化参数”“参数尺寸”将自动在Console中展示

> 训练与测试:

Step 1. 创建模型类

class CNN(Module, Conv_Module):  
    def __init__(self, **kwargs):
        self.name = 'CNN'
        
        Module.__init__(self,**kwargs)
        Conv_Module.__init__(self,**kwargs)

        self.layers = self.Convolutional()
        self.fc = self.Sequential()
        self.opt()

    def forward(self, x, y = None):
        for layer in self.layers:
            x = layer(x)
        x = x.view(x.size(0),-1)
        x = self.fc(x)
        return x

Step 2. 实例化模型

parameter = {'img_size': [1,28,28],
             'conv_struct': [[3, 8], ['M', 2], [6, (6,6)]],
             'conv_func': 'r',
             'struct': [-1, 10],
             'hidden_func': ['g', 'a'],
             'output_func': 'a',
             'dropout': 0.0,
             'task': 'cls'}
    
model = CNN(**parameter)

Step 3. 加载数据集

model.load_mnist('../data', 128)

Step 4. 训练与测试模型

for epoch in range(1, 3 + 1):
    model.batch_training(epoch)
    model.test(epoch)

> 结果展示:

model.result()

Console:

Structure:
             Conv  * Res         Pool Loop          Out
0       [1, 3, 8]  1   -  [Max, 2, 2]    1  [3, 10, 10]
1  [3, 6, (6, 6)]  1   -            -    1    [6, 5, 5]

CNN(
  (L): MSELoss()
  (layers): Sequential(
    (0): ConvBlock(
      (conv1): Conv2d(1, 3, kernel_size=(8, 8), stride=(1, 1))
      (bn1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act_layer): ReLU(inplace)
      (pool_layer): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): ConvBlock(
      (conv1): Conv2d(3, 6, kernel_size=(6, 6), stride=(1, 1))
      (bn1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act_layer): ReLU(inplace)
    )
  )
  (fc): Sequential(
    (0): Linear(in_features=150, out_features=10, bias=True)
    (1): Affine()
  )
)
CNN's Parameters(
  layers.0.conv1.weight:        torch.Size([3, 1, 8, 8])
  layers.0.conv1.bias:  torch.Size([3])
  layers.0.bn1.weight:  torch.Size([3])
  layers.0.bn1.bias:    torch.Size([3])
  layers.0.bn1.running_mean:    torch.Size([3])
  layers.0.bn1.running_var:     torch.Size([3])
  layers.0.bn1.num_batches_tracked:     torch.Size([])
  layers.1.conv1.weight:        torch.Size([6, 3, 6, 6])
  layers.1.conv1.bias:  torch.Size([6])
  layers.1.bn1.weight:  torch.Size([6])
  layers.1.bn1.bias:    torch.Size([6])
  layers.1.bn1.running_mean:    torch.Size([6])
  layers.1.bn1.running_var:     torch.Size([6])
  layers.1.bn1.num_batches_tracked:     torch.Size([])
  fc.0.weight:  torch.Size([10, 150])
  fc.0.bias:    torch.Size([10])
)
Epoch: 1 - 469/469 | loss = 0.0259
    >>> Train: loss = 0.0375   accuracy = 0.9340   
    >>> Test: loss = 0.0225   accuracy = 0.9389   
Epoch: 2 - 469/469 | loss = 0.0200
    >>> Train: loss = 0.0215   accuracy = 0.9484   
    >>> Test: loss = 0.0191   accuracy = 0.9525   
Epoch: 3 - 469/469 | loss = 0.0204
    >>> Train: loss = 0.0192   accuracy = 0.9537   
    >>> Test: loss = 0.0179   accuracy = 0.9571   

My blog

ResearchGate, 知乎, CSDN QQ群:640571839

Paper

希望大家多支持支持我们的工作,欢迎交流探讨~ [1] Z. Pan, H. Chen, Y. Wang, B. Huang, and W. Gui, "A new perspective on ae-and vae-based process monitoring," TechRxiv, Apr. 2022, doi.10.36227/techrxiv.19617534. [2] Z. Pan, Y. Wang, k. Wang, G. Ran, H. Chen, and W. Gui, "Layer-Wise Contribution-Filtered Propagation for Deep Learning-Based Fault Isolation," Int. J. Robust Nonlinear Control, Jul. 2022, doi.10.1002/rnc.6328 [3] Z. Pan, Y. Wang, K. Wang, H. Chen, C. Yang, and W. Gui, "Imputation of Missing Values in Time Series Using an Adaptive-Learned Median-Filled Deep Autoencoder," IEEE Trans. Cybern., 2022, doi.10.1109/TCYB.2022.3167995 [4] Y. Wang, Z. Pan, X. Yuan, C. Yang, and W. Gui, "A novel deep learning based fault diagnosis approach for chemical process with extended deep belief network,” ISA Trans., vol. 96, pp. 457–467, 2020. [5] Z. Pan, Y. Wang, X. Yuan, C. Yang, and W. Gui, "A classification-driven neuron-grouped sae for feature representation and its application to fault classification in chemical processes," Knowl.-Based Syst., vol. 230, p. 107350, 2021. [6] H. Chen, B. Jiang, S. X. Ding, and B. Huang, "Data-driven fault diagnosis for traction systems in high-speed trains: A survey, challenges, and perspectives," IEEE Trans. Intell. Transp. Syst., 2020, doi.10.1109/TITS.2020.3029946 [7] H. Chen and B. Jiang, "A review of fault detection and diagnosis for the traction system in high-speed trains," IEEE Trans. Intell. Transp. Syst., vol. 21, no. 2, pp. 450–465, Feb. 2020.

Open Source Agenda is not affiliated with "Pytorch Deep Neural Networks" Project. README Source: zhuofupan/Pytorch-Deep-Neural-Networks
Stars
119
Open Issues
2
Last Commit
1 year ago

Open Source Agenda Badge

Open Source Agenda Rating