blog by Mitchell Wortsman, Alvaro Herrasti, Sarah Pratt, Ali Farhadi and Mohammad Rastegari from The Allen Institute for Artificial Intelligence, University of Washington, and XNOR.AI.

In this post we discuss the most interesting contributions from our recent paper on Discovering Neural Wirings (to appear at NeurIPS 2019). Traditionally, the connectivity patterns of Artificial Neural Networks (ANNs) are manually defined or largely constrained. In contrast, we relax the typical notion of layers to allow for a much larger space of possible wirings. The wiring of our ANN is not fixed during training – as we learn the network parameters we also learn the connectivity.

In our pursuit we arrive at the following conclusion: it is possible to train a model that is small during inference but still overparameterized during training. By applying our method to discover sparse ANNs we bridge the gap between neural architecture search and sparse neural network learning.

Move the slider below to see how the wiring changes when a small network is trained on MNIST for a few epochs (viz code here).

0

Why Wiring?

Before the advent of modern ANNs, researchers would manually engineer good features (high dimensional vector representations). Good features may now be learned with ANNs, but the architecture of the ANN must be specified instead. Accordingly, a myriad of recent work in Neural Architecture Search (NAS) has focused on learning the architecture of an ANN. However, NAS still searches among a set of manually designed building blocks and so ANN connectivity remains largely constrained. In contrast, RandWire explores a diverse set of connectivity patterns by considering ANNs which are randomly wired. Although randomly wired ANNs are competitive with NAS, their connectivity is fixed during training.

We propose a method for jointly learning the parameters and wiring of an ANN during training. We demonstrate that our method of Discovering Neural Wirings (DNW) outperforms many manually-designed and randomly wired ANNs.

ANNs are inspired by the biological neural networks of the animal brain. Despite the countless fundamental differences between these two systems, a biological inspiration may still prove useful. A recent Nature Communications article (aptly titled A critique of pure learning and what artificial neural networks can learn from animal brains) argues that the connectivity of an animal brain enables rapid learning. Accordingly, the article suggests “wiring topology and network architecture as a target for optimization in artificial systems.” We hope that this work provides a useful step in this direction.

Concurrent work on Weight Agnostic Neural Networks also emphasizes the importance of ANN wiring. They demonstrate that a given wiring for an ANN can effectively solve some simple tasks without any training – the solution is encoded in the connectivity.

Static Neural Graphs (SNGs): A Convenient Abstraction for a Feed Forward ANN

We now describe a convenient abstraction for a feed-forward ANN – a Static Neural Graph (SNG). Our goal is then to learn the optimal edge set of the SNG. We skim over some low level details below and invite you to reference the paper, though this abstraction should feel familiar.

An SNG is a directed acyclic graph which consists of nodes and edges . Additionally, each node has output and input . Input data flows into the network through a designated set of nodes , and the input to node is a weighted sum of the parent’s outputs

The output of each node is computed via a parameterized function

where the edge weights and are learnable network parameters. The output of the network is then computed via a designated set of nodes .

In this work we are designing models for Computer Vision, and so each node resembles a single channel (as illustrated below). Accordingly, performs a convolution over a 2-dimensional matrix (followed by BatchNorm and ReLU).

An Algorithm for Discovering Neural Wirings

How do we learn the optimal edge set during training? We follow recent work (such as The Lottery Ticket Hypothesis) and attribute the “importance” of a parameter to its magnitude. Accordingly, the link between node and is considered important if .

At each training iteration the edge set is chosen by taking the highest magnitude weights:

where is chosen so that there are exactly edges and ensures that the graph is acyclic.

All that remains is to choose a weight update for . Recall that most use backpropogation where gradients from a loss term are passed backwards through the network. Using the chain rule gradients may be computed with respect to each network parameter. The parameters are then often updated via stochastic gradient descent with some learning rate . Conveniently, standard backprop will automatically compute the quantity

Informally, the quantity describes how the network wants to change so that the loss will decrease. Our rule is therefore to strengthen the connection between and when aligns with . In other words, if node can take where the loss wants it to go, we should increase the edge weight from to . We therefore modify via our update rule

where denotes an inner product (these quantities are implicitly treated as vectors).

In practice changes only a small amount at each training iteration. However, if consistently aligns with then will strengthen to a point where edge replaces a weaker edge. We show below that when swapping does occur, it is beneficial under some assumptions.

When training, the rest of the network is updated via backprop as usual. In fact, you may notice that the update rule exactly resembles SGD when . And so the algorithm may be interpreted equivalently as allowing the gradient to flow to, but not through, a set of “potential” edges. In practice we include a momentum and weight decay term as is standard practice with SGD (weight decay should eventually remove dead ends).

As we show in the paper, this is equivelantly a straight-through estimator.

Putting it Together

Wirings at Scale

We employ the following two strategies for discovering wirings at scale:

  1. We chain together multiple graphs, where the output of is the input to . The input nodes perform a strided convolution and the spatial resolution is fixed throughout the remaining nodes in the graph.

  2. The depth of the graph is limited to be by partitioning the nodes into blocks . We then only allow connections between nodes and if .

For an even comparison, we consider the exact same structure and number of edges as MobileNet V1 if it were interpreted as a chain of graphs. By learning the connectivity we boost the ImageNet accuracy by ~10% in the low compute setting.

Model ImageNet Top-1 Accuracy
Original MobileNet V1 (x 0.25) 50.6 %
Random Graph MobileNet V1 (x 0.225) 53.3 %
Discovered MobileNet V1 (x 0.225) 60.9 %

Sparse Networks? Lottery Tickets? Overparameterization?

