Training And Inference with Integers in Deep Neural Networks

(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)

Introduction

Deep neural networks have enjoyed much success in all manners of tasks, but it is common for these networks to be complicated, requiring large amounts of energy-intensive memory and floating-point operations. Therefore, in order to use state-of-the-art networks in applications where energy is limited or having packaging limitation for hardware, such as anything not connected to the power grid, the energy costs must be reduced while preserving as much performance as practical.

Most existing methods focus on reducing the energy requirements during inference rather than training. Since training with SGD requires accumulation, training usually has higher precision demand than inference. Most of the existing methods focus on how to compress a model for inference, rather than during training. This paper proposes a framework to reduce complexity both during training and inference through the use of integers instead of floats. They address how to quantize all operations and operands as well as examining the bitwidth requirement for SGD computation & accumulation. Using integers instead of floats results in energy-savings because integer operations are more efficient than floating point (see the table below). Also, there already exists dedicated hardware for deep learning that uses integer operations (such as the 1st generation of Google TPU) so understanding the best way to use integers is well-motivated.

Rough Energy Costs in 45nm 0.9V <ref>Template:Cite journal</ref>
Energy(pJ) Area($\mu m^2$)
8-bit INT 0.2 0.03 282 36
16-bit FP 1.1 0.4 1640 1360
32-bit FP 3.7 0.9 7700 4184

The authors call the framework WAGE because they consider how best to handle the Weights, Activations, Gradients, and Errors separately.

Related Work

Weight and Activation

Existing works on binary weights and activations <ref>Template:Cite journal</ref> still use high-precision accumulation for SGD. Ternary weight networks<ref>Template:Cite journal</ref> offer more expression than binary weight networks.

Some methods quantize gradients in the backwards pass, but the weights are still stored in float32, and batch normalization is ignored.

WAGE Quantization

The core idea of the proposed method is to constrain the following to low-bitwidth integers on each layer:

• W: weight in inference
• a: activation in inference
• e: error in backpropagation
File:Sandbox, gardens of Schönbrunn.jpg
Four operators QW (·), QA(·), QG(·), QE(·) added in WAGE computation dataflow to reduce precision, bitwidth of signed integers are below or on the right of arrows, activations are included in MAC for concision.

The error and gradient are defined as:

$e^i = \frac{\partial L}{\partial a^i}, g^i = \frac{\partial L}{\partial W^i}$

where L is the loss function.

The precision in bits of the errors, activations, gradients, and weights are $k_E$, $k_A$, $k_G$, and $k_W$ respectively. As shown in the above figure, each quantity also has a quantization operators to reduce bitwidth increases caused by multiply-accumulate (MAC) operations. Also, note that since this is a layer-by-layer approach, each layer may be followed or preceded by a layer with different precision, or even a layer using floating point math.

Shift-Based Linear Mapping and Stochastic Mapping

The proposed method makes use of a linear mapping where continuous, unbounded values are discretized for each bitwidth $k$ with a uniform spacing of

$\sigma(k) = 2^{1-k}, k \in \natnums_+$

With this, the full quantization function is

$Q(x,k) = Clip\left \{ \sigma(k) \cdot round\left [ \frac{x}{\sigma(k)} \right ], -1 + \sigma(k), 1 - \sigma(k) \right \}$

Note that this function is only using when simulating integer operations on floating-point hardware, on native integer hardware, this is done automatically. In addition to this quantization function.

A distribution scaling factor is used in some quantization operators to preserve as much variance as possible when applying the quantization function above. The scaling factor is defined below.

$Shift(x) = 2^{round(log_2(x))}$

Finally, stochastic rounding is substituted for small or real-valued updates during gradient accumulation.

A visual representation of these operations is below.

File:Sanxbox.JPG
Quantization methods used in WAGE. The notation $P, x, \lfloor \cdot \rfloor, \lceil \cdot \rceil$ denotes probability, vector, floor and ceil, respectively. $Shift(\cdot)$ refers to distribution shifting with a certain argument

Weight Initialization

In this work, batch normalization is simplified to a constant scaling layer in order to sidestep the problem of normalizing outputs without floating point math, and to remove the extra memory requirement with batch normalization. As such, some care must be taken when initializing weights. The authors use a modified initialization method base on MSRA <ref>Template:Cite journal</ref>.

$W \thicksim U(-L, +L),L = max \left \{ \sqrt{6/n_{in}}, L_{min} \right \}, L_{min} = \beta \sigma$

$n_{in}$ is the layer fan-in number, $U$ denotes uniform distribution. The original$\eta$ initialization method is modified by adding the condition that the distribution width should be at least $\beta \sigma$, where $\beta$ is a constant greater than 1 and $\sigma$ is the minimum step size see already. This prevents weights being initialised to all-zeros in the case where the bitwidth is low, or the fan-in number is high.

Quantization Details

Weight $Q_W(\cdot)$

$W_q = Q_W(W) = Q(W, k_W)$

The quantization operator is simply the quantization function previously introduced.

Activation $Q_A(\cdot)$

The authors say that the variance of the weights passed through this function will be scaled compared to the variance of the weights as initialized. To prevent this effect from blowing up the network outputs, they introduce a scaling factor $\alpha$. Notice that it is constant for each layer.

$\alpha = max \left \{ Shift(L_{min} / L), 1 \right \}$

The quantization operator is then

$a_q = Q_A(a) = Q(a/\alpha, k_A)$

The scaling factor approximates batch normalization.

Error $Q_E(\cdot)$

