Training And Inference with Integers in Deep Neural Networks
Introduction
Deep neural networks have enjoyed much success in all manners of tasks, but it is common for these networks to be complicated, requiring large amounts of energy-intensive memory and floating-point operations. Therefore, in order to use state-of-the-art networks in applications where energy is limited or having packaging limitation for hardware, such as anything not connected to the power grid, the energy costs must be reduced while preserving as much performance as practical.
Most existing methods focus on reducing the energy requirements during inference rather than training. Since training with SGD requires accumulation, training usually has higher precision demand than inference. Most of the existing methods focus on how to compress a model for inference, rather than during training. This paper proposes a framework to reduce complexity both during training and inference through the use of integers instead of floats. They address how to quantize all operations and operands as well as examining the bitwidth requirement for SGD computation & accumulation. Using integers instead of floats results in energy-savings because integer operations are more efficient than floating point (see the table below). Also, there already exists dedicated hardware for deep learning that uses integer operations (such as the 1st generation of Google TPU) so understanding the best way to use integers is well-motivated.
Energy(pJ) | Area([math]\displaystyle{ \mu m^2 }[/math]) | |||
---|---|---|---|---|
Operation | MUL | ADD | MUL | ADD |
8-bit INT | 0.2 | 0.03 | 282 | 36 |
16-bit FP | 1.1 | 0.4 | 1640 | 1360 |
32-bit FP | 3.7 | 0.9 | 7700 | 4184 |
The authors call the framework WAGE because they consider how best to handle the Weights, Activations, Gradients, and Errors separately.
Related Work
Weight and Activation
Existing works on binary weights and activations <ref>Template:Cite journal</ref> still use high-precision accumulation for SGD. Ternary weight networks<ref>Template:Cite journal</ref> offer more expression than binary weight networks.
Gradient Computation and Accumulation
Some methods quantize gradients in the backwards pass, but the weights are still stored in float32, and batch normalization is ignored.
WAGE Quantization
The core idea of the proposed method is to constrain the following to low-bitwidth integers on each layer:
- W: weight in inference
- a: activation in inference
- e: error in backpropagation
- g: gradient in backpropagation
The error and gradient are defined as:
[math]\displaystyle{ e^i = \frac{\partial L}{\partial a^i}, g^i = \frac{\partial L}{\partial W^i} }[/math]
where L is the loss function.
The precision in bits of the errors, activations, gradients, and weights are [math]\displaystyle{ k_E }[/math], [math]\displaystyle{ k_A }[/math], [math]\displaystyle{ k_G }[/math], and [math]\displaystyle{ k_W }[/math] respectively. As shown in the above figure, each quantity also has a quantization operators to reduce bitwidth increases caused by multiply-accumulate (MAC) operations. Also, note that since this is a layer-by-layer approach, each layer may be followed or preceded by a layer with different precision, or even a layer using floating point math.
Shift-Based Linear Mapping and Stochastic Mapping
The proposed method makes use of a linear mapping where continuous, unbounded values are discretized for each bitwidth [math]\displaystyle{ k }[/math] with a uniform spacing of
[math]\displaystyle{ \sigma(k) = 2^{1-k}, k \in \natnums_+ }[/math]
With this, the full quantization function is
[math]\displaystyle{ Q(x,k) = Clip\left \{ \sigma(k) \cdot round\left [ \frac{x}{\sigma(k)} \right ], -1 + \sigma(k), 1 - \sigma(k) \right \} }[/math]
Note that this function is only using when simulating integer operations on floating-point hardware, on native integer hardware, this is done automatically. In addition to this quantization function.
A distribution scaling factor is used in some quantization operators to preserve as much variance as possible when applying the quantization function above. The scaling factor is defined below.
[math]\displaystyle{ Shift(x) = 2^{round(log_2(x))} }[/math]
Finally, stochastic rounding is substituted for small or real-valued updates during gradient accumulation.
A visual representation of these operations is below.
Weight Initialization
In this work, batch normalization is simplified to a constant scaling layer in order to sidestep the problem of normalizing outputs without floating point math, and to remove the extra memory requirement with batch normalization. As such, some care must be taken when initializing weights. The authors use a modified initialization method base on MSRA <ref>Template:Cite journal</ref>.
[math]\displaystyle{ W \thicksim U(-L, +L),L = max \left \{ \sqrt{6/n_{in}}, L_{min} \right \}, L_{min} = \beta \sigma }[/math]
[math]\displaystyle{ n_{in} }[/math] is the layer fan-in number, [math]\displaystyle{ U }[/math] denotes uniform distribution. The original[math]\displaystyle{ \eta }[/math] initialization method is modified by adding the condition that the distribution width should be at least [math]\displaystyle{ \beta \sigma }[/math], where [math]\displaystyle{ \beta }[/math] is a constant greater than 1 and [math]\displaystyle{ \sigma }[/math] is the minimum step size see already. This prevents weights being initialised to all-zeros in the case where the bitwidth is low, or the fan-in number is high.
Quantization Details
Weight [math]\displaystyle{ Q_W(\cdot) }[/math]
[math]\displaystyle{ W_q = Q_W(W) = Q(W, k_W) }[/math]
The quantization operator is simply the quantization function previously introduced.
Activation [math]\displaystyle{ Q_A(\cdot) }[/math]
The authors say that the variance of the weights passed through this function will be scaled compared to the variance of the weights as initialized. To prevent this effect from blowing up the network outputs, they introduce a scaling factor [math]\displaystyle{ \alpha }[/math]. Notice that it is constant for each layer.
[math]\displaystyle{ \alpha = max \left \{ Shift(L_{min} / L), 1 \right \} }[/math]
The quantization operator is then
[math]\displaystyle{ a_q = Q_A(a) = Q(a/\alpha, k_A) }[/math]
The scaling factor approximates batch normalization.
Error [math]\displaystyle{ Q_E(\cdot) }[/math]
The magnitude of the error can vary greatly, and that a previous approach (DoReFa-Net<ref>Template:Cite journal</ref>) solves the issue by using an affine transform to map the error to the range [math]\displaystyle{ [-1, 1] }[/math], apply quantization, and then applying the inverse transform. However, the authors claim that this approach still requires using float32, and that the magnitude of the error is unimportant: rather it is the orientation of the error. Thus, they only scale the error distribution to the range [math]\displaystyle{ \left [ -\sqrt2, \sqrt2 \right ] }[/math] and quantise:
[math]\displaystyle{ e_q = Q_E(e) = Q(e/Shift(max\{|e|\}), k_E) }[/math]
Max is the element-wise maximum. Note that this discards any error elements less than the minimum step size.
Gradient [math]\displaystyle{ Q_G(\cdot) }[/math]
Similar to the activations and errors, the gradients are rescaled:
[math]\displaystyle{ g_s = \eta \cdot g/Shift(max\{|g|\}) }[/math]
[math]\displaystyle{ \eta }[/math] is a shift-based learning rate. It is an integer power of 2. The shifted gradients are represented in units of minimum step sizes [math]\displaystyle{ \sigma(k) }[/math]. When reducing the bitwidth of the gradients (remember that the gradients are coming out of a MAC operation, so the bitwidth may have increased) stochastic rounding is used as a substitute for small gradient accumulation.
[math]\displaystyle{ \Delta W = Q_G(g) = \sigma(k_G) \cdot sgn(g_s) \cdot \left \{ \lfloor | g_s | \rfloor + Bernoulli(|g_s| - \lfloor | g_s | \rfloor) \right \} }[/math]
This randomly rounds the result of the MAC operation up or down to the nearest quantization for the given gradient bitwidth. The weights are updated with the resulting discrete increments:
[math]\displaystyle{ W_{t+1} = Clip \left \{ W_t - \Delta W_t, -1 + \sigma(k_G), 1 - \sigma(k_G) \right \} }[/math]
Miscellaneous
To train WAGE networks, the authors used pure SGD exclusively because more complicated techniques such as Momentum or RMSProp increase memory consumption and are complicated by the rescaling that happens within each quantization operator.
The quantization and stochastic rounding are a form of regularization.
The authors didn't use a traditional softmax with cross-entropy loss for the experiments because there does not yet exist a softmax layer for low-bit integers. Instead, they use a sum of squared error loss. This works for tasks with a small number of categories, but does not scale well.
Experiments
For all experiments, the default layer bitwidth configuration is 2-8-8-8 for Weights, Activations, Gradients, and Error bits. The weight bitwidth is set to 2 because that results in ternary weights, and therefore no multiplication during inference. They authors argue that the bitwidth for activation and errors should be the same because the computation graph for each is similar and might use the same hardware. During training, the weight bitwidth is 8. For inference the weights are ternarized.
Implementation Details
MNIST: Network is LeNet-5 variant <ref>Template:Cite journal</ref>
SVHN & CIFAR10: VGG variant <ref>Template:Cite journal</ref>
ImageNet: AlexNet variant <ref>Template:Cite book</ref>
Method | [math]\displaystyle{ k_W }[/math] | [math]\displaystyle{ k_A }[/math] | [math]\displaystyle{ k_G }[/math] | [math]\displaystyle{ k_E }[/math] | Opt | BN | MNIST | SVHN | CIFAR10 | ImageNet |
---|---|---|---|---|---|---|---|---|---|---|
BC | 1 | 32 | 32 | 32 | Adam | yes | 1.29 | 2.30 | 9.90 | |
BNN | 1 | 1 | 32 | 32 | Adam | yes | 0.96 | 2.53 | 10.15 | |
BWN | 1 | 32 | 32 | 32 | withM | yes | 43.2/20.6 | |||
XNOR | 1 | 1 | 32 | 32 | Adam | yes | 55.8/30.8 | |||
TWN | 2 | 32 | 32 | 32 | withM | yes | 0.65 | 7.44 | 34.7/13.8 | |
TTQ | 2 | 32 | 32 | 32 | Adam | yes | 6.44 | 42.5/20.3 | ||
DoReFa | 8 | 8 | 32 | 8 | Adam | yes | 2.30 | 47.0/ | ||
TernGrad | 32 | 32 | 2 | 32 | Adam | yes | 14.36 | 42.4/19.5 | ||
WAGE | 2 | 8 | 8 | 8 | SGD | no | 0.40 | 1.92 | 6.78 | 51.6/27.8 |
Training Curves and Regularization
The authors compare the 2-8-8-8 WAGE configuration introduced above, a 2-8-f-f (meaning float32) configuration, and a completely floating point version on CIFAR10. The test error is plotted against epoch. For training these networks, the learning rate is divided by 8 at the 200th epoch and again at the 250th epoch.
The convergence of the 2-8-8-8 has comparable convergence to the vanilla CNN and outperforms the 2-8-f-f variant. The authors speculate that this is because the extra discretization acts as a regularizer.
Bitwidth of Errors
The CIFAR10 test accuracy is plotted against bitwidth below
Bitwidth of Gradients
[math]\displaystyle{ k_G }[/math] | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 |
---|---|---|---|---|---|---|---|---|---|---|---|
error | 54.22 | 51.57 | 28.22 | 18.01 | 11.48 | 7.61 | 6.78 | 6.63 | 6.43 | 6.55 | 6.57 |
The authors also examined the effect of bitwidth on the ImageNet implementation.
Pattern | vanilla | 28ff-BN | 28ff | 28f8 | 28C8 | 288C | 2888 |
---|---|---|---|---|---|---|---|
error | 19.29 | 20.67 | 24.14 | 23.92 | 26.88 | 28.06 | 27.82 |
Here, C denotes 12 bits (Hexidecimal) and BN refers to batch normalization being added.
Discussion
The authors have a few areas they believe this approach could be improved.
MAC Operation: The 2-8-8-8 configuration was chosen because the low weight bitwidth means there aren't any multiplication during inference. However, this does not remove the requirement for multiplication during training. 2-2-8-8 configuration satisfies this requirement, but it is difficult to train and detrimental to the accuracy.
Non-linear Quantization: The linear mapping used in this approach is simple, but there might be a more effective mapping. For example, a logarithmic mapping could be more effective if the weights and activations have a log-normal distribution.
Normalization: Normalization layers (softmax, batch normalization) were not used in this paper. Quantized versions are an area of future work
Conclusion
A framework for training and inference without the use of floating-point representation is presented. Future work may further improve compression and memory requirements.