The past few years have witnessed illuminating work in the field of sparse ANNs. In The Lottery Ticket Hypothesis, Frankle and Carbin demonstrate that dense ANNs contain subnetworks that can be effectively trained in isolation. However, their process for uncovering these so-called winning tickets is expensive as it first requires a dense network to be trained. In Sparse Networks from Scratch, Dettmers and Zettlemoyer introduce sparse learning – training ANNs only once while maintaining sparse weights throughout.

Our work aims to unify the problem of neural architecture search with sparse neural network learning. As NAS becomes less restrictive and more fine grained, finding a good architecture is akin to finding a sparse subnetwork of the complete graph.

Accordingly, we may use our algorithm for Discovering Neural Wirings and apply it to the task of training other sparse ANNs. Our method requires no fine-tuning or retraining to discover a sparse subnetwork. This perspective was guided by Dettmers and Zettelmoyer, though we would like to highlight some differences. Their work enables faster training, though our backwards pass is still dense. Moreover, their work allows a redistribution of parameters across layers whereas we consider a fixed sparisty per layer. Finally, their training is more memory efficient as they actually send unused weights to zero while we continue to update them in the backwards pass.

We leave the biases and batchnorm dense and use a ResNet-50 v1.5. This mirrors the experimental set-up from Appendix C of Sparse Networks from Scratch. The figure below illustrates how top-1 accuracy varies with the sparsity (of the convolutional filters and linear weight – a sparsity of 0% corresponds to the dense network). The figure also shows an alternative setting where the first convolutional layer (with < 10k parameters 0.04% of the total network) is left dense.

To generate the figure above we consider only multiples of 10% and the rest is interpolated. All models and numbers will may be found at our Github though we provide relevant ImageNet Top-1 Accuracy metrics for ResNet-50 v1.5 below.

Model 10% of Weights 20% of Weights
First Layer Dense (ours) 75.0 % 76.6 %
All Layers Sparse (ours) 74.0 % 76.2 %
Sparse Networks from Scratch (Dettmers & Zettelmoyer) 72.9 % 74.9 %

We would like to highlight an interesting conclusion we may draw from this result: It is possible to realize the benefits of overparameterization during training even when the resulting model is sparse. Though we only ever use a small percentage of the weights during the forwards pass, our network has good odds at winning the initialization lottery.

The implementation for training sparse ANNs with our algorithm is quite simple. We implicitly treat each parameter as an edge and so all convolutions are replaced with the following pytorch code:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autograd


class ChooseEdges(autograd.Function):
    @staticmethod
    def forward(ctx, weight, prune_rate):
        output = weight.clone()
        _, idx = weight.flatten().abs().sort()
        p = int(prune_rate * weight.numel())
        # flat_oup and output access the same memory.
        flat_oup = output.flatten()
        flat_oup[idx[:p]] = 0
        return output

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None


class SparseConv(nn.Conv2d):
    def set_prune_rate(self, prune_rate):
        self.prune_rate = prune_rate

    def forward(self, x):
        w = ChooseEdges.apply(self.weight, self.prune_rate)
        x = F.conv2d(
            x, w,
            self.bias, self.stride, self.padding, self.dilation, self.groups
        )
        return x

The figure below illustrates how this code works. It also emphasizes that we are using a straight-through estimator (see paper for more discussion).

Discovering Neural Wirings for Dynamic Neural Graphs (DNGs)

We may also apply our algorithm to processes on graphs, where nodes receive input and produce output at all times (and the graph is not restricted to be a DAG). In the discrete time setting, we consider times and let the input and output vary with (we skim over some low level details – i.e. the initial conditions – and invite you to reference the paper). The state of node at time is then given by

We may also consider the setting where takes on a continuous range of values (as in Neural Ordinary Differential Equations). The ANN evolves according to the following dynamics:

We apply our algorithm for Discovering Neural Wirings to a tiny (41k parameter) classifier in both the static and dynamic setting.

Model Accuracy (CIFAR-10)
Static (Random Graph) 76.1 0.5
Static (ours) 80.9 0.6
Discrete Time (Random Graph) 77.3 0.7
Discrete Time (ours) 82.3 0.6
Continuous (Random Graph) 78.5 1.2
Continuous (ours) 83.1 0.3

Proofs

Here we briefly show that when swapping edges does occur, it is beneficial under some assumptions. Consider a Static Neural Graph where edge replaces edge after the gradient update for the mini-batch. We may show that when the learning rate is sufficiently small and the node states are fixed then the loss will decrease for the current mini-batch. We skim over some details here and invite you to reference the paper, for example we must also assume that the loss is Lipschitz continuous. The paper also includes a proof of a more general case.

We let denote the weight after the gradient update. And so by our update rule we let

where . If the learning rate is sufficiently small then will be close to and so we assume that . Since edge replaces edge we know that and as we choose edges by taking the highest magnitude weights.

Let be the new input to node if swapping is allowed. Likewise, let be the new input to node if swapping is not allowed. It suffices to show that as we observe that via a Taylor approximation.

may be written as and may be written as and so

.

If the magnitude of increases while the magnitude of decreases then we are done as the left hand side will be positive while the right hand side will be negative. We now examine the case where both magnitudes either increase or increase.

As we are left to show that . Moreover, since we may simplify to .

Rearranging we find that where the right hand side is negative by assumption. Moreover, when the learning rate is small enough the left hand side is positive. If . In the extreme case where just exceeds then

.

Citing

If you found this work to be helpful please consider citing:

@article{Wortsman2019DiscoveringNW,
  title={Discovering Neural Wirings},
  author={Mitchell Wortsman and Ali Farhadi and Mohammad Rastegari},
  journal={ArXiv},
  year={2019},
  volume={abs/1906.00586}
}

Acknowledgements

We sincerely thank Tim Dettmers for his assistance and guidance in the experiments regarding sparse networks.

Comments and FAQ

To comment raise an issue on our github repo here.