blob: 8323dc053890b8b37b3cc69cd824cb58bbfa0431 [file] [log] [blame]
kesavandc71914f2022-03-25 11:19:03 +05301package huff0
2
3import (
4 "fmt"
5 "runtime"
6 "sync"
7)
8
9// Compress1X will compress the input.
10// The output can be decoded using Decompress1X.
11// Supply a Scratch object. The scratch object contains state about re-use,
12// So when sharing across independent encodes, be sure to set the re-use policy.
13func Compress1X(in []byte, s *Scratch) (out []byte, reUsed bool, err error) {
14 s, err = s.prepare(in)
15 if err != nil {
16 return nil, false, err
17 }
18 return compress(in, s, s.compress1X)
19}
20
21// Compress4X will compress the input. The input is split into 4 independent blocks
22// and compressed similar to Compress1X.
23// The output can be decoded using Decompress4X.
24// Supply a Scratch object. The scratch object contains state about re-use,
25// So when sharing across independent encodes, be sure to set the re-use policy.
26func Compress4X(in []byte, s *Scratch) (out []byte, reUsed bool, err error) {
27 s, err = s.prepare(in)
28 if err != nil {
29 return nil, false, err
30 }
31 if false {
32 // TODO: compress4Xp only slightly faster.
33 const parallelThreshold = 8 << 10
34 if len(in) < parallelThreshold || runtime.GOMAXPROCS(0) == 1 {
35 return compress(in, s, s.compress4X)
36 }
37 return compress(in, s, s.compress4Xp)
38 }
39 return compress(in, s, s.compress4X)
40}
41
42func compress(in []byte, s *Scratch, compressor func(src []byte) ([]byte, error)) (out []byte, reUsed bool, err error) {
43 // Nuke previous table if we cannot reuse anyway.
44 if s.Reuse == ReusePolicyNone {
45 s.prevTable = s.prevTable[:0]
46 }
47
48 // Create histogram, if none was provided.
49 maxCount := s.maxCount
50 var canReuse = false
51 if maxCount == 0 {
52 maxCount, canReuse = s.countSimple(in)
53 } else {
54 canReuse = s.canUseTable(s.prevTable)
55 }
56
57 // We want the output size to be less than this:
58 wantSize := len(in)
59 if s.WantLogLess > 0 {
60 wantSize -= wantSize >> s.WantLogLess
61 }
62
63 // Reset for next run.
64 s.clearCount = true
65 s.maxCount = 0
66 if maxCount >= len(in) {
67 if maxCount > len(in) {
68 return nil, false, fmt.Errorf("maxCount (%d) > length (%d)", maxCount, len(in))
69 }
70 if len(in) == 1 {
71 return nil, false, ErrIncompressible
72 }
73 // One symbol, use RLE
74 return nil, false, ErrUseRLE
75 }
76 if maxCount == 1 || maxCount < (len(in)>>7) {
77 // Each symbol present maximum once or too well distributed.
78 return nil, false, ErrIncompressible
79 }
80 if s.Reuse == ReusePolicyMust && !canReuse {
81 // We must reuse, but we can't.
82 return nil, false, ErrIncompressible
83 }
84 if (s.Reuse == ReusePolicyPrefer || s.Reuse == ReusePolicyMust) && canReuse {
85 keepTable := s.cTable
86 keepTL := s.actualTableLog
87 s.cTable = s.prevTable
88 s.actualTableLog = s.prevTableLog
89 s.Out, err = compressor(in)
90 s.cTable = keepTable
91 s.actualTableLog = keepTL
92 if err == nil && len(s.Out) < wantSize {
93 s.OutData = s.Out
94 return s.Out, true, nil
95 }
96 if s.Reuse == ReusePolicyMust {
97 return nil, false, ErrIncompressible
98 }
99 // Do not attempt to re-use later.
100 s.prevTable = s.prevTable[:0]
101 }
102
103 // Calculate new table.
104 err = s.buildCTable()
105 if err != nil {
106 return nil, false, err
107 }
108
109 if false && !s.canUseTable(s.cTable) {
110 panic("invalid table generated")
111 }
112
113 if s.Reuse == ReusePolicyAllow && canReuse {
114 hSize := len(s.Out)
115 oldSize := s.prevTable.estimateSize(s.count[:s.symbolLen])
116 newSize := s.cTable.estimateSize(s.count[:s.symbolLen])
117 if oldSize <= hSize+newSize || hSize+12 >= wantSize {
118 // Retain cTable even if we re-use.
119 keepTable := s.cTable
120 keepTL := s.actualTableLog
121
122 s.cTable = s.prevTable
123 s.actualTableLog = s.prevTableLog
124 s.Out, err = compressor(in)
125
126 // Restore ctable.
127 s.cTable = keepTable
128 s.actualTableLog = keepTL
129 if err != nil {
130 return nil, false, err
131 }
132 if len(s.Out) >= wantSize {
133 return nil, false, ErrIncompressible
134 }
135 s.OutData = s.Out
136 return s.Out, true, nil
137 }
138 }
139
140 // Use new table
141 err = s.cTable.write(s)
142 if err != nil {
143 s.OutTable = nil
144 return nil, false, err
145 }
146 s.OutTable = s.Out
147
148 // Compress using new table
149 s.Out, err = compressor(in)
150 if err != nil {
151 s.OutTable = nil
152 return nil, false, err
153 }
154 if len(s.Out) >= wantSize {
155 s.OutTable = nil
156 return nil, false, ErrIncompressible
157 }
158 // Move current table into previous.
159 s.prevTable, s.prevTableLog, s.cTable = s.cTable, s.actualTableLog, s.prevTable[:0]
160 s.OutData = s.Out[len(s.OutTable):]
161 return s.Out, false, nil
162}
163
164// EstimateSizes will estimate the data sizes
165func EstimateSizes(in []byte, s *Scratch) (tableSz, dataSz, reuseSz int, err error) {
166 s, err = s.prepare(in)
167 if err != nil {
168 return 0, 0, 0, err
169 }
170
171 // Create histogram, if none was provided.
172 tableSz, dataSz, reuseSz = -1, -1, -1
173 maxCount := s.maxCount
174 var canReuse = false
175 if maxCount == 0 {
176 maxCount, canReuse = s.countSimple(in)
177 } else {
178 canReuse = s.canUseTable(s.prevTable)
179 }
180
181 // We want the output size to be less than this:
182 wantSize := len(in)
183 if s.WantLogLess > 0 {
184 wantSize -= wantSize >> s.WantLogLess
185 }
186
187 // Reset for next run.
188 s.clearCount = true
189 s.maxCount = 0
190 if maxCount >= len(in) {
191 if maxCount > len(in) {
192 return 0, 0, 0, fmt.Errorf("maxCount (%d) > length (%d)", maxCount, len(in))
193 }
194 if len(in) == 1 {
195 return 0, 0, 0, ErrIncompressible
196 }
197 // One symbol, use RLE
198 return 0, 0, 0, ErrUseRLE
199 }
200 if maxCount == 1 || maxCount < (len(in)>>7) {
201 // Each symbol present maximum once or too well distributed.
202 return 0, 0, 0, ErrIncompressible
203 }
204
205 // Calculate new table.
206 err = s.buildCTable()
207 if err != nil {
208 return 0, 0, 0, err
209 }
210
211 if false && !s.canUseTable(s.cTable) {
212 panic("invalid table generated")
213 }
214
215 tableSz, err = s.cTable.estTableSize(s)
216 if err != nil {
217 return 0, 0, 0, err
218 }
219 if canReuse {
220 reuseSz = s.prevTable.estimateSize(s.count[:s.symbolLen])
221 }
222 dataSz = s.cTable.estimateSize(s.count[:s.symbolLen])
223
224 // Restore
225 return tableSz, dataSz, reuseSz, nil
226}
227
228func (s *Scratch) compress1X(src []byte) ([]byte, error) {
229 return s.compress1xDo(s.Out, src)
230}
231
232func (s *Scratch) compress1xDo(dst, src []byte) ([]byte, error) {
233 var bw = bitWriter{out: dst}
234
235 // N is length divisible by 4.
236 n := len(src)
237 n -= n & 3
238 cTable := s.cTable[:256]
239
240 // Encode last bytes.
241 for i := len(src) & 3; i > 0; i-- {
242 bw.encSymbol(cTable, src[n+i-1])
243 }
244 n -= 4
245 if s.actualTableLog <= 8 {
246 for ; n >= 0; n -= 4 {
247 tmp := src[n : n+4]
248 // tmp should be len 4
249 bw.flush32()
250 bw.encTwoSymbols(cTable, tmp[3], tmp[2])
251 bw.encTwoSymbols(cTable, tmp[1], tmp[0])
252 }
253 } else {
254 for ; n >= 0; n -= 4 {
255 tmp := src[n : n+4]
256 // tmp should be len 4
257 bw.flush32()
258 bw.encTwoSymbols(cTable, tmp[3], tmp[2])
259 bw.flush32()
260 bw.encTwoSymbols(cTable, tmp[1], tmp[0])
261 }
262 }
263 err := bw.close()
264 return bw.out, err
265}
266
267var sixZeros [6]byte
268
269func (s *Scratch) compress4X(src []byte) ([]byte, error) {
270 if len(src) < 12 {
271 return nil, ErrIncompressible
272 }
273 segmentSize := (len(src) + 3) / 4
274
275 // Add placeholder for output length
276 offsetIdx := len(s.Out)
277 s.Out = append(s.Out, sixZeros[:]...)
278
279 for i := 0; i < 4; i++ {
280 toDo := src
281 if len(toDo) > segmentSize {
282 toDo = toDo[:segmentSize]
283 }
284 src = src[len(toDo):]
285
286 var err error
287 idx := len(s.Out)
288 s.Out, err = s.compress1xDo(s.Out, toDo)
289 if err != nil {
290 return nil, err
291 }
292 // Write compressed length as little endian before block.
293 if i < 3 {
294 // Last length is not written.
295 length := len(s.Out) - idx
296 s.Out[i*2+offsetIdx] = byte(length)
297 s.Out[i*2+offsetIdx+1] = byte(length >> 8)
298 }
299 }
300
301 return s.Out, nil
302}
303
304// compress4Xp will compress 4 streams using separate goroutines.
305func (s *Scratch) compress4Xp(src []byte) ([]byte, error) {
306 if len(src) < 12 {
307 return nil, ErrIncompressible
308 }
309 // Add placeholder for output length
310 s.Out = s.Out[:6]
311
312 segmentSize := (len(src) + 3) / 4
313 var wg sync.WaitGroup
314 var errs [4]error
315 wg.Add(4)
316 for i := 0; i < 4; i++ {
317 toDo := src
318 if len(toDo) > segmentSize {
319 toDo = toDo[:segmentSize]
320 }
321 src = src[len(toDo):]
322
323 // Separate goroutine for each block.
324 go func(i int) {
325 s.tmpOut[i], errs[i] = s.compress1xDo(s.tmpOut[i][:0], toDo)
326 wg.Done()
327 }(i)
328 }
329 wg.Wait()
330 for i := 0; i < 4; i++ {
331 if errs[i] != nil {
332 return nil, errs[i]
333 }
334 o := s.tmpOut[i]
335 // Write compressed length as little endian before block.
336 if i < 3 {
337 // Last length is not written.
338 s.Out[i*2] = byte(len(o))
339 s.Out[i*2+1] = byte(len(o) >> 8)
340 }
341
342 // Write output.
343 s.Out = append(s.Out, o...)
344 }
345 return s.Out, nil
346}
347
348// countSimple will create a simple histogram in s.count.
349// Returns the biggest count.
350// Does not update s.clearCount.
351func (s *Scratch) countSimple(in []byte) (max int, reuse bool) {
352 reuse = true
353 for _, v := range in {
354 s.count[v]++
355 }
356 m := uint32(0)
357 if len(s.prevTable) > 0 {
358 for i, v := range s.count[:] {
359 if v > m {
360 m = v
361 }
362 if v > 0 {
363 s.symbolLen = uint16(i) + 1
364 if i >= len(s.prevTable) {
365 reuse = false
366 } else {
367 if s.prevTable[i].nBits == 0 {
368 reuse = false
369 }
370 }
371 }
372 }
373 return int(m), reuse
374 }
375 for i, v := range s.count[:] {
376 if v > m {
377 m = v
378 }
379 if v > 0 {
380 s.symbolLen = uint16(i) + 1
381 }
382 }
383 return int(m), false
384}
385
386func (s *Scratch) canUseTable(c cTable) bool {
387 if len(c) < int(s.symbolLen) {
388 return false
389 }
390 for i, v := range s.count[:s.symbolLen] {
391 if v != 0 && c[i].nBits == 0 {
392 return false
393 }
394 }
395 return true
396}
397
398func (s *Scratch) validateTable(c cTable) bool {
399 if len(c) < int(s.symbolLen) {
400 return false
401 }
402 for i, v := range s.count[:s.symbolLen] {
403 if v != 0 {
404 if c[i].nBits == 0 {
405 return false
406 }
407 if c[i].nBits > s.actualTableLog {
408 return false
409 }
410 }
411 }
412 return true
413}
414
415// minTableLog provides the minimum logSize to safely represent a distribution.
416func (s *Scratch) minTableLog() uint8 {
417 minBitsSrc := highBit32(uint32(s.br.remain())) + 1
418 minBitsSymbols := highBit32(uint32(s.symbolLen-1)) + 2
419 if minBitsSrc < minBitsSymbols {
420 return uint8(minBitsSrc)
421 }
422 return uint8(minBitsSymbols)
423}
424
425// optimalTableLog calculates and sets the optimal tableLog in s.actualTableLog
426func (s *Scratch) optimalTableLog() {
427 tableLog := s.TableLog
428 minBits := s.minTableLog()
429 maxBitsSrc := uint8(highBit32(uint32(s.br.remain()-1))) - 1
430 if maxBitsSrc < tableLog {
431 // Accuracy can be reduced
432 tableLog = maxBitsSrc
433 }
434 if minBits > tableLog {
435 tableLog = minBits
436 }
437 // Need a minimum to safely represent all symbol values
438 if tableLog < minTablelog {
439 tableLog = minTablelog
440 }
441 if tableLog > tableLogMax {
442 tableLog = tableLogMax
443 }
444 s.actualTableLog = tableLog
445}
446
447type cTableEntry struct {
448 val uint16
449 nBits uint8
450 // We have 8 bits extra
451}
452
453const huffNodesMask = huffNodesLen - 1
454
455func (s *Scratch) buildCTable() error {
456 s.optimalTableLog()
457 s.huffSort()
458 if cap(s.cTable) < maxSymbolValue+1 {
459 s.cTable = make([]cTableEntry, s.symbolLen, maxSymbolValue+1)
460 } else {
461 s.cTable = s.cTable[:s.symbolLen]
462 for i := range s.cTable {
463 s.cTable[i] = cTableEntry{}
464 }
465 }
466
467 var startNode = int16(s.symbolLen)
468 nonNullRank := s.symbolLen - 1
469
470 nodeNb := startNode
471 huffNode := s.nodes[1 : huffNodesLen+1]
472
473 // This overlays the slice above, but allows "-1" index lookups.
474 // Different from reference implementation.
475 huffNode0 := s.nodes[0 : huffNodesLen+1]
476
477 for huffNode[nonNullRank].count == 0 {
478 nonNullRank--
479 }
480
481 lowS := int16(nonNullRank)
482 nodeRoot := nodeNb + lowS - 1
483 lowN := nodeNb
484 huffNode[nodeNb].count = huffNode[lowS].count + huffNode[lowS-1].count
485 huffNode[lowS].parent, huffNode[lowS-1].parent = uint16(nodeNb), uint16(nodeNb)
486 nodeNb++
487 lowS -= 2
488 for n := nodeNb; n <= nodeRoot; n++ {
489 huffNode[n].count = 1 << 30
490 }
491 // fake entry, strong barrier
492 huffNode0[0].count = 1 << 31
493
494 // create parents
495 for nodeNb <= nodeRoot {
496 var n1, n2 int16
497 if huffNode0[lowS+1].count < huffNode0[lowN+1].count {
498 n1 = lowS
499 lowS--
500 } else {
501 n1 = lowN
502 lowN++
503 }
504 if huffNode0[lowS+1].count < huffNode0[lowN+1].count {
505 n2 = lowS
506 lowS--
507 } else {
508 n2 = lowN
509 lowN++
510 }
511
512 huffNode[nodeNb].count = huffNode0[n1+1].count + huffNode0[n2+1].count
513 huffNode0[n1+1].parent, huffNode0[n2+1].parent = uint16(nodeNb), uint16(nodeNb)
514 nodeNb++
515 }
516
517 // distribute weights (unlimited tree height)
518 huffNode[nodeRoot].nbBits = 0
519 for n := nodeRoot - 1; n >= startNode; n-- {
520 huffNode[n].nbBits = huffNode[huffNode[n].parent].nbBits + 1
521 }
522 for n := uint16(0); n <= nonNullRank; n++ {
523 huffNode[n].nbBits = huffNode[huffNode[n].parent].nbBits + 1
524 }
525 s.actualTableLog = s.setMaxHeight(int(nonNullRank))
526 maxNbBits := s.actualTableLog
527
528 // fill result into tree (val, nbBits)
529 if maxNbBits > tableLogMax {
530 return fmt.Errorf("internal error: maxNbBits (%d) > tableLogMax (%d)", maxNbBits, tableLogMax)
531 }
532 var nbPerRank [tableLogMax + 1]uint16
533 var valPerRank [16]uint16
534 for _, v := range huffNode[:nonNullRank+1] {
535 nbPerRank[v.nbBits]++
536 }
537 // determine stating value per rank
538 {
539 min := uint16(0)
540 for n := maxNbBits; n > 0; n-- {
541 // get starting value within each rank
542 valPerRank[n] = min
543 min += nbPerRank[n]
544 min >>= 1
545 }
546 }
547
548 // push nbBits per symbol, symbol order
549 for _, v := range huffNode[:nonNullRank+1] {
550 s.cTable[v.symbol].nBits = v.nbBits
551 }
552
553 // assign value within rank, symbol order
554 t := s.cTable[:s.symbolLen]
555 for n, val := range t {
556 nbits := val.nBits & 15
557 v := valPerRank[nbits]
558 t[n].val = v
559 valPerRank[nbits] = v + 1
560 }
561
562 return nil
563}
564
565// huffSort will sort symbols, decreasing order.
566func (s *Scratch) huffSort() {
567 type rankPos struct {
568 base uint32
569 current uint32
570 }
571
572 // Clear nodes
573 nodes := s.nodes[:huffNodesLen+1]
574 s.nodes = nodes
575 nodes = nodes[1 : huffNodesLen+1]
576
577 // Sort into buckets based on length of symbol count.
578 var rank [32]rankPos
579 for _, v := range s.count[:s.symbolLen] {
580 r := highBit32(v+1) & 31
581 rank[r].base++
582 }
583 // maxBitLength is log2(BlockSizeMax) + 1
584 const maxBitLength = 18 + 1
585 for n := maxBitLength; n > 0; n-- {
586 rank[n-1].base += rank[n].base
587 }
588 for n := range rank[:maxBitLength] {
589 rank[n].current = rank[n].base
590 }
591 for n, c := range s.count[:s.symbolLen] {
592 r := (highBit32(c+1) + 1) & 31
593 pos := rank[r].current
594 rank[r].current++
595 prev := nodes[(pos-1)&huffNodesMask]
596 for pos > rank[r].base && c > prev.count {
597 nodes[pos&huffNodesMask] = prev
598 pos--
599 prev = nodes[(pos-1)&huffNodesMask]
600 }
601 nodes[pos&huffNodesMask] = nodeElt{count: c, symbol: byte(n)}
602 }
603}
604
605func (s *Scratch) setMaxHeight(lastNonNull int) uint8 {
606 maxNbBits := s.actualTableLog
607 huffNode := s.nodes[1 : huffNodesLen+1]
608 //huffNode = huffNode[: huffNodesLen]
609
610 largestBits := huffNode[lastNonNull].nbBits
611
612 // early exit : no elt > maxNbBits
613 if largestBits <= maxNbBits {
614 return largestBits
615 }
616 totalCost := int(0)
617 baseCost := int(1) << (largestBits - maxNbBits)
618 n := uint32(lastNonNull)
619
620 for huffNode[n].nbBits > maxNbBits {
621 totalCost += baseCost - (1 << (largestBits - huffNode[n].nbBits))
622 huffNode[n].nbBits = maxNbBits
623 n--
624 }
625 // n stops at huffNode[n].nbBits <= maxNbBits
626
627 for huffNode[n].nbBits == maxNbBits {
628 n--
629 }
630 // n end at index of smallest symbol using < maxNbBits
631
632 // renorm totalCost
633 totalCost >>= largestBits - maxNbBits /* note : totalCost is necessarily a multiple of baseCost */
634
635 // repay normalized cost
636 {
637 const noSymbol = 0xF0F0F0F0
638 var rankLast [tableLogMax + 2]uint32
639
640 for i := range rankLast[:] {
641 rankLast[i] = noSymbol
642 }
643
644 // Get pos of last (smallest) symbol per rank
645 {
646 currentNbBits := maxNbBits
647 for pos := int(n); pos >= 0; pos-- {
648 if huffNode[pos].nbBits >= currentNbBits {
649 continue
650 }
651 currentNbBits = huffNode[pos].nbBits // < maxNbBits
652 rankLast[maxNbBits-currentNbBits] = uint32(pos)
653 }
654 }
655
656 for totalCost > 0 {
657 nBitsToDecrease := uint8(highBit32(uint32(totalCost))) + 1
658
659 for ; nBitsToDecrease > 1; nBitsToDecrease-- {
660 highPos := rankLast[nBitsToDecrease]
661 lowPos := rankLast[nBitsToDecrease-1]
662 if highPos == noSymbol {
663 continue
664 }
665 if lowPos == noSymbol {
666 break
667 }
668 highTotal := huffNode[highPos].count
669 lowTotal := 2 * huffNode[lowPos].count
670 if highTotal <= lowTotal {
671 break
672 }
673 }
674 // only triggered when no more rank 1 symbol left => find closest one (note : there is necessarily at least one !)
675 // HUF_MAX_TABLELOG test just to please gcc 5+; but it should not be necessary
676 // FIXME: try to remove
677 for (nBitsToDecrease <= tableLogMax) && (rankLast[nBitsToDecrease] == noSymbol) {
678 nBitsToDecrease++
679 }
680 totalCost -= 1 << (nBitsToDecrease - 1)
681 if rankLast[nBitsToDecrease-1] == noSymbol {
682 // this rank is no longer empty
683 rankLast[nBitsToDecrease-1] = rankLast[nBitsToDecrease]
684 }
685 huffNode[rankLast[nBitsToDecrease]].nbBits++
686 if rankLast[nBitsToDecrease] == 0 {
687 /* special case, reached largest symbol */
688 rankLast[nBitsToDecrease] = noSymbol
689 } else {
690 rankLast[nBitsToDecrease]--
691 if huffNode[rankLast[nBitsToDecrease]].nbBits != maxNbBits-nBitsToDecrease {
692 rankLast[nBitsToDecrease] = noSymbol /* this rank is now empty */
693 }
694 }
695 }
696
697 for totalCost < 0 { /* Sometimes, cost correction overshoot */
698 if rankLast[1] == noSymbol { /* special case : no rank 1 symbol (using maxNbBits-1); let's create one from largest rank 0 (using maxNbBits) */
699 for huffNode[n].nbBits == maxNbBits {
700 n--
701 }
702 huffNode[n+1].nbBits--
703 rankLast[1] = n + 1
704 totalCost++
705 continue
706 }
707 huffNode[rankLast[1]+1].nbBits--
708 rankLast[1]++
709 totalCost++
710 }
711 }
712 return maxNbBits
713}
714
715type nodeElt struct {
716 count uint32
717 parent uint16
718 symbol byte
719 nbBits uint8
720}