Back

Integers: The Theory

Part 2 of a series on native integer machine learning/deep learning

Debug notes

After the previous post I jumped right into debugging.

The increasing runtime in combination with the drastic difference between lr_shift and momentum_shift in the Iris example gave me reason to believe that the code has a bug in regards to scaling (shifting values).

I am making this assumption out of two reasons.

First, both versions use different arithmetic operations. The f32 uses raw arithmetic operators, e.g. + and * whereas the i32 uses saturating/wrapping arithmetic, i.e. values are either clipped or wrapped when overflowing.

Second, the i32 version implements gradient clipping.

To get a clearer picture of the current state I added debugging statements to Iris test, checking runtime, overflow stats and outgoing shift of predictions. The results were rather sobering; no wrapping, compute time per epoch is between 2700 - 3000µs (micro seconds). This could be influenced by a gamma ray from outer space. So not really a great measurement for tracking down bugs.

Looking at the shifts of the network revealed something different.

For every training sample we see inside the logs

Epoch #     Loss   Shift In    Forward Shift   Backward Shift
Epoch 1     XXXXXX      5           16              16
Epoch 2     XXXXXX      5           16              16
Epoch 3     XXXXXX      5           16              16
...

This doesn't look correct. Especially the backward shift looks completely wrong.

The layers are configured the following way

Layer Quant Shift Incoming Shift Outgoing Shift
L1 0 3 4
L2 0 4 4
L3 0 4 2

Our input is created with the QuantizationMethod::StandardScore, so it carries a shift of 5.

The calculation for the final shift of the forward pass so is so=si+k=1Nsik+sqk with si the original input shift, sik the input shift of the k-th layer and sqk the quantization shift of layer k.

In our example so=5+(3+0)+(4+0)+(4+0)=5+3+4+4=16

This shows that our forward pass is already computing the correct shift. The backward pass should reduce this shift again, but it doesn't. Hence an issue is inside the backward pass.

A fresh set of eyes was one of the main drivers behind the initial solution. After adjustment we ran the experiment again.

Architecture: 4 -> 8 -> 8 -> 3 with ReLU activations and MSE loss.

We now reach higher accuracy on i32 compared to its f32 counter implementation. The i32 branch converges within 239 epochs reaching an accuracy of 95%, whereas f32 never reaches this level.

f32 i32
Test accuracy 90% 95%
Epochs 500 239 (early stopped)
lr_shift 7 7
momentum_shift 1 1
Batch size 32 32
Grad. Clip value - 213

Reason for this "success" were two things. Fixing the internal shift bookkeeping, Linear<S>'s backward pass was using Linear::output_shift to transform the gradient w.r.t inputs. This resulted in wrong shifts for subsequent layers. And Linear<S>'s backward pass did not return a reduced s_g (gradient shift), but rather just forwarded the incoming value. Apart from the bookkeeping, drastically increasing the gradient clip value to allow more information bytes flowing through the network helped to get more stable results.

The diff for the changes inside Linear<S>::backward

-        let local_delta_w = s_g + s_x;
+        let local_delta_w = s_g - s_x;
-        let local_delta_x = self.weights.quant_shift + self.output_shift;
+        let local_delta_x = self.weights.quant_shift + self.input_shift;
...
-        let s_g_prev = s_g;//.saturating_sub(local_delta_x);
+        let s_g_prev = s_g.saturating_sub(local_delta_x);

Especially the update to s_g was something I had implemented earlier, but forgot an didn't discover up until today.

Impact on MNIST

The change was clearly fixing a bug, otherwise our performance on Iris would not look like this. Running quickly some epochs shows that at least the increasing compute time issue that we had in the previous post is resolved and we have a more stable value there

 Epoch         Loss     Accuracy    Time(s)   Clamps FW Shift BW Shift