The magnitude of the error can vary greatly, and that a previous approach (DoReFa-Net<ref>Template:Cite journal</ref>) solves the issue by using an affine transform to map the error to the range $[-1, 1]$, apply quantization, and then applying the inverse transform. However, the authors claim that this approach still requires using float32, and that the magnitude of the error is unimportant: rather it is the orientation of the error. Thus, they only scale the error distribution to the range $\left [ -\sqrt2, \sqrt2 \right ]$ and quantise:

$e_q = Q_E(e) = Q(e/Shift(max\{|e|\}), k_E)$

Max is the element-wise maximum. Note that this discards any error elements less than the minimum step size.

Gradient $Q_G(\cdot)$

Similar to the activations and errors, the gradients are rescaled:

$g_s = \eta \cdot g/Shift(max\{|g|\})$

$\eta$ is a shift-based learning rate. It is an integer power of 2. The shifted gradients are represented in units of minimum step sizes $\sigma(k)$. When reducing the bitwidth of the gradients (remember that the gradients are coming out of a MAC operation, so the bitwidth may have increased) stochastic rounding is used as a substitute for small gradient accumulation.

$\Delta W = Q_G(g) = \sigma(k_G) \cdot sgn(g_s) \cdot \left \{ \lfloor | g_s | \rfloor + Bernoulli(|g_s| - \lfloor | g_s | \rfloor) \right \}$

This randomly rounds the result of the MAC operation up or down to the nearest quantization for the given gradient bitwidth. The weights are updated with the resulting discrete increments:

$W_{t+1} = Clip \left \{ W_t - \Delta W_t, -1 + \sigma(k_G), 1 - \sigma(k_G) \right \}$

Miscellaneous

To train WAGE networks, the authors used pure SGD exclusively because more complicated techniques such as Momentum or RMSProp increase memory consumption and are complicated by the rescaling that happens within each quantization operator.

The quantization and stochastic rounding are a form of regularization.

The authors didn't use a traditional softmax with cross-entropy loss for the experiments because there does not yet exist a softmax layer for low-bit integers. Instead, they use a sum of squared error loss. This works for tasks with a small number of categories, but does not scale well.

Experiments

For all experiments, the default layer bitwidth configuration is 2-8-8-8 for Weights, Activations, Gradients, and Error bits. The weight bitwidth is set to 2 because that results in ternary weights, and therefore no multiplication during inference. They authors argue that the bitwidth for activation and errors should be the same because the computation graph for each is similar and might use the same hardware. During training, the weight bitwidth is 8. For inference the weights are ternarized.

Implementation Details

MNIST: Network is LeNet-5 variant <ref>Template:Cite journal</ref>

SVHN & CIFAR10: VGG variant <ref>Template:Cite journal</ref>

ImageNet: AlexNet variant <ref>Template:Cite book</ref>

Test or validation error rates (%) in previous works and WAGE on multiple datasets. Opt denotes gradient descent optimizer, withM means SGD with momentum, BN represents batch normalization, 32 bit refers to float32, and ImageNet top-k format: top1/top5.
Method $k_W$ $k_A$ $k_G$ $k_E$ Opt BN MNIST SVHN CIFAR10 ImageNet
BC 1 32 32 32 Adam yes 1.29 2.30 9.90
BNN 1 1 32 32 Adam yes 0.96 2.53 10.15
BWN 1 32 32 32 withM yes 43.2/20.6
XNOR 1 1 32 32 Adam yes 55.8/30.8
TWN 2 32 32 32 withM yes 0.65 7.44 34.7/13.8
TTQ 2 32 32 32 Adam yes 6.44 42.5/20.3
DoReFa 8 8 32 8 Adam yes 2.30 47.0/
WAGE 2 8 8 8 SGD no 0.40 1.92 6.78 51.6/27.8

Training Curves and Regularization

The authors compare the 2-8-8-8 WAGE configuration introduced above, a 2-8-f-f (meaning float32) configuration, and a completely floating point version on CIFAR10. The test error is plotted against epoch. For training these networks, the learning rate is divided by 8 at the 200th epoch and again at the 250th epoch.

File:Project 1164 Moskva 2009 G1.jpg
Training curves of WAGE variations and a vanilla CNN on CIFAR10

The convergence of the 2-8-8-8 has comparable convergence to the vanilla CNN and outperforms the 2-8-f-f variant. The authors speculate that this is because the extra discretization acts as a regularizer.

Bitwidth of Errors

The CIFAR10 test accuracy is plotted against bitwidth below

File:Sandbox.svg
The 10 run accuracies of different $k_E$

Test error rates (%) on CIFAR10 with different $k_G$
$k_G$ 2 3 4 5 6 7 8 9 10 11 12
error 54.22 51.57 28.22 18.01 11.48 7.61 6.78 6.63 6.43 6.55 6.57

The authors also examined the effect of bitwidth on the ImageNet implementation.

Top-5 error rates (%) on ImageNet with different $k_G$and $k_E$
Pattern vanilla 28ff-BN 28ff 28f8 28C8 288C 2888
error 19.29 20.67 24.14 23.92 26.88 28.06 27.82

Here, C denotes 12 bits (Hexidecimal) and BN refers to batch normalization being added.

Discussion

The authors have a few areas they believe this approach could be improved.

MAC Operation: The 2-8-8-8 configuration was chosen because the low weight bitwidth means there aren't any multiplication during inference. However, this does not remove the requirement for multiplication during training. 2-2-8-8 configuration satisfies this requirement, but it is difficult to train and detrimental to the accuracy.

Non-linear Quantization: The linear mapping used in this approach is simple, but there might be a more effective mapping. For example, a logarithmic mapping could be more effective if the weights and activations have a log-normal distribution.

Normalization: Normalization layers (softmax, batch normalization) were not used in this paper. Quantized versions are an area of future work

Conclusion

A framework for training and inference without the use of floating-point representation is presented. Future work may further improve compression and memory requirements.