Controlling stochastic gradient descent using stochastic approximation for robust distributed optimization

  • *Corresponding author: Adit Jain

Dedicated to Professor George Yin on the occasion of his 70th birthday. This research was funded by the Army research office under grant W911NF-24-1-0083 and National Science Foundation under grant CCF-2312198

  • This paper deals with the problem of controlling the stochastic gradient descent, performed by multiple learners where the aim is to estimate the respective $ \arg\min f $ using noisy gradients obtained by querying a stochastic oracle. Each query has a learning cost, and the noisy gradient response has varying degrees of noise variance, the bound of which is assumed to vary in a Markovian fashion. For a single learner, the decision problem is to choose when to query the oracle such that the learning cost is minimized. A constrained Markov decision process (CMDP) is formulated to solve the decision problem of a single learner. Structural results are proven for the optimal policy for the CMDP, which is shown to be threshold decreasing in the queue state. For multiple learners, a constrained switching control game is formulated for scheduling and controlling $ N $ learners querying the same oracle, one at a time. The structural results are extended for the optimal policy achieving the Nash equilibrium. The structural results are used to propose a stochastic approximation algorithm to search for the optimal policy, which tracks the parameters of the optimal policy using a sigmoidal approximation and does not require knowledge of the underlying transition probabilities. The paper also briefly discusses applications in federated learning and numerically shows the convergence properties of the proposed algorithm.

    Mathematics Subject Classification: Primary: 62L20, 91A15, 90C40; Secondary: 60J20.


  • Figure 1.  Pictorial representation of the two system models considered in the paper: (a) shows a single learner with queue length $ b_k $ querying a stochastic oracle in state $ o_k $. Learner sends the query $ q_k $ and receives a noisy gradient of $ f $ at the query point $ r_k $ and a bound on the noise variance $ I^2_k $. (b) illustrates a multi-learner setting where multiple learners query the same oracle, but only a single learner can be scheduled to query the oracle. The querying protocol remains similar to (a), but each learner has a different queue, and the oracle has different states with respect to different learners

    Figure 2.  Approximate threshold parameters (denoted by $ \theta $) in constant step size SASPS algorithm estimating the true parameters (denoted by $ \phi $) for different oracle states

    Figure 3.  Approximate threshold parameters (denoted by $ \theta $) in constant step size SASPS algorithm tracking changes to the true parameters (denoted by $ \phi $) for different oracle states. There is a change introduced in the transition probabilities at iteration $ 3000 $ (dashed)

    Figure 4.  Approximate threshold parameters (denoted by $ \theta $) in decreasing step size SASPS-N algorithm tracking the true parameters (denoted by blue scatter points $ \phi $) for different learners (rows) and oracle states (columns)

    Table 1.  Summary of the mathematical notation used in the text

    Symbol Description
    $ k $ Time index for the stochastic gradient descent (SGD)
    $ m $ Time index for stochastic approximation for structured policy search (SASPS) algorithm
    $ \hat{x} $ Learner's estimate of the minima
    $ s $ State variable for the oracle state, learner state, and arrival state
    $ q $ query posed to the oracle by the learner
    $ u $ Action of the learner (Learning or No Learning)
    $ r $ Noisy gradient evaluation by the oracle of function at query point
    $ \sigma _k $ Variance of the noise added at time $ k $
    $ \sigma $ Tolerance parameter of learner for noise variance
    $ \Theta $ True threshold parameter of the optimal policy
    $ \tilde{\Theta} $ Threshold parameter of the SASPS tracking the true threshold parameter
    $ \mu $, $ \gamma $ Step size of the SGD and SASPS respectively
    $ \beta $ Discount Factor
    $ l $ Learning Cost
    $ c $ Queue Cost