────────────────────────────────────────────────────────────────────────────────
     0    1612.0000         9.4%       9.12s       -1    31    5
     1    1612.0000         9.4%       9.09s       -1    31    5
     2    1612.0000         9.4%       9.08s       -1    31    5
     3    1612.0000         9.4%       9.09s       -1    31    5
     4    1612.0000         9.4%       9.08s       -1    31    5
     5    1612.0000         9.4%       9.08s       -1    31    5
     6    1612.0000         9.4%       9.07s       -1    31    5
     7    1612.0000         9.4%       9.10s       -1    31    5
     8    1612.0000         9.4%       9.08s       -1    31    5
     9    1612.0000         9.4%       9.09s       -1    31    5

Tip

Obviously, the network in these logs did not learn anything, instead pay attention to the now more or less stable runtime.

The core math

Right now, we're adjusting things and seemingly fix issues that we don't fully understand. Instead of running around like a headless chicken, we will now look closer into the theory behind the implementation of native integer training. After we defined all the required theory, we will implement it and hopefully resolve all issues we faced so far, for good.

Motivation

The fundamental problem with integer arithmetic is that multiplication accumulates scale. When you multiply two integers, the result lives in a different "unit" than either operand, and without an explicit mechanism to track and correct for that, values either explode or collapse across layers in a artificial neural network. The structure we want is one where the scale is part of the value itself, so that every operation is well-defined in terms of what the numbers actually mean, not just what bits they hold.

The carrier set

Let Z be the integers. Define the carrier set as the Cartesian product

D=Z×Z

An element x=(v,s)D consists of a mantissa vZ and a scale exponent sZ. The rational number it encodes is given by the interpretation map

[[x]]=v2s

So a larger s means the mantissa is "finer-grained", the unit each integer step represents is smaller. A smaller s (or negative s) means the value lives at a corser scale.

Note

The formal definition admits sZ, allowing negative scale exponents, which would represent values scaled up by a power of two rather than down. In the implementation, shift values are typed as u32, restricting s0. This means the system can only represent dyadic rationals of magnitude |v|, never larger. The formal generality is retained here because the mathematical properties of the operations hold regardless, but any concrete instantiation should be understood as working in Z×N0.

Non-uniqueness

It is critical to note that D is a set of representations, not a set of distinct values. Two elements x1=(v1,s1) and x2=(v2,s2) are said to be equivalent, written x1x2, if and only if they encode the same rational number

x1x2v12s1=v22s2

For instance, (3,2),(6,3) and (12,4) all represent the rational 34. This non-uniqueness is not a flaw, it is precisely what gives the system its flexibility. The scale can be chosen to match the magnitude of the data at hand.

Scale alignment

Before we can define addition and substraction, we need a way to bring two elements into a common representation.

Given x1=(v1,s1) and x2=(v2,s2), define the target scale s=max(s1,s2), which is the coarser of the two scales. The alignment of xi to scale s is

align(xi,s)=(vi2ssi,s) where ... denotes a stochastic down-scaling^1 operation and vi2ssi the aligned value according to the target scale.

The element with the finer scale has its mantissa right-shifted (divided by 2k) to match the coarser unit. This is the exact analogue of lining up decimal points before adding:

You can only add 1.27+1.3 once they share the same number of decimal places.

Alignment is lossy in general, i.e. the bits shifted out are gone for good. This precision loss during alignment is the fundamental source of rounding error in this system, which is the equivalent to floating-point roundoff.

Stochastic scaling

Stochastic scaling/Requantization handles two things. Rescaling (adjusting the shift to a new target) and clipping (enforcing the hardware bit-width boundaries).

Rescaling

Let x=(vx,sx)D be an accumulated tuple, and stZ be the target shift such that k=stsx0.

We define the stochastic rounding function SR(v,k):DD as:

SR(v,k)=v2k+I((vmod2k)>U)

where UU{0,2k1} and I() is the indicator function.

The requantization operator R:D×ZD is then defined as

R(x,st)=(clip(SR(vx,stsx),Qmin,Qmax),st)

where Qmin and Qmax represent the representable bounds of the target integer precision. The function clip is defined as:

