Article Contents
Article Contents

# Classification with Runge-Kutta networks and feature space augmentation

• * Corresponding author: Axel Kröner
The second author is supported by DAAD project 57570343
• In this paper we combine an approach based on Runge-Kutta Nets considered in [Benning et al., J. Comput. Dynamics, 9, 2019] and a technique on augmenting the input space in [Dupont et al., NeurIPS, 2019] to obtain network architectures which show a better numerical performance for deep neural networks in point and image classification problems. The approach is illustrated with several examples implemented in PyTorch.

Mathematics Subject Classification: 65L06, 68T07.

 Citation:

• Figure 1.  Butcher tableaus: (from left to right) general form, forward Euler and classic RK4.

Figure 2.  Two dimensional datasets for binary point classification with 1500 samples each: donut 1D and donut 2D (top), squares and spiral (bottom).

Figure 3.  Classification of donut_2D with RK4Net of width $\hat{d} = 2$ corresponding to the NODE-approach (top) and $\hat{d} = 3$, i. e. with space augmentation characterizing the ANODE-approach (bottom), and of same depth $L = 100$ and $\tanh$ activation. The plots show (from left to right) the trajectories of the features starting at the small dot and terminating at the large dot, their final transformation in the output layer and the resulting prediction with coloured background according to the network's classification.

Figure 4.  Classification of squares with RK4Net of width $\hat{d} = 2$ corresponding to the NODE-approach (top) and $\hat{d} = 3$, i. e. with space augmentation characterizing the ANODE-approach (bottom), and of same depth $L = 100$ and $\tanh$ activation. The plots show (from left to right) the trajectories of the features starting at the small dot and terminating at the large dot, their final transformation in the output layer and the resulting prediction with coloured background according to the network's classification.

Figure 5.  Feature transformation of spiral with StandardNet (top) and RK4Net (bottom) of width $\hat{d} = 16$, depth $L = 20$ and $\tanh$ activation. (From left to right) features in input layer, hidden layers and output layer.

Figure 6.  Prediction of donut_1D with RK4Net of width $\hat{d} = 16$, depth $L = 20$ and $\tanh$ activation.

Figure 7.  Accuracy (left) and cost (right) over the course of epochs on donut_1D with RK4Net of width $\hat{d} = 16$ and depth $L = 20$. Solid lines represent metrics on validation and dotted lines on training data.

Figure 8.  Donut and squares datasets of different dimensionality and with varying number of classes used for comparing performance of networks between binary and multiclass classification (first column), as well as 2D and 3D input space (second and third column).

Figure 9.  Repetitions with random initializations for RK4Net with width $\hat{d} = 16$, depth $L = 100$ and $\tanh$ activation, on donut 2D & 6C. The plots show (upper row) the feature transformation in the output layer reduced by PCA to 3D, and (lower row) the resulting prediction underlaid with a coloured background according to the network's classification.

Figure 10.  Classification of donut 2D & 6C with network width $\hat{d} = 16$, depth $L = 5$ for StandardNet and $L = 100$ for EulerNet and RK4Net, and $\tanh$ activation. The plots show (from left to right) the feature transformation in the output layer reduced by PCA to 3D and 2D, and the resulting prediction. Two dimensional plots are underlaid with a coloured background according to the network's classification.

Figure 11.  Classification of squares 2D & 4C with network width $\hat{d} = 16$, depth $L = 5$ for StandardNet and $L = 100$ for EulerNet and RK4Net, and $\tanh$ activation. The plots show (from left to right) the feature transformation in the output layer reduced by PCA to 3D and 2D, and the resulting prediction. Two dimensional plots are underlaid with a coloured background according to the network's classification.

Figure 12.  Validation accuracy (left) and cost (right) over the course of epochs on donut 3D & 6C with network width $\hat{d} = 16$, depth $L = 5$ for StandardNet and $L = 100$ for EulerNet and RK4Net, and $\tanh$ activation. Solid line represents the mean and shaded area the standard deviation over repetitions.

