Integers: The Theory
Part 2 of a series on native integer machine learning/deep learning
Debug notes
After the previous post I jumped right into debugging.
The increasing runtime in combination with the drastic difference between lr_shift and momentum_shift in the Iris example
gave me reason to believe that the code has a bug in regards to scaling (shifting values).
I am making this assumption out of two reasons.
First, both versions use different arithmetic operations. The f32 uses raw arithmetic operators, e.g. + and *
whereas the i32 uses saturating/wrapping arithmetic, i.e. values are either clipped or wrapped when overflowing.
Second, the i32 version implements gradient clipping.
To get a clearer picture of the current state I added debugging statements to Iris test, checking runtime, overflow stats and outgoing shift of predictions. The results were rather sobering; no wrapping, compute time per epoch is between 2700 - 3000µs (micro seconds). This could be influenced by a gamma ray from outer space. So not really a great measurement for tracking down bugs.
Looking at the shifts of the network revealed something different.
For every training sample we see inside the logs
Epoch # Loss Shift In Forward Shift Backward Shift
Epoch 1 XXXXXX 5 16 16
Epoch 2 XXXXXX 5 16 16
Epoch 3 XXXXXX 5 16 16
...
This doesn't look correct. Especially the backward shift looks completely wrong.
The layers are configured the following way
| Layer | Quant Shift | Incoming Shift | Outgoing Shift |
|---|---|---|---|
| L1 | 0 | 3 | 4 |
| L2 | 0 | 4 | 4 |
| L3 | 0 | 4 | 2 |
Our input is created with the QuantizationMethod::StandardScore, so it carries a shift of 5.
The calculation for the final shift of the forward pass
In our example
This shows that our forward pass is already computing the correct shift. The backward pass should reduce this shift again, but it doesn't. Hence an issue is inside the backward pass.
A fresh set of eyes was one of the main drivers behind the initial solution. After adjustment we ran the experiment again.
Architecture: 4 -> 8 -> 8 -> 3 with ReLU activations and MSE loss.
We now reach higher accuracy on i32 compared to its f32 counter implementation. The i32 branch converges within 239 epochs reaching an accuracy of 95%, whereas f32 never reaches this level.
| f32 | i32 | |
|---|---|---|
| Test accuracy | 90% | 95% |
| Epochs | 500 | 239 (early stopped) |
lr_shift |
7 | 7 |
momentum_shift |
1 | 1 |
| Batch size | 32 | 32 |
| Grad. Clip value | - |
Reason for this "success" were two things. Fixing the internal shift bookkeeping, Linear<S>'s backward pass was using Linear::output_shift to transform the gradient w.r.t inputs. This resulted in wrong shifts for subsequent layers. And Linear<S>'s backward pass did not return a reduced s_g (gradient shift), but rather just forwarded the incoming value.
Apart from the bookkeeping, drastically increasing the gradient clip value to allow more information bytes flowing through the network helped to get more stable results.
The diff for the changes inside Linear<S>::backward
- let local_delta_w = s_g + s_x;
+ let local_delta_w = s_g - s_x;
- let local_delta_x = self.weights.quant_shift + self.output_shift;
+ let local_delta_x = self.weights.quant_shift + self.input_shift;
...
- let s_g_prev = s_g;//.saturating_sub(local_delta_x);
+ let s_g_prev = s_g.saturating_sub(local_delta_x);
Especially the update to s_g was something I had implemented earlier, but forgot an didn't discover up until today.
Impact on MNIST
The change was clearly fixing a bug, otherwise our performance on Iris would not look like this. Running quickly some epochs shows that at least the increasing compute time issue that we had in the previous post is resolved and we have a more stable value there
Epoch Loss Accuracy Time(s) Clamps FW Shift BW Shift
────────────────────────────────────────────────────────────────────────────────
0 1612.0000 9.4% 9.12s -1 31 5
1 1612.0000 9.4% 9.09s -1 31 5
2 1612.0000 9.4% 9.08s -1 31 5
3 1612.0000 9.4% 9.09s -1 31 5
4 1612.0000 9.4% 9.08s -1 31 5
5 1612.0000 9.4% 9.08s -1 31 5
6 1612.0000 9.4% 9.07s -1 31 5
7 1612.0000 9.4% 9.10s -1 31 5
8 1612.0000 9.4% 9.08s -1 31 5
9 1612.0000 9.4% 9.09s -1 31 5
Tip
Obviously, the network in these logs did not learn anything, instead pay attention to the now more or less stable runtime.
The core math
Right now, we're adjusting things and seemingly fix issues that we don't fully understand. Instead of running around like a headless chicken, we will now look closer into the theory behind the implementation of native integer training. After we defined all the required theory, we will implement it and hopefully resolve all issues we faced so far, for good.
Motivation
The fundamental problem with integer arithmetic is that multiplication accumulates scale. When you multiply two integers, the result lives in a different "unit" than either operand, and without an explicit mechanism to track and correct for that, values either explode or collapse across layers in a artificial neural network. The structure we want is one where the scale is part of the value itself, so that every operation is well-defined in terms of what the numbers actually mean, not just what bits they hold.
The carrier set
Let
An element
So a larger
Note
The formal definition admits
, allowing negative scale exponents, which would represent values scaled up by a power of two rather than down. In the implementation, shift values are typed as u32, restricting. This means the system can only represent dyadic rationals of magnitude , never larger. The formal generality is retained here because the mathematical properties of the operations hold regardless, but any concrete instantiation should be understood as working in .
Non-uniqueness
It is critical to note that
For instance,
Scale alignment
Before we can define addition and substraction, we need a way to bring two elements into a common representation.
Given
The element with the finer scale has its mantissa right-shifted (divided by
You can only add
Alignment is lossy in general, i.e. the bits shifted out are gone for good. This precision loss during alignment is the fundamental source of rounding error in this system, which is the equivalent to floating-point roundoff.
Stochastic scaling
Stochastic scaling/Requantization handles two things. Rescaling (adjusting the shift to a new target) and clipping (enforcing the hardware bit-width boundaries).
Rescaling
Let
We define the stochastic rounding function
where
The requantization operator
where
Let
We can derive the values of
For targets for a signed two's complement integer (e.g. INT8 where
For unsigned integer targets (e.g. UINT8):
Arithmetic operations
We can now define the four operations. All four preserve the invariant that the result is again an element of
Addition and subtraction require alignment first, then operate directly on mantissas
Multiplication does not require alignment because the scales compose additively. The exact
product of two dyadic numbers has mantissa
Note that the scale of the result is adjusted by
Division has the opposite problem. Integer division truncates, so a precision shift
Again the scale is adjusted to compensate, so
A note on approximation
An important consequence of the above is that
The lr_shift, quant_shift and grad_shift parameters are all instances of this single decision appearing in different contexts.
This provides all necessary building blocks for the forward pass of a neural network using
Gradients
But this framework introduces a new problem.
Due to the randomness of
Straight-Through Estimator
To overcome this we will formalize a Straight-Through Estimator (STE). This STE essentially defines a "fake" derivative for the backward pass that ignores the discrete steps and pretends the function was continuous.
Let
During backpropagation, we receive the gradient of the loss w.r.t.
We break
- The STE for
:
In continuous space, the rounding function is an approximation of the identity function (scaled by the bit-shift). The STE assumes the gradient passes straight through the rounding operation unmodified.
- The STE for
:
For the clipping function, the gradient passes through identically if the value was within the representable bounds during the forward pass.
If the value was clipped to
Using the indicator function
Note that we evaluate the bounds check using
Combining these, the gradient simply flows straight through the rounding operator but gets masked by the clipping bounds.
Let
The gradient
Implementation
Now that we outlined the required theory to make integer native training work, we must align our implementation. I will explain this in the next post.