Let vZ be an integer value, and let Qmin,QmaxZ represent the lower and upper bounds of the target scale, where Qmin<Qmax.

clip(v,Qmin,Qmax)={Qminif v<Qminvif QminvQmaxQmaxif v>Qmax

We can derive the values of Qmin,QmaxZ based on the target bit-width bN+.

For targets for a signed two's complement integer (e.g. INT8 where b=8):

Qmin=2b1Qmax=2b11

For unsigned integer targets (e.g. UINT8):

Qmin=0Qmax=2b1

Arithmetic operations

We can now define the four operations. All four preserve the invariant that the result is again an element of D. Let x1=(v1,s1) and x2=(v2,s2), and write v^=vi2ssi for their aligned mantissas at s=max(s1,s2).

Addition and subtraction require alignment first, then operate directly on mantissas

x1x2=(v^1+v^2,s)x1x2=(v^1v^2,s)

Multiplication does not require alignment because the scales compose additively. The exact product of two dyadic numbers has mantissa v1v2 and scale s1+s2. However, the mantissa v1v2 may be too large to store, so a quantization shift qZ is applied to rescale it back into a reusable range

x1x2=(v1v22q,s1+s2q)

Note that the scale of the result is adjusted by q to compensate, so the encoded value is approximately preserved: [[x1x2]][[x1]][[x2]]

Division has the opposite problem. Integer division truncates, so a precision shift pZ is applied to the dividend before dividing, in order to preserve low-order bits that would otherwise be lost

x1x2=(v12pv2,s1s2+p)

Again the scale is adjusted to compensate, so [[x1x2]][[x1]][[x2]].

A note on approximation

An important consequence of the above is that (D,,) is not a ring. The operations are approximate, addition loses the bits shifted away during alignment and multiplication loses the bits shifted away during quantization. The choices of q and p are free parameters. Choosing them well is exactly the shift management problem that sits at the heart of integer neural network training.

The lr_shift, quant_shift and grad_shift parameters are all instances of this single decision appearing in different contexts. This provides all necessary building blocks for the forward pass of a neural network using D.

Gradients

But this framework introduces a new problem. Due to the randomness of SR, i.e. sampling from U and using I(), the forward pass becomes non-differentiable. This is a big problem, because we want to implement training of neural networks, which requires differentiable functions, otherwise we can not compute the gradient with standard chain-rule calculus. Unfortunately, SR as well as clip can both be considered step functions, meaning their true mathematical derivatives are exactly zero almost everywhere. The network would never train because the gradients would instantly vanish.

Straight-Through Estimator

To overcome this we will formalize a Straight-Through Estimator (STE). This STE essentially defines a "fake" derivative for the backward pass that ignores the discrete steps and pretends the function was continuous.

Let z=(Wx)b be the high-precision accumulated tuple in D before requantization, so y=R(z,s).

During backpropagation, we receive the gradient of the loss w.r.t. y, denoted as yLD. We need to compute the gradient w.r.t. z, denoted as zLD.

We break R into its two components and define the STE for each:

  1. The STE for SR:

In continuous space, the rounding function is an approximation of the identity function (scaled by the bit-shift). The STE assumes the gradient passes straight through the rounding operation unmodified.

zSR(z,k)1

  1. The STE for clip:

For the clipping function, the gradient passes through identically if the value was within the representable bounds during the forward pass. If the value was clipped to Qmin or Qmax, the gradient is stopped (zeroed out).

Using the indicator function I, the derivative is:

zclip(z)I(QminvyQmax)

Note that we evaluate the bounds check using vy which is the actual integer value of the output tuple y from the forward pass.

Combining these, the gradient simply flows straight through the rounding operator but gets masked by the clipping bounds.

Let 0y be the zero-gradient tuple that matches the shift of the incoming gradient, meaning 0y=(0,sy).

The gradient zL is defined as

zL={yLif QminvyQmax0yotherwise

Implementation

Now that we outlined the required theory to make integer native training work, we must align our implementation. I will explain this in the next post.