Figure 13.  Validation accuracy (left) and cost (right) over the course of epochs on squares 3D & 4C with network width $\hat{d} = 16$, depth $L = 5$ for StandardNet and $L = 100$ for EulerNet and RK4Net, and $\tanh$ activation. Solid line represents the mean and shaded area the standard deviation over repetitions.

Figure 14.  Exemplary images of MNIST with true label and prediction produced by RK4Net with width $\hat{d} = 30^2$, depth $L = 100$ and $\tanh$ activation.

Figure 15.  Exemplary images of Fashion-MNIST with true label and prediction produced by RK4Net with width $\hat{d} = 30^2$, depth $L = 100$ and $\tanh$ activation.

Figure 16.  Feature transformation in the output layer of StandardNet (left) and RK4Net (right) of Fashion-MNIST images reduced by PCA to 3D. Each color represents one article class.

Figure 17.  Accuracy (left) and cost (right) over the course of epochs on MNIST with network width $\hat{d} = 30^2$, depth $L = 5$ for StandardNet and $L = 100$ for EulerNet and RK4Net, and $\tanh$ activation. Solid lines represent metrics on validation and dotted lines on training data.

Figure 18.  Accuracy (left) and cost (right) over the course of epochs on Fashion-MNIST with network width $\hat{d} = 30^2$, depth $L = 5$ for StandardNet and $L = 100$ for EulerNet and RK4Net, and $\tanh$ activation. Solid lines represent metrics on validation and dotted lines on training data.

Table 1.  Mean of training (upper row) and validation (lower row) accuracy (%) over four repetitions on spiral with network width $\hat{d} = 16$ and $\tanh$ activation.

 depth L 1 3 5 10 20 40 100 StandardNet 92.7391.88 92.8792.50 98.1298.10 97.5297.45 67.6266.87 51.0848.92 50.6749.33 RK4Net 75.6075.12 91.4290.68 97.9097.33 99.7799.47 99.9399.70 99.7399.50 99.9599.75

Table 2.  Mean of training (upper row) and validation (lower row) cost ($\times 10^{-1}$) over four repetitions on spiral with network width $\hat{d} = 16$ and $\tanh$ activation.

 depth L 1 3 5 10 20 40 100 StandardNet 2.232.33 1.381.53 0.660.67 0.770.77 6.096.13 6.936.94 6.936.93 RK4Net 4.324.39 2.682.69 0.981.06 0.160.28 0.040.13 0.100.12 0.010.12

Table 3.  Variability of accuracy (%) and cost ($\times 10^{-1}$) over four repetitions for RK4Net with width $\hat{d} = 16$, depth $L = 100$ and $\tanh$ activation, on donut 2D & 6C.

 training accuracy validation accuracy training cost validation cost mean 77.13 74.92 5.13 5.59 standard deviation 0.76 0.89 0.08 0.16

Table 4.  Mean of validation accuracy (%, upper row) and cost ($\times 10^{-1}$, lower row) over four repetitions with network width $\hat{d} = 16$, depth $L = 5$ for StandardNet and $L = 100$ for EulerNet and RK4Net, and $\tanh$ activation.

 donut3D & 2C donut3D & 3C donut2D & 6C donut3D & 6C squares2D & 4C squares3D & 4C StandardNet 92.371.71 87.752.85 75.125.60 73.005.86 94.121.57 89.683.03 EulerNet 91.881.84 88.302.75 74.875.56 74.635.67 93.351.66 89.482.71 RK4Net 92.731.72 87.132.95 74.925.59 74.885.73 93.201.64 89.372.81

Table 5.  Mean and standard deviation of accuracy (%) and cost ($\times 10^{-1}$) over four repetitions for non-augmented (upper row) and augmented (lower row) RK4Net with depth $L = 100$ and $\tanh$ activation on MNIST.

 width $\hat{d}$ training accuracy validation accuracy training cost validation cost $28^2$ $97.70 \pm 2.80$ $87.27 \pm 2.93$ $0.78 \pm 0.95$ $7.71 \pm 1.62$ $30^2$ $99.77 \pm 0.40$ $90.40 \pm 1.08$ $0.10 \pm 0.17$ $5.36 \pm 0.46$

