What?
We like bayesian inference over neural networks. This however, is very computationally costly. Some have tried to do bayesian inference on only a subspace of the parameter set (Efficient and Scalable Bayesian Neural Nets with Rank-1 Factors: Dusenberry et al 2020), while others have tried to model the model noise by adding stochasticity to the nodes rather than weights (Kingma et al 2015 and others).
This paper focusses on node-based stochasticity.
How?
The authors define a node-based fully connected NN as
\[\begin{align} \mathbf{f}_{\mathcal{Z}}^{0}(\mathbf{x}) &=\mathbf{x} \\ \mathbf{h}_{\mathcal{Z}}^{\ell}(\mathbf{x}) &=(\mathbf{W}^\ell(\mathbf{f}_\mathcal{Z}^{\ell-1}(\mathbf{x})\circ\mathbf{z}^\ell)+\mathbf{b}^\ell)\circ\mathbf{s}^\ell \\ \mathbf{f}_{\mathcal{Z}}^\ell(\mathbf{x}) &=\sigma^\ell\big(\mathbf{h}_{\mathcal{Z}}^\ell(\mathbf{x})\big),\quad\forall\ell=1,\ldots,L \\ \mathbf{f}_{\mathcal{Z}}(\mathbf{x}) &=\mathbf{f}_{\mathcal{Z}}^L(\mathbf{x}), \end{align}\]where $\mathbf{s}^\ell$ and $\mathbf{z}^\ell$ are random variables collected in $\mathcal{Z}=\{\mathbf{z}^\ell, \mathbf{s}^\ell\}_{\ell=1}^L$ and the deterministic weights are collected in $\mathcal{\theta}=\{\mathbf{W}^\ell\}_{\ell=1}^L$.
The authors define an independent prior $p(\theta,\mathcal{Z})=p(\theta)p(\mathcal{Z})$ and then do variational inference on the loss function
\[\begin{aligned}\mathcal{L}(\hat{\theta},\phi)=\mathbb{E}_{q_\phi(\mathcal{Z})}\bigg[\log p(\mathcal{D}|\hat{\theta},\mathcal{Z})\bigg] -\operatorname{KL}\bigg[q_\phi(\mathcal{Z}) || p(\mathcal{Z})\bigg]+\log p(\hat{\theta}).\end{aligned}\]In essence, we find a MAP solution for the more numerous weights θ, while inferring the posterior distribution of the latent variables Z
This is all the same as in Efficient and Scalable Bayesian Neural Nets with Rank-1 Factors: Dusenberry et al 2020.
In order to model covariate shift in data (such as image noise, hardware differences on x-rays etc), they add some shifting function $\mathbf g^0(\mathbf x)$ such that the corrupted data point $\mathbf{x}^c = \mathbf x + \mathbf g^0(\mathbf x)$. This corruption can then be tracked down through the layers such that
\[\begin{aligned} \overbrace{\mathbf{g}^\ell(\mathbf{x})}^{\mathrm{shift}}& =\overbrace{\mathbf{f}^\ell(\mathbf{x}^c)}^{\text{corrupted output}}-\overbrace{\mathbf{f}^\ell(\mathbf{x})}^{\text{clean output}} \\ &\boldsymbol{\approx}\mathbf{J}_\sigma\left[\mathbf{h}^\ell(\mathbf{x})\right]\left(\mathbf{W}^\ell\mathbf{g}^{\ell\boldsymbol{-}1}(\mathbf{x})\right), \end{aligned}\]The question then is: can we reduce the shift at the output layer $L$? The answer - according to the authors - is that we can do exactly this, by training a node-based Bayesian network where we add similar noise when training the network, but at all the nodes in the network instead of only the input layer.