2/11/2014

02-11-14 - Understanding ANS - 10

Not really an ANS topic, but a piece you need for ANS so I've had a look at it.

For ANS and many other statistical coders (eg. arithmetic coding) you need to create scaled frequencies (the Fs in ANS terminology) from the true counts.

But how do you do that? I've seen many heuristics over the years that are more or less good, but I've never actually seen the right answer. How do you scale to minimize total code len? Well let's do it.

Let's state the problem :


You are given some true counts Cs

Sum{Cs} = T  the total of true counts

the true probabilities then are

Ps = Cs/T

and the ideal code lens are log2(1/Ps)

You need to create scaled frequencies Fs
such that

Sum{Fs} = M

for some given M.

and our goal is to minimize the total code len under the counts Fs.

The ideal entropy of the given counts is :

H = Sum{ Ps * log2(1/Ps) }

The code len under the counts Fs is :

L = Sum{ Ps * log2(M/Fs) }

The code len is strictly worse than the entropy

L >= H

We must also meet the constraint

if ( Cs != 0 ) then Fs > 0

That is, all symbols that exist in the set must be codeable. (note that this is not actually optimal; it's usually better to replace all rare symbols with a single escape symbol, but we won't do that here).

The naive solution is :


Fs = round( M * Ps )

if ( Cs > 0 ) Fs = MAX(Fs,1);

which is just scaling up the Ps by M. This has two problems - one is that Sum{Fs} is not actually M. The other is that just rounding the float does not actually distribute the integer counts to minimize codelen.

The usual heuristic is to do something like the above, and then apply some fix to make the sum right.

So first let's address how to fix the sum. We will always have issues with the sum being off M because of integer rounding.

What you will have is some correction :


correction = M - Sum{Fs}

that can be positive or negative. This is a count that needs to be added onto some symbols. We want to add it to the symbols that give us the most benefit to L, the total code len. Well that's simple, we just measure the affect of changing each Fs :

correction_sign = correction > 0 ? 1 : -1;

Ls_before = Ps * log2(M/Fs)
Ls_after = Ps * log2(M/(Fs + correction_sign))

Ls_delta = Ls_after - Ls_before
Ls_delta = Ps * ( log2(M/(Fs + correction_sign)) - log2(M/Fs) )
Ls_delta = Ps * log2(Fs/(Fs + correction_sign))

so we need to just find the symbol that gives us the lowest Ls_delta. This is either an improvement to total L, or the least increase in L.

We need to apply multiple corrections. We don't want a solution thats O(alphabet*correction) , since that can be 256*256 in bad cases. (correction is <= alphabet and typically in the 1-50 range for a typical 256-symbol file). The obvious solution is a heap. In pseudocode :


For all s
    push_heap( Ls_delta , s )

For correction
    s = pop_heap
    adjust Fs
    compute new Ls_delta for s
    push_heap( Ls_delta , s )

note that after we adjust the count we need to recompute Ls_delta and repush that symbol, because we might want to choose the same symbol again later.

In STL+cblib this is :


to[] = Fs
from[] = original counts

struct sort_sym
{
    int sym;
    float rank;
    sort_sym() { }
    sort_sym( int s, float r ) : sym(s) , rank(r) { }
    bool operator < (const sort_sym & rhs) const { return rank < rhs.rank; }
};

---------

    if ( correction != 0 )
    {
        //lprintfvar(correction);
        int32 correction_sign = (correction > 0) ? 1 : -1;

        vector<sort_sym> heap;
        heap.reserve(alphabet);

        for LOOP(i,alphabet)
        {
            if ( from[i] == 0 ) continue;
            ASSERT( to[i] != 0 );
            if ( to[i] > 1 || correction_sign == 1 )
            {
                double change = log( (double) to[i] / (to[i] + correction_sign) ) * from[i];
            
                heap.push_back( sort_sym(i,change) );
            }
        }
        
        std::make_heap(heap.begin(),heap.end());
        
        while( correction != 0 )
        {
            ASSERT_RELEASE( ! heap.empty() );
            std::pop_heap(heap.begin(),heap.end());
            sort_sym ss = heap.back();
            heap.pop_back();
            
            int i = ss.sym;
            ASSERT( from[i] != 0 );
            
            to[i] += correction_sign;
            correction -= correction_sign;
            ASSERT( to[i] != 0 );
        
            if ( to[i] > 1 || correction_sign == 1 )
            {
                double change = log( (double) to[i] / (to[i] + correction_sign) ) * from[i];
            
                heap.push_back( sort_sym(i,change) );
                std::push_heap(heap.begin(),heap.end());
            }               
        }
    
        ASSERT( cb::sum(to,to+alphabet) == (uint32)to_sum_desired );
    }

