blob: 0823c928cec87c278e2222fad0a267881377798e [file] [log] [blame]
khenaidood948f772021-08-11 17:49:24 -04001package huff0
2
3import (
4 "fmt"
5 "runtime"
6 "sync"
7)
8
9// Compress1X will compress the input.
10// The output can be decoded using Decompress1X.
11// Supply a Scratch object. The scratch object contains state about re-use,
12// So when sharing across independent encodes, be sure to set the re-use policy.
13func Compress1X(in []byte, s *Scratch) (out []byte, reUsed bool, err error) {
14 s, err = s.prepare(in)
15 if err != nil {
16 return nil, false, err
17 }
18 return compress(in, s, s.compress1X)
19}
20
21// Compress4X will compress the input. The input is split into 4 independent blocks
22// and compressed similar to Compress1X.
23// The output can be decoded using Decompress4X.
24// Supply a Scratch object. The scratch object contains state about re-use,
25// So when sharing across independent encodes, be sure to set the re-use policy.
26func Compress4X(in []byte, s *Scratch) (out []byte, reUsed bool, err error) {
27 s, err = s.prepare(in)
28 if err != nil {
29 return nil, false, err
30 }
31 if false {
32 // TODO: compress4Xp only slightly faster.
33 const parallelThreshold = 8 << 10
34 if len(in) < parallelThreshold || runtime.GOMAXPROCS(0) == 1 {
35 return compress(in, s, s.compress4X)
36 }
37 return compress(in, s, s.compress4Xp)
38 }
39 return compress(in, s, s.compress4X)
40}
41
42func compress(in []byte, s *Scratch, compressor func(src []byte) ([]byte, error)) (out []byte, reUsed bool, err error) {
43 // Nuke previous table if we cannot reuse anyway.
44 if s.Reuse == ReusePolicyNone {
45 s.prevTable = s.prevTable[:0]
46 }
47
48 // Create histogram, if none was provided.
49 maxCount := s.maxCount
50 var canReuse = false
51 if maxCount == 0 {
52 maxCount, canReuse = s.countSimple(in)
53 } else {
54 canReuse = s.canUseTable(s.prevTable)
55 }
56
57 // We want the output size to be less than this:
58 wantSize := len(in)
59 if s.WantLogLess > 0 {
60 wantSize -= wantSize >> s.WantLogLess
61 }
62
63 // Reset for next run.
64 s.clearCount = true
65 s.maxCount = 0
66 if maxCount >= len(in) {
67 if maxCount > len(in) {
68 return nil, false, fmt.Errorf("maxCount (%d) > length (%d)", maxCount, len(in))
69 }
70 if len(in) == 1 {
71 return nil, false, ErrIncompressible
72 }
73 // One symbol, use RLE
74 return nil, false, ErrUseRLE
75 }
76 if maxCount == 1 || maxCount < (len(in)>>7) {
77 // Each symbol present maximum once or too well distributed.
78 return nil, false, ErrIncompressible
79 }
80 if s.Reuse == ReusePolicyMust && !canReuse {
81 // We must reuse, but we can't.
82 return nil, false, ErrIncompressible
83 }
84 if (s.Reuse == ReusePolicyPrefer || s.Reuse == ReusePolicyMust) && canReuse {
85 keepTable := s.cTable
86 keepTL := s.actualTableLog
87 s.cTable = s.prevTable
88 s.actualTableLog = s.prevTableLog
89 s.Out, err = compressor(in)
90 s.cTable = keepTable
91 s.actualTableLog = keepTL
92 if err == nil && len(s.Out) < wantSize {
93 s.OutData = s.Out
94 return s.Out, true, nil
95 }
96 if s.Reuse == ReusePolicyMust {
97 return nil, false, ErrIncompressible
98 }
99 // Do not attempt to re-use later.
100 s.prevTable = s.prevTable[:0]
101 }
102
103 // Calculate new table.
104 err = s.buildCTable()
105 if err != nil {
106 return nil, false, err
107 }
108
109 if false && !s.canUseTable(s.cTable) {
110 panic("invalid table generated")
111 }
112
113 if s.Reuse == ReusePolicyAllow && canReuse {
114 hSize := len(s.Out)
115 oldSize := s.prevTable.estimateSize(s.count[:s.symbolLen])
116 newSize := s.cTable.estimateSize(s.count[:s.symbolLen])
117 if oldSize <= hSize+newSize || hSize+12 >= wantSize {
118 // Retain cTable even if we re-use.
119 keepTable := s.cTable
120 keepTL := s.actualTableLog
121
122 s.cTable = s.prevTable
123 s.actualTableLog = s.prevTableLog
124 s.Out, err = compressor(in)
125
126 // Restore ctable.
127 s.cTable = keepTable
128 s.actualTableLog = keepTL
129 if err != nil {
130 return nil, false, err
131 }
132 if len(s.Out) >= wantSize {
133 return nil, false, ErrIncompressible
134 }
135 s.OutData = s.Out
136 return s.Out, true, nil
137 }
138 }
139
140 // Use new table
141 err = s.cTable.write(s)
142 if err != nil {
143 s.OutTable = nil
144 return nil, false, err
145 }
146 s.OutTable = s.Out
147
148 // Compress using new table
149 s.Out, err = compressor(in)
150 if err != nil {
151 s.OutTable = nil
152 return nil, false, err
153 }
154 if len(s.Out) >= wantSize {
155 s.OutTable = nil
156 return nil, false, ErrIncompressible
157 }
158 // Move current table into previous.
159 s.prevTable, s.prevTableLog, s.cTable = s.cTable, s.actualTableLog, s.prevTable[:0]
160 s.OutData = s.Out[len(s.OutTable):]
161 return s.Out, false, nil
162}
163
164func (s *Scratch) compress1X(src []byte) ([]byte, error) {
165 return s.compress1xDo(s.Out, src)
166}
167
168func (s *Scratch) compress1xDo(dst, src []byte) ([]byte, error) {
169 var bw = bitWriter{out: dst}
170
171 // N is length divisible by 4.
172 n := len(src)
173 n -= n & 3
174 cTable := s.cTable[:256]
175
176 // Encode last bytes.
177 for i := len(src) & 3; i > 0; i-- {
178 bw.encSymbol(cTable, src[n+i-1])
179 }
180 n -= 4
181 if s.actualTableLog <= 8 {
182 for ; n >= 0; n -= 4 {
183 tmp := src[n : n+4]
184 // tmp should be len 4
185 bw.flush32()
186 bw.encTwoSymbols(cTable, tmp[3], tmp[2])
187 bw.encTwoSymbols(cTable, tmp[1], tmp[0])
188 }
189 } else {
190 for ; n >= 0; n -= 4 {
191 tmp := src[n : n+4]
192 // tmp should be len 4
193 bw.flush32()
194 bw.encTwoSymbols(cTable, tmp[3], tmp[2])
195 bw.flush32()
196 bw.encTwoSymbols(cTable, tmp[1], tmp[0])
197 }
198 }
199 err := bw.close()
200 return bw.out, err
201}
202
203var sixZeros [6]byte
204
205func (s *Scratch) compress4X(src []byte) ([]byte, error) {
206 if len(src) < 12 {
207 return nil, ErrIncompressible
208 }
209 segmentSize := (len(src) + 3) / 4
210
211 // Add placeholder for output length
212 offsetIdx := len(s.Out)
213 s.Out = append(s.Out, sixZeros[:]...)
214
215 for i := 0; i < 4; i++ {
216 toDo := src
217 if len(toDo) > segmentSize {
218 toDo = toDo[:segmentSize]
219 }
220 src = src[len(toDo):]
221
222 var err error
223 idx := len(s.Out)
224 s.Out, err = s.compress1xDo(s.Out, toDo)
225 if err != nil {
226 return nil, err
227 }
228 // Write compressed length as little endian before block.
229 if i < 3 {
230 // Last length is not written.
231 length := len(s.Out) - idx
232 s.Out[i*2+offsetIdx] = byte(length)
233 s.Out[i*2+offsetIdx+1] = byte(length >> 8)
234 }
235 }
236
237 return s.Out, nil
238}
239
240// compress4Xp will compress 4 streams using separate goroutines.
241func (s *Scratch) compress4Xp(src []byte) ([]byte, error) {
242 if len(src) < 12 {
243 return nil, ErrIncompressible
244 }
245 // Add placeholder for output length
246 s.Out = s.Out[:6]
247
248 segmentSize := (len(src) + 3) / 4
249 var wg sync.WaitGroup
250 var errs [4]error
251 wg.Add(4)
252 for i := 0; i < 4; i++ {
253 toDo := src
254 if len(toDo) > segmentSize {
255 toDo = toDo[:segmentSize]
256 }
257 src = src[len(toDo):]
258
259 // Separate goroutine for each block.
260 go func(i int) {
261 s.tmpOut[i], errs[i] = s.compress1xDo(s.tmpOut[i][:0], toDo)
262 wg.Done()
263 }(i)
264 }
265 wg.Wait()
266 for i := 0; i < 4; i++ {
267 if errs[i] != nil {
268 return nil, errs[i]
269 }
270 o := s.tmpOut[i]
271 // Write compressed length as little endian before block.
272 if i < 3 {
273 // Last length is not written.
274 s.Out[i*2] = byte(len(o))
275 s.Out[i*2+1] = byte(len(o) >> 8)
276 }
277
278 // Write output.
279 s.Out = append(s.Out, o...)
280 }
281 return s.Out, nil
282}
283
284// countSimple will create a simple histogram in s.count.
285// Returns the biggest count.
286// Does not update s.clearCount.
287func (s *Scratch) countSimple(in []byte) (max int, reuse bool) {
288 reuse = true
289 for _, v := range in {
290 s.count[v]++
291 }
292 m := uint32(0)
293 if len(s.prevTable) > 0 {
294 for i, v := range s.count[:] {
295 if v > m {
296 m = v
297 }
298 if v > 0 {
299 s.symbolLen = uint16(i) + 1
300 if i >= len(s.prevTable) {
301 reuse = false
302 } else {
303 if s.prevTable[i].nBits == 0 {
304 reuse = false
305 }
306 }
307 }
308 }
309 return int(m), reuse
310 }
311 for i, v := range s.count[:] {
312 if v > m {
313 m = v
314 }
315 if v > 0 {
316 s.symbolLen = uint16(i) + 1
317 }
318 }
319 return int(m), false
320}
321
322func (s *Scratch) canUseTable(c cTable) bool {
323 if len(c) < int(s.symbolLen) {
324 return false
325 }
326 for i, v := range s.count[:s.symbolLen] {
327 if v != 0 && c[i].nBits == 0 {
328 return false
329 }
330 }
331 return true
332}
333
334func (s *Scratch) validateTable(c cTable) bool {
335 if len(c) < int(s.symbolLen) {
336 return false
337 }
338 for i, v := range s.count[:s.symbolLen] {
339 if v != 0 {
340 if c[i].nBits == 0 {
341 return false
342 }
343 if c[i].nBits > s.actualTableLog {
344 return false
345 }
346 }
347 }
348 return true
349}
350
351// minTableLog provides the minimum logSize to safely represent a distribution.
352func (s *Scratch) minTableLog() uint8 {
353 minBitsSrc := highBit32(uint32(s.br.remain())) + 1
354 minBitsSymbols := highBit32(uint32(s.symbolLen-1)) + 2
355 if minBitsSrc < minBitsSymbols {
356 return uint8(minBitsSrc)
357 }
358 return uint8(minBitsSymbols)
359}
360
361// optimalTableLog calculates and sets the optimal tableLog in s.actualTableLog
362func (s *Scratch) optimalTableLog() {
363 tableLog := s.TableLog
364 minBits := s.minTableLog()
365 maxBitsSrc := uint8(highBit32(uint32(s.br.remain()-1))) - 1
366 if maxBitsSrc < tableLog {
367 // Accuracy can be reduced
368 tableLog = maxBitsSrc
369 }
370 if minBits > tableLog {
371 tableLog = minBits
372 }
373 // Need a minimum to safely represent all symbol values
374 if tableLog < minTablelog {
375 tableLog = minTablelog
376 }
377 if tableLog > tableLogMax {
378 tableLog = tableLogMax
379 }
380 s.actualTableLog = tableLog
381}
382
383type cTableEntry struct {
384 val uint16
385 nBits uint8
386 // We have 8 bits extra
387}
388
389const huffNodesMask = huffNodesLen - 1
390
391func (s *Scratch) buildCTable() error {
392 s.optimalTableLog()
393 s.huffSort()
394 if cap(s.cTable) < maxSymbolValue+1 {
395 s.cTable = make([]cTableEntry, s.symbolLen, maxSymbolValue+1)
396 } else {
397 s.cTable = s.cTable[:s.symbolLen]
398 for i := range s.cTable {
399 s.cTable[i] = cTableEntry{}
400 }
401 }
402
403 var startNode = int16(s.symbolLen)
404 nonNullRank := s.symbolLen - 1
405
406 nodeNb := startNode
407 huffNode := s.nodes[1 : huffNodesLen+1]
408
409 // This overlays the slice above, but allows "-1" index lookups.
410 // Different from reference implementation.
411 huffNode0 := s.nodes[0 : huffNodesLen+1]
412
413 for huffNode[nonNullRank].count == 0 {
414 nonNullRank--
415 }
416
417 lowS := int16(nonNullRank)
418 nodeRoot := nodeNb + lowS - 1
419 lowN := nodeNb
420 huffNode[nodeNb].count = huffNode[lowS].count + huffNode[lowS-1].count
421 huffNode[lowS].parent, huffNode[lowS-1].parent = uint16(nodeNb), uint16(nodeNb)
422 nodeNb++
423 lowS -= 2
424 for n := nodeNb; n <= nodeRoot; n++ {
425 huffNode[n].count = 1 << 30
426 }
427 // fake entry, strong barrier
428 huffNode0[0].count = 1 << 31
429
430 // create parents
431 for nodeNb <= nodeRoot {
432 var n1, n2 int16
433 if huffNode0[lowS+1].count < huffNode0[lowN+1].count {
434 n1 = lowS
435 lowS--
436 } else {
437 n1 = lowN
438 lowN++
439 }
440 if huffNode0[lowS+1].count < huffNode0[lowN+1].count {
441 n2 = lowS
442 lowS--
443 } else {
444 n2 = lowN
445 lowN++
446 }
447
448 huffNode[nodeNb].count = huffNode0[n1+1].count + huffNode0[n2+1].count
449 huffNode0[n1+1].parent, huffNode0[n2+1].parent = uint16(nodeNb), uint16(nodeNb)
450 nodeNb++
451 }
452
453 // distribute weights (unlimited tree height)
454 huffNode[nodeRoot].nbBits = 0
455 for n := nodeRoot - 1; n >= startNode; n-- {
456 huffNode[n].nbBits = huffNode[huffNode[n].parent].nbBits + 1
457 }
458 for n := uint16(0); n <= nonNullRank; n++ {
459 huffNode[n].nbBits = huffNode[huffNode[n].parent].nbBits + 1
460 }
461 s.actualTableLog = s.setMaxHeight(int(nonNullRank))
462 maxNbBits := s.actualTableLog
463
464 // fill result into tree (val, nbBits)
465 if maxNbBits > tableLogMax {
466 return fmt.Errorf("internal error: maxNbBits (%d) > tableLogMax (%d)", maxNbBits, tableLogMax)
467 }
468 var nbPerRank [tableLogMax + 1]uint16
469 var valPerRank [16]uint16
470 for _, v := range huffNode[:nonNullRank+1] {
471 nbPerRank[v.nbBits]++
472 }
473 // determine stating value per rank
474 {
475 min := uint16(0)
476 for n := maxNbBits; n > 0; n-- {
477 // get starting value within each rank
478 valPerRank[n] = min
479 min += nbPerRank[n]
480 min >>= 1
481 }
482 }
483
484 // push nbBits per symbol, symbol order
485 for _, v := range huffNode[:nonNullRank+1] {
486 s.cTable[v.symbol].nBits = v.nbBits
487 }
488
489 // assign value within rank, symbol order
490 t := s.cTable[:s.symbolLen]
491 for n, val := range t {
492 nbits := val.nBits & 15
493 v := valPerRank[nbits]
494 t[n].val = v
495 valPerRank[nbits] = v + 1
496 }
497
498 return nil
499}
500
501// huffSort will sort symbols, decreasing order.
502func (s *Scratch) huffSort() {
503 type rankPos struct {
504 base uint32
505 current uint32
506 }
507
508 // Clear nodes
509 nodes := s.nodes[:huffNodesLen+1]
510 s.nodes = nodes
511 nodes = nodes[1 : huffNodesLen+1]
512
513 // Sort into buckets based on length of symbol count.
514 var rank [32]rankPos
515 for _, v := range s.count[:s.symbolLen] {
516 r := highBit32(v+1) & 31
517 rank[r].base++
518 }
519 // maxBitLength is log2(BlockSizeMax) + 1
520 const maxBitLength = 18 + 1
521 for n := maxBitLength; n > 0; n-- {
522 rank[n-1].base += rank[n].base
523 }
524 for n := range rank[:maxBitLength] {
525 rank[n].current = rank[n].base
526 }
527 for n, c := range s.count[:s.symbolLen] {
528 r := (highBit32(c+1) + 1) & 31
529 pos := rank[r].current
530 rank[r].current++
531 prev := nodes[(pos-1)&huffNodesMask]
532 for pos > rank[r].base && c > prev.count {
533 nodes[pos&huffNodesMask] = prev
534 pos--
535 prev = nodes[(pos-1)&huffNodesMask]
536 }
537 nodes[pos&huffNodesMask] = nodeElt{count: c, symbol: byte(n)}
538 }
539}
540
541func (s *Scratch) setMaxHeight(lastNonNull int) uint8 {
542 maxNbBits := s.actualTableLog
543 huffNode := s.nodes[1 : huffNodesLen+1]
544 //huffNode = huffNode[: huffNodesLen]
545
546 largestBits := huffNode[lastNonNull].nbBits
547
548 // early exit : no elt > maxNbBits
549 if largestBits <= maxNbBits {
550 return largestBits
551 }
552 totalCost := int(0)
553 baseCost := int(1) << (largestBits - maxNbBits)
554 n := uint32(lastNonNull)
555
556 for huffNode[n].nbBits > maxNbBits {
557 totalCost += baseCost - (1 << (largestBits - huffNode[n].nbBits))
558 huffNode[n].nbBits = maxNbBits
559 n--
560 }
561 // n stops at huffNode[n].nbBits <= maxNbBits
562
563 for huffNode[n].nbBits == maxNbBits {
564 n--
565 }
566 // n end at index of smallest symbol using < maxNbBits
567
568 // renorm totalCost
569 totalCost >>= largestBits - maxNbBits /* note : totalCost is necessarily a multiple of baseCost */
570
571 // repay normalized cost
572 {
573 const noSymbol = 0xF0F0F0F0
574 var rankLast [tableLogMax + 2]uint32
575
576 for i := range rankLast[:] {
577 rankLast[i] = noSymbol
578 }
579
580 // Get pos of last (smallest) symbol per rank
581 {
582 currentNbBits := maxNbBits
583 for pos := int(n); pos >= 0; pos-- {
584 if huffNode[pos].nbBits >= currentNbBits {
585 continue
586 }
587 currentNbBits = huffNode[pos].nbBits // < maxNbBits
588 rankLast[maxNbBits-currentNbBits] = uint32(pos)
589 }
590 }
591
592 for totalCost > 0 {
593 nBitsToDecrease := uint8(highBit32(uint32(totalCost))) + 1
594
595 for ; nBitsToDecrease > 1; nBitsToDecrease-- {
596 highPos := rankLast[nBitsToDecrease]
597 lowPos := rankLast[nBitsToDecrease-1]
598 if highPos == noSymbol {
599 continue
600 }
601 if lowPos == noSymbol {
602 break
603 }
604 highTotal := huffNode[highPos].count
605 lowTotal := 2 * huffNode[lowPos].count
606 if highTotal <= lowTotal {
607 break
608 }
609 }
610 // only triggered when no more rank 1 symbol left => find closest one (note : there is necessarily at least one !)
611 // HUF_MAX_TABLELOG test just to please gcc 5+; but it should not be necessary
612 // FIXME: try to remove
613 for (nBitsToDecrease <= tableLogMax) && (rankLast[nBitsToDecrease] == noSymbol) {
614 nBitsToDecrease++
615 }
616 totalCost -= 1 << (nBitsToDecrease - 1)
617 if rankLast[nBitsToDecrease-1] == noSymbol {
618 // this rank is no longer empty
619 rankLast[nBitsToDecrease-1] = rankLast[nBitsToDecrease]
620 }
621 huffNode[rankLast[nBitsToDecrease]].nbBits++
622 if rankLast[nBitsToDecrease] == 0 {
623 /* special case, reached largest symbol */
624 rankLast[nBitsToDecrease] = noSymbol
625 } else {
626 rankLast[nBitsToDecrease]--
627 if huffNode[rankLast[nBitsToDecrease]].nbBits != maxNbBits-nBitsToDecrease {
628 rankLast[nBitsToDecrease] = noSymbol /* this rank is now empty */
629 }
630 }
631 }
632
633 for totalCost < 0 { /* Sometimes, cost correction overshoot */
634 if rankLast[1] == noSymbol { /* special case : no rank 1 symbol (using maxNbBits-1); let's create one from largest rank 0 (using maxNbBits) */
635 for huffNode[n].nbBits == maxNbBits {
636 n--
637 }
638 huffNode[n+1].nbBits--
639 rankLast[1] = n + 1
640 totalCost++
641 continue
642 }
643 huffNode[rankLast[1]+1].nbBits--
644 rankLast[1]++
645 totalCost++
646 }
647 }
648 return maxNbBits
649}
650
651type nodeElt struct {
652 count uint32
653 parent uint16
654 symbol byte
655 nbBits uint8
656}