Table 6.  Mean and standard deviation of validation accuracy (%, upper row) and cost ($\times 10^{-1}$, lower row) over four repetitions with network width $\hat{d} = 30^2$, depth $L = 5$ for StandardNet and $L = 100$ for EulerNet and RK4Net, and $\tanh$ activation.

 MNIST Fashion-MNIST StandardNet $85.67 \pm 0.78\\8.95 \pm 0.92$ $61.23 \pm 6.00\\11.41 \pm 1.37$ EulerNet $90.98 \pm 0.48\\5.71 \pm 0.36$ $77.62 \pm 2.57\\9.52 \pm 1.87$ RK4Net $90.40 \pm 1.08\\5.36 \pm 0.46$ $79.13 \pm 1.57\\8.24 \pm 0.60$
•  [1] M. Benning, E. Celledoni, M. J. Ehrhardt, B. Owren and C.-B. Schönlieb, Deep learning as optimal control problems: Models and numerical methods, J. Comput. Dyn., 6 (2019), 171-198.  doi: 10.3934/jcd.2019009. [2] E. Celledoni, M. J. Ehrhardt, C. Etmann, R. I. McLachlan, B. Owren, C.-B. Schönlieb and F. Sherry, Structure-preserving deep learning, European J. Appl. Math., 32 (2021), 888-936.  doi: 10.1017/S0956792521000139. [3] R. T. Q. Chen, Y. Rubanova, J. Bettencourt and D. K. Duvenaud, Neural Ordinary Differential Equations, Advances in Neural Information Processing Systems, 31, Curran Associates, Inc., 2018. [4] E. Dupont, A. Doucet and Y. W. Teh, Augmented neural ODEs, Adv. Neural Inf. Process. Syst., 32 (2019). [5] W. E, A proposal on machine learning via dynamical systems, Commun. Math. Stat., 5 (2017), 1-11.  doi: 10.1007/s40304-017-0103-z. [6] E. Giesecke, Augmented-RK-Nets, 2021. Available from: https://github.com/ElisaGiesecke/augmented-RK-Nets. [7] I. Goodfellow,  Y. Bengio and  A. Courville,  Deep Learning, Adaptive Computation and Machine Learning, MIT Press, Cambridge, MA, 2016. [8] W. W. Hager, Runge-Kutta methods in optimal control and the transformed adjoint system, Numer. Math., 87 (2000), 247-282.  doi: 10.1007/s002110000178. [9] K. He, X. Zhang, S. Ren and J. Sun, Deep residual learning for image recognition, IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Las Vegas, NV, 2016. doi: 10.1109/CVPR.2016.90. [10] C. F. Higham and D. J. Higham, Deep learning: An introduction for applied mathematicians, SIAM Rev., 61 (2019), 860-891.  doi: 10.1137/18M1165748. [11] M. Raissi, P. Perdikaris and G. E. Karniadakis, Physics-informed neural networks: A deep learning framework for solving forward and inverse problems involving nonlinear partial differential equations, J. Comput. Phys., 378 (2019), 686-707.  doi: 10.1016/j.jcp.2018.10.045. [12] D. Ruiz-Balet and E. Zuazua, Neural ODE control for classification, approximation and transport, preprint, arXiv: 2104.05278. [13] J. M. Sanz-Serna, Symplectic Runge-Kutta and related methods: Recent results, Phys. D, 60 (1992), 293-302.  doi: 10.1016/0167-2789(92)90245-I. [14] J. M. Sanz-Serna, Symplectic Runge-Kutta schemes for adjoint equations, automatic differentiation, optimal control, and more, SIAM Rev., 58 (2016), 3-33.  doi: 10.1137/151002769.

Figures(18)

Tables(6)