\`x^2+y_1+z_12^34\`
Advanced Search
Article Contents
Article Contents

Classification with Runge-Kutta networks and feature space augmentation

  • * Corresponding author: Axel Kröner

    * Corresponding author: Axel Kröner
The second author is supported by DAAD project 57570343
Abstract Full Text(HTML) Figure(18) / Table(6) Related Papers Cited by
  • In this paper we combine an approach based on Runge-Kutta Nets considered in [Benning et al., J. Comput. Dynamics, 9, 2019] and a technique on augmenting the input space in [Dupont et al., NeurIPS, 2019] to obtain network architectures which show a better numerical performance for deep neural networks in point and image classification problems. The approach is illustrated with several examples implemented in PyTorch.

    Mathematics Subject Classification: 65L06, 68T07.

    Citation:

    \begin{equation} \\ \end{equation}
  • 加载中
  • Figure 1.  Butcher tableaus: (from left to right) general form, forward Euler and classic RK4.

    Figure 2.  Two dimensional datasets for binary point classification with 1500 samples each: donut 1D and donut 2D (top), squares and spiral (bottom).

    Figure 3.  Classification of donut_2D with RK4Net of width $ \hat{d} = 2 $ corresponding to the NODE-approach (top) and $ \hat{d} = 3 $, i. e. with space augmentation characterizing the ANODE-approach (bottom), and of same depth $ L = 100 $ and $ \tanh $ activation. The plots show (from left to right) the trajectories of the features starting at the small dot and terminating at the large dot, their final transformation in the output layer and the resulting prediction with coloured background according to the network's classification.

    Figure 4.  Classification of squares with RK4Net of width $ \hat{d} = 2 $ corresponding to the NODE-approach (top) and $ \hat{d} = 3 $, i. e. with space augmentation characterizing the ANODE-approach (bottom), and of same depth $ L = 100 $ and $ \tanh $ activation. The plots show (from left to right) the trajectories of the features starting at the small dot and terminating at the large dot, their final transformation in the output layer and the resulting prediction with coloured background according to the network's classification.

    Figure 5.  Feature transformation of spiral with StandardNet (top) and RK4Net (bottom) of width $ \hat{d} = 16 $, depth $ L = 20 $ and $ \tanh $ activation. (From left to right) features in input layer, hidden layers and output layer.

    Figure 6.  Prediction of donut_1D with RK4Net of width $ \hat{d} = 16 $, depth $ L = 20 $ and $ \tanh $ activation.

    Figure 7.  Accuracy (left) and cost (right) over the course of epochs on donut_1D with RK4Net of width $ \hat{d} = 16 $ and depth $ L = 20 $. Solid lines represent metrics on validation and dotted lines on training data.

    Figure 8.  Donut and squares datasets of different dimensionality and with varying number of classes used for comparing performance of networks between binary and multiclass classification (first column), as well as 2D and 3D input space (second and third column).

    Figure 9.  Repetitions with random initializations for RK4Net with width $ \hat{d} = 16 $, depth $ L = 100 $ and $ \tanh $ activation, on donut 2D & 6C. The plots show (upper row) the feature transformation in the output layer reduced by PCA to 3D, and (lower row) the resulting prediction underlaid with a coloured background according to the network's classification.

    Figure 10.  Classification of donut 2D & 6C with network width $ \hat{d} = 16 $, depth $ L = 5 $ for StandardNet and $ L = 100 $ for EulerNet and RK4Net, and $ \tanh $ activation. The plots show (from left to right) the feature transformation in the output layer reduced by PCA to 3D and 2D, and the resulting prediction. Two dimensional plots are underlaid with a coloured background according to the network's classification.

    Figure 11.  Classification of squares 2D & 4C with network width $ \hat{d} = 16 $, depth $ L = 5 $ for StandardNet and $ L = 100 $ for EulerNet and RK4Net, and $ \tanh $ activation. The plots show (from left to right) the feature transformation in the output layer reduced by PCA to 3D and 2D, and the resulting prediction. Two dimensional plots are underlaid with a coloured background according to the network's classification.

    Figure 12.  Validation accuracy (left) and cost (right) over the course of epochs on donut 3D & 6C with network width $ \hat{d} = 16 $, depth $ L = 5 $ for StandardNet and $ L = 100 $ for EulerNet and RK4Net, and $ \tanh $ activation. Solid line represents the mean and shaded area the standard deviation over repetitions.

    Figure 13.  Validation accuracy (left) and cost (right) over the course of epochs on squares 3D & 4C with network width $ \hat{d} = 16 $, depth $ L = 5 $ for StandardNet and $ L = 100 $ for EulerNet and RK4Net, and $ \tanh $ activation. Solid line represents the mean and shaded area the standard deviation over repetitions.

    Figure 14.  Exemplary images of MNIST with true label and prediction produced by RK4Net with width $ \hat{d} = 30^2 $, depth $ L = 100 $ and $ \tanh $ activation.

    Figure 15.  Exemplary images of Fashion-MNIST with true label and prediction produced by RK4Net with width $ \hat{d} = 30^2 $, depth $ L = 100 $ and $ \tanh $ activation.

    Figure 16.  Feature transformation in the output layer of StandardNet (left) and RK4Net (right) of Fashion-MNIST images reduced by PCA to 3D. Each color represents one article class.

    Figure 17.  Accuracy (left) and cost (right) over the course of epochs on MNIST with network width $ \hat{d} = 30^2 $, depth $ L = 5 $ for StandardNet and $ L = 100 $ for EulerNet and RK4Net, and $ \tanh $ activation. Solid lines represent metrics on validation and dotted lines on training data.

    Figure 18.  Accuracy (left) and cost (right) over the course of epochs on Fashion-MNIST with network width $ \hat{d} = 30^2 $, depth $ L = 5 $ for StandardNet and $ L = 100 $ for EulerNet and RK4Net, and $ \tanh $ activation. Solid lines represent metrics on validation and dotted lines on training data.

    Table 1.  Mean of training (upper row) and validation (lower row) accuracy (%) over four repetitions on spiral with network width $ \hat{d} = 16 $ and $ \tanh $ activation.

    depth L 1 3 5 10 20 40 100
    StandardNet 92.73
    91.88
    92.87
    92.50
    98.12
    98.10
    97.52
    97.45
    67.62
    66.87
    51.08
    48.92
    50.67
    49.33
    RK4Net 75.60
    75.12
    91.42
    90.68
    97.90
    97.33
    99.77
    99.47
    99.93
    99.70
    99.73
    99.50
    99.95
    99.75
     | Show Table
    DownLoad: CSV

    Table 2.  Mean of training (upper row) and validation (lower row) cost ($ \times 10^{-1} $) over four repetitions on spiral with network width $ \hat{d} = 16 $ and $ \tanh $ activation.

    depth L 1 3 5 10 20 40 100
    StandardNet 2.23
    2.33
    1.38
    1.53
    0.66
    0.67
    0.77
    0.77
    6.09
    6.13
    6.93
    6.94
    6.93
    6.93
    RK4Net 4.32
    4.39
    2.68
    2.69
    0.98
    1.06
    0.16
    0.28
    0.04
    0.13
    0.10
    0.12
    0.01
    0.12
     | Show Table
    DownLoad: CSV

    Table 3.  Variability of accuracy (%) and cost ($ \times 10^{-1} $) over four repetitions for RK4Net with width $ \hat{d} = 16 $, depth $ L = 100 $ and $ \tanh $ activation, on donut 2D & 6C.

    training accuracy validation accuracy training cost validation cost
    mean 77.13 74.92 5.13 5.59
    standard deviation 0.76 0.89 0.08 0.16
     | Show Table
    DownLoad: CSV

    Table 4.  Mean of validation accuracy (%, upper row) and cost ($ \times 10^{-1} $, lower row) over four repetitions with network width $ \hat{d} = 16 $, depth $ L = 5 $ for StandardNet and $ L = 100 $ for EulerNet and RK4Net, and $ \tanh $ activation.

    donut
    3D & 2C
    donut
    3D & 3C
    donut
    2D & 6C
    donut
    3D & 6C
    squares
    2D & 4C
    squares
    3D & 4C
    StandardNet 92.37
    1.71
    87.75
    2.85
    75.12
    5.60
    73.00
    5.86
    94.12
    1.57
    89.68
    3.03
    EulerNet 91.88
    1.84
    88.30
    2.75
    74.87
    5.56
    74.63
    5.67
    93.35
    1.66
    89.48
    2.71
    RK4Net 92.73
    1.72
    87.13
    2.95
    74.92
    5.59
    74.88
    5.73
    93.20
    1.64
    89.37
    2.81
     | Show Table
    DownLoad: CSV

    Table 5.  Mean and standard deviation of accuracy (%) and cost ($ \times 10^{-1} $) over four repetitions for non-augmented (upper row) and augmented (lower row) RK4Net with depth $ L = 100 $ and $ \tanh $ activation on MNIST.

    width $\hat{d}$ training accuracy validation accuracy training cost validation cost
    $28^2$ $97.70 \pm 2.80$ $87.27 \pm 2.93$ $0.78 \pm 0.95$ $7.71 \pm 1.62$
    $30^2$ $99.77 \pm 0.40$ $90.40 \pm 1.08$ $0.10 \pm 0.17$ $5.36 \pm 0.46$
     | Show Table
    DownLoad: CSV

    Table 6.  Mean and standard deviation of validation accuracy (%, upper row) and cost ($ \times 10^{-1} $, lower row) over four repetitions with network width $ \hat{d} = 30^2 $, depth $ L = 5 $ for StandardNet and $ L = 100 $ for EulerNet and RK4Net, and $ \tanh $ activation.

    MNIST Fashion-MNIST
    StandardNet $85.67 \pm 0.78\\8.95 \pm 0.92$ $61.23 \pm 6.00\\11.41 \pm 1.37$
    EulerNet $90.98 \pm 0.48\\5.71 \pm 0.36$ $77.62 \pm 2.57\\9.52 \pm 1.87$
    RK4Net $90.40 \pm 1.08\\5.36 \pm 0.46$ $79.13 \pm 1.57\\8.24 \pm 0.60$
     | Show Table
    DownLoad: CSV
  • [1] M. BenningE. CelledoniM. J. EhrhardtB. Owren and C.-B. Schönlieb, Deep learning as optimal control problems: Models and numerical methods, J. Comput. Dyn., 6 (2019), 171-198.  doi: 10.3934/jcd.2019009.
    [2] E. CelledoniM. J. EhrhardtC. EtmannR. I. McLachlanB. OwrenC.-B. Schönlieb and F. Sherry, Structure-preserving deep learning, European J. Appl. Math., 32 (2021), 888-936.  doi: 10.1017/S0956792521000139.
    [3] R. T. Q. Chen, Y. Rubanova, J. Bettencourt and D. K. Duvenaud, Neural Ordinary Differential Equations, Advances in Neural Information Processing Systems, 31, Curran Associates, Inc., 2018.
    [4] E. Dupont, A. Doucet and Y. W. Teh, Augmented neural ODEs, Adv. Neural Inf. Process. Syst., 32 (2019).
    [5] W. E, A proposal on machine learning via dynamical systems, Commun. Math. Stat., 5 (2017), 1-11.  doi: 10.1007/s40304-017-0103-z.
    [6] E. Giesecke, Augmented-RK-Nets, 2021. Available from: https://github.com/ElisaGiesecke/augmented-RK-Nets.
    [7] I. GoodfellowY. Bengio and  A. CourvilleDeep Learning, Adaptive Computation and Machine Learning, MIT Press, Cambridge, MA, 2016. 
    [8] W. W. Hager, Runge-Kutta methods in optimal control and the transformed adjoint system, Numer. Math., 87 (2000), 247-282.  doi: 10.1007/s002110000178.
    [9] K. He, X. Zhang, S. Ren and J. Sun, Deep residual learning for image recognition, IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Las Vegas, NV, 2016. doi: 10.1109/CVPR.2016.90.
    [10] C. F. Higham and D. J. Higham, Deep learning: An introduction for applied mathematicians, SIAM Rev., 61 (2019), 860-891.  doi: 10.1137/18M1165748.
    [11] M. RaissiP. Perdikaris and G. E. Karniadakis, Physics-informed neural networks: A deep learning framework for solving forward and inverse problems involving nonlinear partial differential equations, J. Comput. Phys., 378 (2019), 686-707.  doi: 10.1016/j.jcp.2018.10.045.
    [12] D. Ruiz-Balet and E. Zuazua, Neural ODE control for classification, approximation and transport, preprint, arXiv: 2104.05278.
    [13] J. M. Sanz-Serna, Symplectic Runge-Kutta and related methods: Recent results, Phys. D, 60 (1992), 293-302.  doi: 10.1016/0167-2789(92)90245-I.
    [14] J. M. Sanz-Serna, Symplectic Runge-Kutta schemes for adjoint equations, automatic differentiation, optimal control, and more, SIAM Rev., 58 (2016), 3-33.  doi: 10.1137/151002769.
  • 加载中

Figures(18)

Tables(6)

SHARE

Article Metrics

HTML views(1844) PDF downloads(198) Cited by(0)

Access History

Other Articles By Authors

Catalog

    /

    DownLoad:  Full-Size Img  PowerPoint
    Return
    Return