blob: 6f341914c67f05381ef2c0c037254e1bc799e7ab [file] [log] [blame]
khenaidoo7d3c5582021-08-11 18:09:44 -04001// Copyright 2018 Klaus Post. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4// Based on work Copyright (c) 2013, Yann Collet, released under BSD License.
5
6package fse
7
8import (
9 "errors"
10 "fmt"
11)
12
13// Compress the input bytes. Input must be < 2GB.
14// Provide a Scratch buffer to avoid memory allocations.
15// Note that the output is also kept in the scratch buffer.
16// If input is too hard to compress, ErrIncompressible is returned.
17// If input is a single byte value repeated ErrUseRLE is returned.
18func Compress(in []byte, s *Scratch) ([]byte, error) {
19 if len(in) <= 1 {
20 return nil, ErrIncompressible
21 }
22 if len(in) > (2<<30)-1 {
23 return nil, errors.New("input too big, must be < 2GB")
24 }
25 s, err := s.prepare(in)
26 if err != nil {
27 return nil, err
28 }
29
30 // Create histogram, if none was provided.
31 maxCount := s.maxCount
32 if maxCount == 0 {
33 maxCount = s.countSimple(in)
34 }
35 // Reset for next run.
36 s.clearCount = true
37 s.maxCount = 0
38 if maxCount == len(in) {
39 // One symbol, use RLE
40 return nil, ErrUseRLE
41 }
42 if maxCount == 1 || maxCount < (len(in)>>7) {
43 // Each symbol present maximum once or too well distributed.
44 return nil, ErrIncompressible
45 }
46 s.optimalTableLog()
47 err = s.normalizeCount()
48 if err != nil {
49 return nil, err
50 }
51 err = s.writeCount()
52 if err != nil {
53 return nil, err
54 }
55
56 if false {
57 err = s.validateNorm()
58 if err != nil {
59 return nil, err
60 }
61 }
62
63 err = s.buildCTable()
64 if err != nil {
65 return nil, err
66 }
67 err = s.compress(in)
68 if err != nil {
69 return nil, err
70 }
71 s.Out = s.bw.out
72 // Check if we compressed.
73 if len(s.Out) >= len(in) {
74 return nil, ErrIncompressible
75 }
76 return s.Out, nil
77}
78
79// cState contains the compression state of a stream.
80type cState struct {
81 bw *bitWriter
82 stateTable []uint16
83 state uint16
84}
85
86// init will initialize the compression state to the first symbol of the stream.
87func (c *cState) init(bw *bitWriter, ct *cTable, tableLog uint8, first symbolTransform) {
88 c.bw = bw
89 c.stateTable = ct.stateTable
90
91 nbBitsOut := (first.deltaNbBits + (1 << 15)) >> 16
92 im := int32((nbBitsOut << 16) - first.deltaNbBits)
93 lu := (im >> nbBitsOut) + first.deltaFindState
94 c.state = c.stateTable[lu]
95}
96
97// encode the output symbol provided and write it to the bitstream.
98func (c *cState) encode(symbolTT symbolTransform) {
99 nbBitsOut := (uint32(c.state) + symbolTT.deltaNbBits) >> 16
100 dstState := int32(c.state>>(nbBitsOut&15)) + symbolTT.deltaFindState
101 c.bw.addBits16NC(c.state, uint8(nbBitsOut))
102 c.state = c.stateTable[dstState]
103}
104
105// encode the output symbol provided and write it to the bitstream.
106func (c *cState) encodeZero(symbolTT symbolTransform) {
107 nbBitsOut := (uint32(c.state) + symbolTT.deltaNbBits) >> 16
108 dstState := int32(c.state>>(nbBitsOut&15)) + symbolTT.deltaFindState
109 c.bw.addBits16ZeroNC(c.state, uint8(nbBitsOut))
110 c.state = c.stateTable[dstState]
111}
112
113// flush will write the tablelog to the output and flush the remaining full bytes.
114func (c *cState) flush(tableLog uint8) {
115 c.bw.flush32()
116 c.bw.addBits16NC(c.state, tableLog)
117 c.bw.flush()
118}
119
120// compress is the main compression loop that will encode the input from the last byte to the first.
121func (s *Scratch) compress(src []byte) error {
122 if len(src) <= 2 {
123 return errors.New("compress: src too small")
124 }
125 tt := s.ct.symbolTT[:256]
126 s.bw.reset(s.Out)
127
128 // Our two states each encodes every second byte.
129 // Last byte encoded (first byte decoded) will always be encoded by c1.
130 var c1, c2 cState
131
132 // Encode so remaining size is divisible by 4.
133 ip := len(src)
134 if ip&1 == 1 {
135 c1.init(&s.bw, &s.ct, s.actualTableLog, tt[src[ip-1]])
136 c2.init(&s.bw, &s.ct, s.actualTableLog, tt[src[ip-2]])
137 c1.encodeZero(tt[src[ip-3]])
138 ip -= 3
139 } else {
140 c2.init(&s.bw, &s.ct, s.actualTableLog, tt[src[ip-1]])
141 c1.init(&s.bw, &s.ct, s.actualTableLog, tt[src[ip-2]])
142 ip -= 2
143 }
144 if ip&2 != 0 {
145 c2.encodeZero(tt[src[ip-1]])
146 c1.encodeZero(tt[src[ip-2]])
147 ip -= 2
148 }
149
150 // Main compression loop.
151 switch {
152 case !s.zeroBits && s.actualTableLog <= 8:
153 // We can encode 4 symbols without requiring a flush.
154 // We do not need to check if any output is 0 bits.
155 for ip >= 4 {
156 s.bw.flush32()
157 v3, v2, v1, v0 := src[ip-4], src[ip-3], src[ip-2], src[ip-1]
158 c2.encode(tt[v0])
159 c1.encode(tt[v1])
160 c2.encode(tt[v2])
161 c1.encode(tt[v3])
162 ip -= 4
163 }
164 case !s.zeroBits:
165 // We do not need to check if any output is 0 bits.
166 for ip >= 4 {
167 s.bw.flush32()
168 v3, v2, v1, v0 := src[ip-4], src[ip-3], src[ip-2], src[ip-1]
169 c2.encode(tt[v0])
170 c1.encode(tt[v1])
171 s.bw.flush32()
172 c2.encode(tt[v2])
173 c1.encode(tt[v3])
174 ip -= 4
175 }
176 case s.actualTableLog <= 8:
177 // We can encode 4 symbols without requiring a flush
178 for ip >= 4 {
179 s.bw.flush32()
180 v3, v2, v1, v0 := src[ip-4], src[ip-3], src[ip-2], src[ip-1]
181 c2.encodeZero(tt[v0])
182 c1.encodeZero(tt[v1])
183 c2.encodeZero(tt[v2])
184 c1.encodeZero(tt[v3])
185 ip -= 4
186 }
187 default:
188 for ip >= 4 {
189 s.bw.flush32()
190 v3, v2, v1, v0 := src[ip-4], src[ip-3], src[ip-2], src[ip-1]
191 c2.encodeZero(tt[v0])
192 c1.encodeZero(tt[v1])
193 s.bw.flush32()
194 c2.encodeZero(tt[v2])
195 c1.encodeZero(tt[v3])
196 ip -= 4
197 }
198 }
199
200 // Flush final state.
201 // Used to initialize state when decoding.
202 c2.flush(s.actualTableLog)
203 c1.flush(s.actualTableLog)
204
205 return s.bw.close()
206}
207
208// writeCount will write the normalized histogram count to header.
209// This is read back by readNCount.
210func (s *Scratch) writeCount() error {
211 var (
212 tableLog = s.actualTableLog
213 tableSize = 1 << tableLog
214 previous0 bool
215 charnum uint16
216
217 maxHeaderSize = ((int(s.symbolLen) * int(tableLog)) >> 3) + 3
218
219 // Write Table Size
220 bitStream = uint32(tableLog - minTablelog)
221 bitCount = uint(4)
222 remaining = int16(tableSize + 1) /* +1 for extra accuracy */
223 threshold = int16(tableSize)
224 nbBits = uint(tableLog + 1)
225 )
226 if cap(s.Out) < maxHeaderSize {
227 s.Out = make([]byte, 0, s.br.remain()+maxHeaderSize)
228 }
229 outP := uint(0)
230 out := s.Out[:maxHeaderSize]
231
232 // stops at 1
233 for remaining > 1 {
234 if previous0 {
235 start := charnum
236 for s.norm[charnum] == 0 {
237 charnum++
238 }
239 for charnum >= start+24 {
240 start += 24
241 bitStream += uint32(0xFFFF) << bitCount
242 out[outP] = byte(bitStream)
243 out[outP+1] = byte(bitStream >> 8)
244 outP += 2
245 bitStream >>= 16
246 }
247 for charnum >= start+3 {
248 start += 3
249 bitStream += 3 << bitCount
250 bitCount += 2
251 }
252 bitStream += uint32(charnum-start) << bitCount
253 bitCount += 2
254 if bitCount > 16 {
255 out[outP] = byte(bitStream)
256 out[outP+1] = byte(bitStream >> 8)
257 outP += 2
258 bitStream >>= 16
259 bitCount -= 16
260 }
261 }
262
263 count := s.norm[charnum]
264 charnum++
265 max := (2*threshold - 1) - remaining
266 if count < 0 {
267 remaining += count
268 } else {
269 remaining -= count
270 }
271 count++ // +1 for extra accuracy
272 if count >= threshold {
273 count += max // [0..max[ [max..threshold[ (...) [threshold+max 2*threshold[
274 }
275 bitStream += uint32(count) << bitCount
276 bitCount += nbBits
277 if count < max {
278 bitCount--
279 }
280
281 previous0 = count == 1
282 if remaining < 1 {
283 return errors.New("internal error: remaining<1")
284 }
285 for remaining < threshold {
286 nbBits--
287 threshold >>= 1
288 }
289
290 if bitCount > 16 {
291 out[outP] = byte(bitStream)
292 out[outP+1] = byte(bitStream >> 8)
293 outP += 2
294 bitStream >>= 16
295 bitCount -= 16
296 }
297 }
298
299 out[outP] = byte(bitStream)
300 out[outP+1] = byte(bitStream >> 8)
301 outP += (bitCount + 7) / 8
302
303 if charnum > s.symbolLen {
304 return errors.New("internal error: charnum > s.symbolLen")
305 }
306 s.Out = out[:outP]
307 return nil
308}
309
310// symbolTransform contains the state transform for a symbol.
311type symbolTransform struct {
312 deltaFindState int32
313 deltaNbBits uint32
314}
315
316// String prints values as a human readable string.
317func (s symbolTransform) String() string {
318 return fmt.Sprintf("dnbits: %08x, fs:%d", s.deltaNbBits, s.deltaFindState)
319}
320
321// cTable contains tables used for compression.
322type cTable struct {
323 tableSymbol []byte
324 stateTable []uint16
325 symbolTT []symbolTransform
326}
327
328// allocCtable will allocate tables needed for compression.
329// If existing tables a re big enough, they are simply re-used.
330func (s *Scratch) allocCtable() {
331 tableSize := 1 << s.actualTableLog
332 // get tableSymbol that is big enough.
333 if cap(s.ct.tableSymbol) < tableSize {
334 s.ct.tableSymbol = make([]byte, tableSize)
335 }
336 s.ct.tableSymbol = s.ct.tableSymbol[:tableSize]
337
338 ctSize := tableSize
339 if cap(s.ct.stateTable) < ctSize {
340 s.ct.stateTable = make([]uint16, ctSize)
341 }
342 s.ct.stateTable = s.ct.stateTable[:ctSize]
343
344 if cap(s.ct.symbolTT) < 256 {
345 s.ct.symbolTT = make([]symbolTransform, 256)
346 }
347 s.ct.symbolTT = s.ct.symbolTT[:256]
348}
349
350// buildCTable will populate the compression table so it is ready to be used.
351func (s *Scratch) buildCTable() error {
352 tableSize := uint32(1 << s.actualTableLog)
353 highThreshold := tableSize - 1
354 var cumul [maxSymbolValue + 2]int16
355
356 s.allocCtable()
357 tableSymbol := s.ct.tableSymbol[:tableSize]
358 // symbol start positions
359 {
360 cumul[0] = 0
361 for ui, v := range s.norm[:s.symbolLen-1] {
362 u := byte(ui) // one less than reference
363 if v == -1 {
364 // Low proba symbol
365 cumul[u+1] = cumul[u] + 1
366 tableSymbol[highThreshold] = u
367 highThreshold--
368 } else {
369 cumul[u+1] = cumul[u] + v
370 }
371 }
372 // Encode last symbol separately to avoid overflowing u
373 u := int(s.symbolLen - 1)
374 v := s.norm[s.symbolLen-1]
375 if v == -1 {
376 // Low proba symbol
377 cumul[u+1] = cumul[u] + 1
378 tableSymbol[highThreshold] = byte(u)
379 highThreshold--
380 } else {
381 cumul[u+1] = cumul[u] + v
382 }
383 if uint32(cumul[s.symbolLen]) != tableSize {
384 return fmt.Errorf("internal error: expected cumul[s.symbolLen] (%d) == tableSize (%d)", cumul[s.symbolLen], tableSize)
385 }
386 cumul[s.symbolLen] = int16(tableSize) + 1
387 }
388 // Spread symbols
389 s.zeroBits = false
390 {
391 step := tableStep(tableSize)
392 tableMask := tableSize - 1
393 var position uint32
394 // if any symbol > largeLimit, we may have 0 bits output.
395 largeLimit := int16(1 << (s.actualTableLog - 1))
396 for ui, v := range s.norm[:s.symbolLen] {
397 symbol := byte(ui)
398 if v > largeLimit {
399 s.zeroBits = true
400 }
401 for nbOccurrences := int16(0); nbOccurrences < v; nbOccurrences++ {
402 tableSymbol[position] = symbol
403 position = (position + step) & tableMask
404 for position > highThreshold {
405 position = (position + step) & tableMask
406 } /* Low proba area */
407 }
408 }
409
410 // Check if we have gone through all positions
411 if position != 0 {
412 return errors.New("position!=0")
413 }
414 }
415
416 // Build table
417 table := s.ct.stateTable
418 {
419 tsi := int(tableSize)
420 for u, v := range tableSymbol {
421 // TableU16 : sorted by symbol order; gives next state value
422 table[cumul[v]] = uint16(tsi + u)
423 cumul[v]++
424 }
425 }
426
427 // Build Symbol Transformation Table
428 {
429 total := int16(0)
430 symbolTT := s.ct.symbolTT[:s.symbolLen]
431 tableLog := s.actualTableLog
432 tl := (uint32(tableLog) << 16) - (1 << tableLog)
433 for i, v := range s.norm[:s.symbolLen] {
434 switch v {
435 case 0:
436 case -1, 1:
437 symbolTT[i].deltaNbBits = tl
438 symbolTT[i].deltaFindState = int32(total - 1)
439 total++
440 default:
441 maxBitsOut := uint32(tableLog) - highBits(uint32(v-1))
442 minStatePlus := uint32(v) << maxBitsOut
443 symbolTT[i].deltaNbBits = (maxBitsOut << 16) - minStatePlus
444 symbolTT[i].deltaFindState = int32(total - v)
445 total += v
446 }
447 }
448 if total != int16(tableSize) {
449 return fmt.Errorf("total mismatch %d (got) != %d (want)", total, tableSize)
450 }
451 }
452 return nil
453}
454
455// countSimple will create a simple histogram in s.count.
456// Returns the biggest count.
457// Does not update s.clearCount.
458func (s *Scratch) countSimple(in []byte) (max int) {
459 for _, v := range in {
460 s.count[v]++
461 }
462 m := uint32(0)
463 for i, v := range s.count[:] {
464 if v > m {
465 m = v
466 }
467 if v > 0 {
468 s.symbolLen = uint16(i) + 1
469 }
470 }
471 return int(m)
472}
473
474// minTableLog provides the minimum logSize to safely represent a distribution.
475func (s *Scratch) minTableLog() uint8 {
476 minBitsSrc := highBits(uint32(s.br.remain()-1)) + 1
477 minBitsSymbols := highBits(uint32(s.symbolLen-1)) + 2
478 if minBitsSrc < minBitsSymbols {
479 return uint8(minBitsSrc)
480 }
481 return uint8(minBitsSymbols)
482}
483
484// optimalTableLog calculates and sets the optimal tableLog in s.actualTableLog
485func (s *Scratch) optimalTableLog() {
486 tableLog := s.TableLog
487 minBits := s.minTableLog()
488 maxBitsSrc := uint8(highBits(uint32(s.br.remain()-1))) - 2
489 if maxBitsSrc < tableLog {
490 // Accuracy can be reduced
491 tableLog = maxBitsSrc
492 }
493 if minBits > tableLog {
494 tableLog = minBits
495 }
496 // Need a minimum to safely represent all symbol values
497 if tableLog < minTablelog {
498 tableLog = minTablelog
499 }
500 if tableLog > maxTableLog {
501 tableLog = maxTableLog
502 }
503 s.actualTableLog = tableLog
504}
505
506var rtbTable = [...]uint32{0, 473195, 504333, 520860, 550000, 700000, 750000, 830000}
507
508// normalizeCount will normalize the count of the symbols so
509// the total is equal to the table size.
510func (s *Scratch) normalizeCount() error {
511 var (
512 tableLog = s.actualTableLog
513 scale = 62 - uint64(tableLog)
514 step = (1 << 62) / uint64(s.br.remain())
515 vStep = uint64(1) << (scale - 20)
516 stillToDistribute = int16(1 << tableLog)
517 largest int
518 largestP int16
519 lowThreshold = (uint32)(s.br.remain() >> tableLog)
520 )
521
522 for i, cnt := range s.count[:s.symbolLen] {
523 // already handled
524 // if (count[s] == s.length) return 0; /* rle special case */
525
526 if cnt == 0 {
527 s.norm[i] = 0
528 continue
529 }
530 if cnt <= lowThreshold {
531 s.norm[i] = -1
532 stillToDistribute--
533 } else {
534 proba := (int16)((uint64(cnt) * step) >> scale)
535 if proba < 8 {
536 restToBeat := vStep * uint64(rtbTable[proba])
537 v := uint64(cnt)*step - (uint64(proba) << scale)
538 if v > restToBeat {
539 proba++
540 }
541 }
542 if proba > largestP {
543 largestP = proba
544 largest = i
545 }
546 s.norm[i] = proba
547 stillToDistribute -= proba
548 }
549 }
550
551 if -stillToDistribute >= (s.norm[largest] >> 1) {
552 // corner case, need another normalization method
553 return s.normalizeCount2()
554 }
555 s.norm[largest] += stillToDistribute
556 return nil
557}
558
559// Secondary normalization method.
560// To be used when primary method fails.
561func (s *Scratch) normalizeCount2() error {
562 const notYetAssigned = -2
563 var (
564 distributed uint32
565 total = uint32(s.br.remain())
566 tableLog = s.actualTableLog
567 lowThreshold = total >> tableLog
568 lowOne = (total * 3) >> (tableLog + 1)
569 )
570 for i, cnt := range s.count[:s.symbolLen] {
571 if cnt == 0 {
572 s.norm[i] = 0
573 continue
574 }
575 if cnt <= lowThreshold {
576 s.norm[i] = -1
577 distributed++
578 total -= cnt
579 continue
580 }
581 if cnt <= lowOne {
582 s.norm[i] = 1
583 distributed++
584 total -= cnt
585 continue
586 }
587 s.norm[i] = notYetAssigned
588 }
589 toDistribute := (1 << tableLog) - distributed
590
591 if (total / toDistribute) > lowOne {
592 // risk of rounding to zero
593 lowOne = (total * 3) / (toDistribute * 2)
594 for i, cnt := range s.count[:s.symbolLen] {
595 if (s.norm[i] == notYetAssigned) && (cnt <= lowOne) {
596 s.norm[i] = 1
597 distributed++
598 total -= cnt
599 continue
600 }
601 }
602 toDistribute = (1 << tableLog) - distributed
603 }
604 if distributed == uint32(s.symbolLen)+1 {
605 // all values are pretty poor;
606 // probably incompressible data (should have already been detected);
607 // find max, then give all remaining points to max
608 var maxV int
609 var maxC uint32
610 for i, cnt := range s.count[:s.symbolLen] {
611 if cnt > maxC {
612 maxV = i
613 maxC = cnt
614 }
615 }
616 s.norm[maxV] += int16(toDistribute)
617 return nil
618 }
619
620 if total == 0 {
621 // all of the symbols were low enough for the lowOne or lowThreshold
622 for i := uint32(0); toDistribute > 0; i = (i + 1) % (uint32(s.symbolLen)) {
623 if s.norm[i] > 0 {
624 toDistribute--
625 s.norm[i]++
626 }
627 }
628 return nil
629 }
630
631 var (
632 vStepLog = 62 - uint64(tableLog)
633 mid = uint64((1 << (vStepLog - 1)) - 1)
634 rStep = (((1 << vStepLog) * uint64(toDistribute)) + mid) / uint64(total) // scale on remaining
635 tmpTotal = mid
636 )
637 for i, cnt := range s.count[:s.symbolLen] {
638 if s.norm[i] == notYetAssigned {
639 var (
640 end = tmpTotal + uint64(cnt)*rStep
641 sStart = uint32(tmpTotal >> vStepLog)
642 sEnd = uint32(end >> vStepLog)
643 weight = sEnd - sStart
644 )
645 if weight < 1 {
646 return errors.New("weight < 1")
647 }
648 s.norm[i] = int16(weight)
649 tmpTotal = end
650 }
651 }
652 return nil
653}
654
655// validateNorm validates the normalized histogram table.
656func (s *Scratch) validateNorm() (err error) {
657 var total int
658 for _, v := range s.norm[:s.symbolLen] {
659 if v >= 0 {
660 total += int(v)
661 } else {
662 total -= int(v)
663 }
664 }
665 defer func() {
666 if err == nil {
667 return
668 }
669 fmt.Printf("selected TableLog: %d, Symbol length: %d\n", s.actualTableLog, s.symbolLen)
670 for i, v := range s.norm[:s.symbolLen] {
671 fmt.Printf("%3d: %5d -> %4d \n", i, s.count[i], v)
672 }
673 }()
674 if total != (1 << s.actualTableLog) {
675 return fmt.Errorf("warning: Total == %d != %d", total, 1<<s.actualTableLog)
676 }
677 for i, v := range s.count[s.symbolLen:] {
678 if v != 0 {
679 return fmt.Errorf("warning: Found symbol out of range, %d after cut", i)
680 }
681 }
682 return nil
683}