Stochastic Depth Networks will Become the New Normal

by  on

.. in deep learning that is.

Update: This post apparently made a lot of people mad. Check out my next post after this :-)

Everyday a half dozen or so new deep learning papers come out on ArXiv, but very few catch my eye. Yesterday, I read about "Deep Networks with Stochastic Depth". I think, like dropout and batch normalization, this will be a game changer. For one, the results speak for themselves -- in some cases up to 40% reduction in training time while at the same time beating the state of the art.CIFAR stochastic depth

Figure 1. Error rate vs. Survival Probability (explained later)

Why is that a big deal? The biggest impediment in applying deep learning (or for that matter any S/E process) in product development is turnaround time. If I spend 1 week training my model and then find it is a pile of shit, because I did not initialize something well or the architecture was missing something, that's not good. For this reason, everyone I know wants to get the best GPUs or work on the biggest clusters -- not just it lets them build more expressive networks but simply they're super fast. So, any technique that improves experiment turnaround time is welcome!

The idea is ridiculously simple (perhaps why it is effective?): randomly skip layers while training. As a result you have a network that has expected depth really small, while the maximum depth can be in the order of 1000s. In effect, like dropout training, this creates an ensemble model from the [latex]2^L[/latex] possible networks for an [latex]L[/latex]-layer deep network.Stochastic Net

I also like that, this new method just adds only one (actually two if consider decay scheme) hyperparameter for tuning -- layer survival probability. From their experiments it appears that this hyperparameter is quite low maintenance. Most arbitrary values you pick seem to do well unless you pick something really low (See Figure 1).

Something weird you notice, also from Figure 1., is this training seems to do well (at least on CIFAR data), even when you keep the deepest layers only 20% of the time. Remember all the narratives we told about how depth learns hierarchical representations, and higher level representations -- those higher level representations don't seem to matter so much after all.

Question: For really deep networks can we ditch the model weights at the higher levels to keep the model footprint small enough to fit mobile devices? (In addition to things like binarization etc.)

Expect to see a flurry of papers showing results of Stochastic Depth applied to other network architectures pretty soon.


Copyright © 2021. Delip Rao