You may have noted that the above code is using natural log instead of log2. The difference is only a constant scaling factor, so it doesn't affect the heap order; you may use whatever log base is fastest.

Errkay. So our first attempt is to just use the naive scaling Fs = round( M * Ps ) and then fix the sum using the heap correction algorithm above.

Doing round+correct gets you 99% of the way there. I measured the difference between the total code len made that way and the optimal, and they are less than 0.001 bpb different on every file I tried. But it's still not quite right, so what is the right way?

To guide my search I had a look at the cases where round+correct was not optimal. When it's not optimal it means there is some symbol a and some symbol b such that { Fa+1 , Fb-1 } gives a better total code len than {Fa,Fb}. An example of that is :


count to inc : (1/1024) was (1866/1286152 = 0.0015)
count to dec : (380/1024) was (482110/1286152 = 0.3748)
to inc; cl before : 10.00 cl after : 9.00 , true cl : 9.43
to dec; cl before : 1.43 cl after : 1.43 , true cl : 1.42

The key point is on the 1 count :

count to inc : (1/1024) was (1866/1286152 = 0.0015)
to inc; cl before : 10.00 cl after : 9.00 , true cl : 9.43

1024*1866/1286152 = 1.485660
round(1.485660) = 1

so Fs = 1 , which is a codelen of 10

but Fs = 2 gives a codelen (9) closer to the true codelen (9.43)

And this provided the key observation : rather than rounding the scaled count, what we should be doing is either floor or ceil of the fraction, whichever gives a codelen closer to the true codelen.

BTW before you go off hacking a special case just for Fs==1, it also happens with higher counts :


count to inc : (2/1024) was (439/180084) scaled = 2.4963
to inc; cl before : 9.00 cl after : 8.42 , true cl : 8.68

count to inc : (4/1024) was (644/146557) scaled = 4.4997
to inc; cl before : 8.00 cl after : 7.68 , true cl : 7.83

though obviously the higher Fs, the less likely it is because the rounding gets closer to being perfect.

So it's easy enough just to solve exactly, simply pick the floor or ceil of the ratio depending on which makes the closer codelen :


Ps = Cs/T from the true counts

down = floor( M * Ps )
down = MAX( down,1)

Fs = either down or (down+1)

true_codelen = -log2( Ps )
down_codelen = -log2( down/M )
  up_codelen = -log2( (down+1)/M )

if ( |down_codelen - true_codelen| < |up_codelen - true_codelen| )
  Fs = down
else
  Fs = down+1

And since all we care about is the inequality, we can do some maths and simplify the expressions. I won't write out all the algebra to do the simplification because it's straightforward, but there are a few key steps :

| log(x) | = log( MAX(x,1/x) )

log(x) >= log(y)  is the same as x >= y

down <= M*Ps
down+1 >= M*Ps

the result of the simplification in code is :

from[] = original counts (Cs) , sum to T
to[] = normalized counts (Fs) , will sum to M

    double from_scaled = from[i] * M/T;

    uint32 down = (uint32)( from_scaled );
                
    to[i] = ( from_scaled*from_scaled <= down*(down+1) ) ? down : down+1;

Note that there's no special casing needed to ensure that (from_scaled < 1) gives you to[i] = 1 , we get that for free with this expression.

I was delighted when I got to this extremely simple final form.

And that is the conclusion. Use that to find the initial scaled counts. There will still be some correction that needs to be applied to reach the target sum exactly, so use the heap correction algorithm above.

As a final note, if we look at the final expression :


to[i] = ( from_scaled*from_scaled < down*(down+1) ) ? down : down+1;

to[i] = ( test < 0 ) ? down : down+1;

test = from_scaled*from_scaled - down*(down+1); 

from_scaled = down + frac

test = (down + frac)^2 - down*(down+1);

solve for frac where test = 0

frac = sqrt( down^2 + down ) - down

That gives you the fractional part of the scaled count where you should round up or down. It varies with floor(from_scaled). The actual values are :

1 : 0.414214
2 : 0.449490
3 : 0.464102
4 : 0.472136
5 : 0.477226
6 : 0.480741
7 : 0.483315
8 : 0.485281
9 : 0.486833
10 : 0.488088
11 : 0.489125
12 : 0.489996
13 : 0.490738
14 : 0.491377
15 : 0.491933
16 : 0.492423
17 : 0.492856
18 : 0.493242
19 : 0.493589

You can see as Fs gets larger, it goes to 0.5 , so just using rounding is close to correct. It's really in the very low values where it's quite far from 0.5 that errors are most likely to occur.

No comments:

old rants