For ANS and many other statistical coders (eg. arithmetic coding) you need to create scaled frequencies (the Fs in ANS terminology) from the true counts.
But how do you do that? I've seen many heuristics over the years that are more or less good, but I've never actually seen the right answer. How do you scale to minimize total code len? Well let's do it.
Let's state the problem :
You are given some true counts Cs
Sum{Cs} = T the total of true counts
the true probabilities then are
Ps = Cs/T
and the ideal code lens are log2(1/Ps)
You need to create scaled frequencies Fs
such that
Sum{Fs} = M
for some given M.
and our goal is to minimize the total code len under the counts Fs.
The ideal entropy of the given counts is :
H = Sum{ Ps * log2(1/Ps) }
The code len under the counts Fs is :
L = Sum{ Ps * log2(M/Fs) }
The code len is strictly worse than the entropy
L >= H
We must also meet the constraint
if ( Cs != 0 ) then Fs > 0
That is, all symbols that exist in the set must be codeable. (note that this is not actually optimal;
it's usually better to replace all rare symbols with a single escape symbol, but we won't do that here).
The naive solution is :
Fs = round( M * Ps )
if ( Cs > 0 ) Fs = MAX(Fs,1);
which is just scaling up the Ps by M. This has two problems - one is that Sum{Fs} is not actually M. The other
is that just rounding the float does not actually distribute the integer counts to minimize codelen.
The usual heuristic is to do something like the above, and then apply some fix to make the sum right.
So first let's address how to fix the sum. We will always have issues with the sum being off M because of integer rounding.
What you will have is some correction :
correction = M - Sum{Fs}
that can be positive or negative. This is a count that needs to be added onto some symbols. We want to
add it to the symbols that give us the most benefit to L, the total code len.
Well that's simple, we just measure the affect of changing each Fs :
correction_sign = correction > 0 ? 1 : -1;
Ls_before = Ps * log2(M/Fs)
Ls_after = Ps * log2(M/(Fs + correction_sign))
Ls_delta = Ls_after - Ls_before
Ls_delta = Ps * ( log2(M/(Fs + correction_sign)) - log2(M/Fs) )
Ls_delta = Ps * log2(Fs/(Fs + correction_sign))
so we need to just find the symbol that gives us the lowest Ls_delta. This is either an improvement to total L,
or the least increase in L.
We need to apply multiple corrections. We don't want a solution thats O(alphabet*correction) , since that can
be 256*256 in bad cases. (correction is <= alphabet and typically in the 1-50 range for a typical 256-symbol file).
The obvious solution is a heap. In pseudocode :
For all s
push_heap( Ls_delta , s )
For correction
s = pop_heap
adjust Fs
compute new Ls_delta for s
push_heap( Ls_delta , s )
note that after we adjust the count we need to recompute Ls_delta and repush that symbol, because we might want
to choose the same symbol again later.
In STL+cblib this is :
to[] = Fs from[] = original counts struct sort_sym { int sym; float rank; sort_sym() { } sort_sym( int s, float r ) : sym(s) , rank(r) { } bool operator < (const sort_sym & rhs) const { return rank < rhs.rank; } }; --------- if ( correction != 0 ) { //lprintfvar(correction); int32 correction_sign = (correction > 0) ? 1 : -1; vector<sort_sym> heap; heap.reserve(alphabet); for LOOP(i,alphabet) { if ( from[i] == 0 ) continue; ASSERT( to[i] != 0 ); if ( to[i] > 1 || correction_sign == 1 ) { double change = log( (double) to[i] / (to[i] + correction_sign) ) * from[i]; heap.push_back( sort_sym(i,change) ); } } std::make_heap(heap.begin(),heap.end()); while( correction != 0 ) { ASSERT_RELEASE( ! heap.empty() ); std::pop_heap(heap.begin(),heap.end()); sort_sym ss = heap.back(); heap.pop_back(); int i = ss.sym; ASSERT( from[i] != 0 ); to[i] += correction_sign; correction -= correction_sign; ASSERT( to[i] != 0 ); if ( to[i] > 1 || correction_sign == 1 ) { double change = log( (double) to[i] / (to[i] + correction_sign) ) * from[i]; heap.push_back( sort_sym(i,change) ); std::push_heap(heap.begin(),heap.end()); } } ASSERT( cb::sum(to,to+alphabet) == (uint32)to_sum_desired ); } |
Errkay. So our first attempt is to just use the naive scaling Fs = round( M * Ps ) and then fix the sum using the heap correction algorithm above.
Doing round+correct gets you 99% of the way there. I measured the difference between the total code len made that way and the optimal, and they are less than 0.001 bpb different on every file I tried. But it's still not quite right, so what is the right way?
To guide my search I had a look at the cases where round+correct was not optimal. When it's not optimal
it means there is some symbol a and some symbol b such that { Fa+1 , Fb-1 } gives a better total code len
than {Fa,Fb}. An example of that is :
count to inc : (1/1024) was (1866/1286152 = 0.0015)
count to dec : (380/1024) was (482110/1286152 = 0.3748)
to inc; cl before : 10.00 cl after : 9.00 , true cl : 9.43
to dec; cl before : 1.43 cl after : 1.43 , true cl : 1.42
The key point is on the 1 count :
count to inc : (1/1024) was (1866/1286152 = 0.0015)
to inc; cl before : 10.00 cl after : 9.00 , true cl : 9.43
1024*1866/1286152 = 1.485660
round(1.485660) = 1
so Fs = 1 , which is a codelen of 10
but Fs = 2 gives a codelen (9) closer to the true codelen (9.43)
And this provided the key observation : rather than rounding the scaled count,
what we should be doing is either floor or ceil of the fraction,
whichever gives a codelen closer to the true codelen.
BTW before you go off hacking a special case just for Fs==1, it also happens with higher counts :
count to inc : (2/1024) was (439/180084) scaled = 2.4963
to inc; cl before : 9.00 cl after : 8.42 , true cl : 8.68
count to inc : (4/1024) was (644/146557) scaled = 4.4997
to inc; cl before : 8.00 cl after : 7.68 , true cl : 7.83
though obviously the higher Fs, the less likely it is because the rounding gets closer to being perfect.
So it's easy enough just to solve exactly, simply pick the floor or ceil of the ratio depending on which
makes the closer codelen :
Ps = Cs/T from the true counts
down = floor( M * Ps )
down = MAX( down,1)
Fs = either down or (down+1)
true_codelen = -log2( Ps )
down_codelen = -log2( down/M )
up_codelen = -log2( (down+1)/M )
if ( |down_codelen - true_codelen| < |up_codelen - true_codelen| )
Fs = down
else
Fs = down+1
And since all we care about is the inequality, we can do some maths and simplify the expressions.
I won't write out all the algebra to do the simplification because it's straightforward, but there
are a few key steps :
| log(x) | = log( MAX(x,1/x) )
log(x) >= log(y) is the same as x >= y
down <= M*Ps
down+1 >= M*Ps
the result of the simplification in code is :
from[] = original counts (Cs) , sum to T to[] = normalized counts (Fs) , will sum to M double from_scaled = from[i] * M/T; uint32 down = (uint32)( from_scaled ); to[i] = ( from_scaled*from_scaled <= down*(down+1) ) ? down : down+1; |
Note that there's no special casing needed to ensure that (from_scaled < 1) gives you to[i] = 1 , we get that for free with this expression.
I was delighted when I got to this extremely simple final form.
And that is the conclusion. Use that to find the initial scaled counts. There will still be some correction that needs to be applied to reach the target sum exactly, so use the heap correction algorithm above.
As a final note, if we look at the final expression :
to[i] = ( from_scaled*from_scaled < down*(down+1) ) ? down : down+1;
to[i] = ( test < 0 ) ? down : down+1;
test = from_scaled*from_scaled - down*(down+1);
from_scaled = down + frac
test = (down + frac)^2 - down*(down+1);
solve for frac where test = 0
frac = sqrt( down^2 + down ) - down
That gives you the fractional part of the scaled count where you should round up or down.
It varies with floor(from_scaled). The actual values are :
1 : 0.414214
2 : 0.449490
3 : 0.464102
4 : 0.472136
5 : 0.477226
6 : 0.480741
7 : 0.483315
8 : 0.485281
9 : 0.486833
10 : 0.488088
11 : 0.489125
12 : 0.489996
13 : 0.490738
14 : 0.491377
15 : 0.491933
16 : 0.492423
17 : 0.492856
18 : 0.493242
19 : 0.493589
You can see as Fs gets larger, it goes to 0.5 , so just using rounding is close to correct. It's really in the very
low values where it's quite far from 0.5 that errors are most likely to occur.
No comments:
Post a Comment