blob: 0843cb014ffce1a2be628e61244346066e8a2aff [file] [log] [blame]
Pragya Arya324337e2020-02-20 14:35:08 +05301package huff0
2
3import (
4 "fmt"
5 "runtime"
6 "sync"
7)
8
9// Compress1X will compress the input.
10// The output can be decoded using Decompress1X.
11// Supply a Scratch object. The scratch object contains state about re-use,
12// So when sharing across independent encodes, be sure to set the re-use policy.
13func Compress1X(in []byte, s *Scratch) (out []byte, reUsed bool, err error) {
14 s, err = s.prepare(in)
15 if err != nil {
16 return nil, false, err
17 }
18 return compress(in, s, s.compress1X)
19}
20
21// Compress4X will compress the input. The input is split into 4 independent blocks
22// and compressed similar to Compress1X.
23// The output can be decoded using Decompress4X.
24// Supply a Scratch object. The scratch object contains state about re-use,
25// So when sharing across independent encodes, be sure to set the re-use policy.
26func Compress4X(in []byte, s *Scratch) (out []byte, reUsed bool, err error) {
27 s, err = s.prepare(in)
28 if err != nil {
29 return nil, false, err
30 }
31 if false {
32 // TODO: compress4Xp only slightly faster.
33 const parallelThreshold = 8 << 10
34 if len(in) < parallelThreshold || runtime.GOMAXPROCS(0) == 1 {
35 return compress(in, s, s.compress4X)
36 }
37 return compress(in, s, s.compress4Xp)
38 }
39 return compress(in, s, s.compress4X)
40}
41
42func compress(in []byte, s *Scratch, compressor func(src []byte) ([]byte, error)) (out []byte, reUsed bool, err error) {
43 // Nuke previous table if we cannot reuse anyway.
44 if s.Reuse == ReusePolicyNone {
45 s.prevTable = s.prevTable[:0]
46 }
47
48 // Create histogram, if none was provided.
49 maxCount := s.maxCount
50 var canReuse = false
51 if maxCount == 0 {
52 maxCount, canReuse = s.countSimple(in)
53 } else {
54 canReuse = s.canUseTable(s.prevTable)
55 }
56
57 // We want the output size to be less than this:
58 wantSize := len(in)
59 if s.WantLogLess > 0 {
60 wantSize -= wantSize >> s.WantLogLess
61 }
62
63 // Reset for next run.
64 s.clearCount = true
65 s.maxCount = 0
66 if maxCount >= len(in) {
67 if maxCount > len(in) {
68 return nil, false, fmt.Errorf("maxCount (%d) > length (%d)", maxCount, len(in))
69 }
70 if len(in) == 1 {
71 return nil, false, ErrIncompressible
72 }
73 // One symbol, use RLE
74 return nil, false, ErrUseRLE
75 }
76 if maxCount == 1 || maxCount < (len(in)>>7) {
77 // Each symbol present maximum once or too well distributed.
78 return nil, false, ErrIncompressible
79 }
80
81 if s.Reuse == ReusePolicyPrefer && canReuse {
82 keepTable := s.cTable
83 keepTL := s.actualTableLog
84 s.cTable = s.prevTable
85 s.actualTableLog = s.prevTableLog
86 s.Out, err = compressor(in)
87 s.cTable = keepTable
88 s.actualTableLog = keepTL
89 if err == nil && len(s.Out) < wantSize {
90 s.OutData = s.Out
91 return s.Out, true, nil
92 }
93 // Do not attempt to re-use later.
94 s.prevTable = s.prevTable[:0]
95 }
96
97 // Calculate new table.
98 err = s.buildCTable()
99 if err != nil {
100 return nil, false, err
101 }
102
103 if false && !s.canUseTable(s.cTable) {
104 panic("invalid table generated")
105 }
106
107 if s.Reuse == ReusePolicyAllow && canReuse {
108 hSize := len(s.Out)
109 oldSize := s.prevTable.estimateSize(s.count[:s.symbolLen])
110 newSize := s.cTable.estimateSize(s.count[:s.symbolLen])
111 if oldSize <= hSize+newSize || hSize+12 >= wantSize {
112 // Retain cTable even if we re-use.
113 keepTable := s.cTable
114 keepTL := s.actualTableLog
115
116 s.cTable = s.prevTable
117 s.actualTableLog = s.prevTableLog
118 s.Out, err = compressor(in)
119
120 // Restore ctable.
121 s.cTable = keepTable
122 s.actualTableLog = keepTL
123 if err != nil {
124 return nil, false, err
125 }
126 if len(s.Out) >= wantSize {
127 return nil, false, ErrIncompressible
128 }
129 s.OutData = s.Out
130 return s.Out, true, nil
131 }
132 }
133
134 // Use new table
135 err = s.cTable.write(s)
136 if err != nil {
137 s.OutTable = nil
138 return nil, false, err
139 }
140 s.OutTable = s.Out
141
142 // Compress using new table
143 s.Out, err = compressor(in)
144 if err != nil {
145 s.OutTable = nil
146 return nil, false, err
147 }
148 if len(s.Out) >= wantSize {
149 s.OutTable = nil
150 return nil, false, ErrIncompressible
151 }
152 // Move current table into previous.
153 s.prevTable, s.prevTableLog, s.cTable = s.cTable, s.actualTableLog, s.prevTable[:0]
154 s.OutData = s.Out[len(s.OutTable):]
155 return s.Out, false, nil
156}
157
158func (s *Scratch) compress1X(src []byte) ([]byte, error) {
159 return s.compress1xDo(s.Out, src)
160}
161
162func (s *Scratch) compress1xDo(dst, src []byte) ([]byte, error) {
163 var bw = bitWriter{out: dst}
164
165 // N is length divisible by 4.
166 n := len(src)
167 n -= n & 3
168 cTable := s.cTable[:256]
169
170 // Encode last bytes.
171 for i := len(src) & 3; i > 0; i-- {
172 bw.encSymbol(cTable, src[n+i-1])
173 }
174 n -= 4
175 if s.actualTableLog <= 8 {
176 for ; n >= 0; n -= 4 {
177 tmp := src[n : n+4]
178 // tmp should be len 4
179 bw.flush32()
180 bw.encTwoSymbols(cTable, tmp[3], tmp[2])
181 bw.encTwoSymbols(cTable, tmp[1], tmp[0])
182 }
183 } else {
184 for ; n >= 0; n -= 4 {
185 tmp := src[n : n+4]
186 // tmp should be len 4
187 bw.flush32()
188 bw.encTwoSymbols(cTable, tmp[3], tmp[2])
189 bw.flush32()
190 bw.encTwoSymbols(cTable, tmp[1], tmp[0])
191 }
192 }
193 err := bw.close()
194 return bw.out, err
195}
196
197var sixZeros [6]byte
198
199func (s *Scratch) compress4X(src []byte) ([]byte, error) {
200 if len(src) < 12 {
201 return nil, ErrIncompressible
202 }
203 segmentSize := (len(src) + 3) / 4
204
205 // Add placeholder for output length
206 offsetIdx := len(s.Out)
207 s.Out = append(s.Out, sixZeros[:]...)
208
209 for i := 0; i < 4; i++ {
210 toDo := src
211 if len(toDo) > segmentSize {
212 toDo = toDo[:segmentSize]
213 }
214 src = src[len(toDo):]
215
216 var err error
217 idx := len(s.Out)
218 s.Out, err = s.compress1xDo(s.Out, toDo)
219 if err != nil {
220 return nil, err
221 }
222 // Write compressed length as little endian before block.
223 if i < 3 {
224 // Last length is not written.
225 length := len(s.Out) - idx
226 s.Out[i*2+offsetIdx] = byte(length)
227 s.Out[i*2+offsetIdx+1] = byte(length >> 8)
228 }
229 }
230
231 return s.Out, nil
232}
233
234// compress4Xp will compress 4 streams using separate goroutines.
235func (s *Scratch) compress4Xp(src []byte) ([]byte, error) {
236 if len(src) < 12 {
237 return nil, ErrIncompressible
238 }
239 // Add placeholder for output length
240 s.Out = s.Out[:6]
241
242 segmentSize := (len(src) + 3) / 4
243 var wg sync.WaitGroup
244 var errs [4]error
245 wg.Add(4)
246 for i := 0; i < 4; i++ {
247 toDo := src
248 if len(toDo) > segmentSize {
249 toDo = toDo[:segmentSize]
250 }
251 src = src[len(toDo):]
252
253 // Separate goroutine for each block.
254 go func(i int) {
255 s.tmpOut[i], errs[i] = s.compress1xDo(s.tmpOut[i][:0], toDo)
256 wg.Done()
257 }(i)
258 }
259 wg.Wait()
260 for i := 0; i < 4; i++ {
261 if errs[i] != nil {
262 return nil, errs[i]
263 }
264 o := s.tmpOut[i]
265 // Write compressed length as little endian before block.
266 if i < 3 {
267 // Last length is not written.
268 s.Out[i*2] = byte(len(o))
269 s.Out[i*2+1] = byte(len(o) >> 8)
270 }
271
272 // Write output.
273 s.Out = append(s.Out, o...)
274 }
275 return s.Out, nil
276}
277
278// countSimple will create a simple histogram in s.count.
279// Returns the biggest count.
280// Does not update s.clearCount.
281func (s *Scratch) countSimple(in []byte) (max int, reuse bool) {
282 reuse = true
283 for _, v := range in {
284 s.count[v]++
285 }
286 m := uint32(0)
287 if len(s.prevTable) > 0 {
288 for i, v := range s.count[:] {
289 if v > m {
290 m = v
291 }
292 if v > 0 {
293 s.symbolLen = uint16(i) + 1
294 if i >= len(s.prevTable) {
295 reuse = false
296 } else {
297 if s.prevTable[i].nBits == 0 {
298 reuse = false
299 }
300 }
301 }
302 }
303 return int(m), reuse
304 }
305 for i, v := range s.count[:] {
306 if v > m {
307 m = v
308 }
309 if v > 0 {
310 s.symbolLen = uint16(i) + 1
311 }
312 }
313 return int(m), false
314}
315
316func (s *Scratch) canUseTable(c cTable) bool {
317 if len(c) < int(s.symbolLen) {
318 return false
319 }
320 for i, v := range s.count[:s.symbolLen] {
321 if v != 0 && c[i].nBits == 0 {
322 return false
323 }
324 }
325 return true
326}
327
328func (s *Scratch) validateTable(c cTable) bool {
329 if len(c) < int(s.symbolLen) {
330 return false
331 }
332 for i, v := range s.count[:s.symbolLen] {
333 if v != 0 {
334 if c[i].nBits == 0 {
335 return false
336 }
337 if c[i].nBits > s.actualTableLog {
338 return false
339 }
340 }
341 }
342 return true
343}
344
345// minTableLog provides the minimum logSize to safely represent a distribution.
346func (s *Scratch) minTableLog() uint8 {
347 minBitsSrc := highBit32(uint32(s.br.remain())) + 1
348 minBitsSymbols := highBit32(uint32(s.symbolLen-1)) + 2
349 if minBitsSrc < minBitsSymbols {
350 return uint8(minBitsSrc)
351 }
352 return uint8(minBitsSymbols)
353}
354
355// optimalTableLog calculates and sets the optimal tableLog in s.actualTableLog
356func (s *Scratch) optimalTableLog() {
357 tableLog := s.TableLog
358 minBits := s.minTableLog()
359 maxBitsSrc := uint8(highBit32(uint32(s.br.remain()-1))) - 1
360 if maxBitsSrc < tableLog {
361 // Accuracy can be reduced
362 tableLog = maxBitsSrc
363 }
364 if minBits > tableLog {
365 tableLog = minBits
366 }
367 // Need a minimum to safely represent all symbol values
368 if tableLog < minTablelog {
369 tableLog = minTablelog
370 }
371 if tableLog > tableLogMax {
372 tableLog = tableLogMax
373 }
374 s.actualTableLog = tableLog
375}
376
377type cTableEntry struct {
378 val uint16
379 nBits uint8
380 // We have 8 bits extra
381}
382
383const huffNodesMask = huffNodesLen - 1
384
385func (s *Scratch) buildCTable() error {
386 s.optimalTableLog()
387 s.huffSort()
388 if cap(s.cTable) < maxSymbolValue+1 {
389 s.cTable = make([]cTableEntry, s.symbolLen, maxSymbolValue+1)
390 } else {
391 s.cTable = s.cTable[:s.symbolLen]
392 for i := range s.cTable {
393 s.cTable[i] = cTableEntry{}
394 }
395 }
396
397 var startNode = int16(s.symbolLen)
398 nonNullRank := s.symbolLen - 1
399
400 nodeNb := int16(startNode)
401 huffNode := s.nodes[1 : huffNodesLen+1]
402
403 // This overlays the slice above, but allows "-1" index lookups.
404 // Different from reference implementation.
405 huffNode0 := s.nodes[0 : huffNodesLen+1]
406
407 for huffNode[nonNullRank].count == 0 {
408 nonNullRank--
409 }
410
411 lowS := int16(nonNullRank)
412 nodeRoot := nodeNb + lowS - 1
413 lowN := nodeNb
414 huffNode[nodeNb].count = huffNode[lowS].count + huffNode[lowS-1].count
415 huffNode[lowS].parent, huffNode[lowS-1].parent = uint16(nodeNb), uint16(nodeNb)
416 nodeNb++
417 lowS -= 2
418 for n := nodeNb; n <= nodeRoot; n++ {
419 huffNode[n].count = 1 << 30
420 }
421 // fake entry, strong barrier
422 huffNode0[0].count = 1 << 31
423
424 // create parents
425 for nodeNb <= nodeRoot {
426 var n1, n2 int16
427 if huffNode0[lowS+1].count < huffNode0[lowN+1].count {
428 n1 = lowS
429 lowS--
430 } else {
431 n1 = lowN
432 lowN++
433 }
434 if huffNode0[lowS+1].count < huffNode0[lowN+1].count {
435 n2 = lowS
436 lowS--
437 } else {
438 n2 = lowN
439 lowN++
440 }
441
442 huffNode[nodeNb].count = huffNode0[n1+1].count + huffNode0[n2+1].count
443 huffNode0[n1+1].parent, huffNode0[n2+1].parent = uint16(nodeNb), uint16(nodeNb)
444 nodeNb++
445 }
446
447 // distribute weights (unlimited tree height)
448 huffNode[nodeRoot].nbBits = 0
449 for n := nodeRoot - 1; n >= startNode; n-- {
450 huffNode[n].nbBits = huffNode[huffNode[n].parent].nbBits + 1
451 }
452 for n := uint16(0); n <= nonNullRank; n++ {
453 huffNode[n].nbBits = huffNode[huffNode[n].parent].nbBits + 1
454 }
455 s.actualTableLog = s.setMaxHeight(int(nonNullRank))
456 maxNbBits := s.actualTableLog
457
458 // fill result into tree (val, nbBits)
459 if maxNbBits > tableLogMax {
460 return fmt.Errorf("internal error: maxNbBits (%d) > tableLogMax (%d)", maxNbBits, tableLogMax)
461 }
462 var nbPerRank [tableLogMax + 1]uint16
463 var valPerRank [16]uint16
464 for _, v := range huffNode[:nonNullRank+1] {
465 nbPerRank[v.nbBits]++
466 }
467 // determine stating value per rank
468 {
469 min := uint16(0)
470 for n := maxNbBits; n > 0; n-- {
471 // get starting value within each rank
472 valPerRank[n] = min
473 min += nbPerRank[n]
474 min >>= 1
475 }
476 }
477
478 // push nbBits per symbol, symbol order
479 for _, v := range huffNode[:nonNullRank+1] {
480 s.cTable[v.symbol].nBits = v.nbBits
481 }
482
483 // assign value within rank, symbol order
484 t := s.cTable[:s.symbolLen]
485 for n, val := range t {
486 nbits := val.nBits & 15
487 v := valPerRank[nbits]
488 t[n].val = v
489 valPerRank[nbits] = v + 1
490 }
491
492 return nil
493}
494
495// huffSort will sort symbols, decreasing order.
496func (s *Scratch) huffSort() {
497 type rankPos struct {
498 base uint32
499 current uint32
500 }
501
502 // Clear nodes
503 nodes := s.nodes[:huffNodesLen+1]
504 s.nodes = nodes
505 nodes = nodes[1 : huffNodesLen+1]
506
507 // Sort into buckets based on length of symbol count.
508 var rank [32]rankPos
509 for _, v := range s.count[:s.symbolLen] {
510 r := highBit32(v+1) & 31
511 rank[r].base++
512 }
513 // maxBitLength is log2(BlockSizeMax) + 1
514 const maxBitLength = 18 + 1
515 for n := maxBitLength; n > 0; n-- {
516 rank[n-1].base += rank[n].base
517 }
518 for n := range rank[:maxBitLength] {
519 rank[n].current = rank[n].base
520 }
521 for n, c := range s.count[:s.symbolLen] {
522 r := (highBit32(c+1) + 1) & 31
523 pos := rank[r].current
524 rank[r].current++
525 prev := nodes[(pos-1)&huffNodesMask]
526 for pos > rank[r].base && c > prev.count {
527 nodes[pos&huffNodesMask] = prev
528 pos--
529 prev = nodes[(pos-1)&huffNodesMask]
530 }
531 nodes[pos&huffNodesMask] = nodeElt{count: c, symbol: byte(n)}
532 }
533 return
534}
535
536func (s *Scratch) setMaxHeight(lastNonNull int) uint8 {
537 maxNbBits := s.actualTableLog
538 huffNode := s.nodes[1 : huffNodesLen+1]
539 //huffNode = huffNode[: huffNodesLen]
540
541 largestBits := huffNode[lastNonNull].nbBits
542
543 // early exit : no elt > maxNbBits
544 if largestBits <= maxNbBits {
545 return largestBits
546 }
547 totalCost := int(0)
548 baseCost := int(1) << (largestBits - maxNbBits)
549 n := uint32(lastNonNull)
550
551 for huffNode[n].nbBits > maxNbBits {
552 totalCost += baseCost - (1 << (largestBits - huffNode[n].nbBits))
553 huffNode[n].nbBits = maxNbBits
554 n--
555 }
556 // n stops at huffNode[n].nbBits <= maxNbBits
557
558 for huffNode[n].nbBits == maxNbBits {
559 n--
560 }
561 // n end at index of smallest symbol using < maxNbBits
562
563 // renorm totalCost
564 totalCost >>= largestBits - maxNbBits /* note : totalCost is necessarily a multiple of baseCost */
565
566 // repay normalized cost
567 {
568 const noSymbol = 0xF0F0F0F0
569 var rankLast [tableLogMax + 2]uint32
570
571 for i := range rankLast[:] {
572 rankLast[i] = noSymbol
573 }
574
575 // Get pos of last (smallest) symbol per rank
576 {
577 currentNbBits := uint8(maxNbBits)
578 for pos := int(n); pos >= 0; pos-- {
579 if huffNode[pos].nbBits >= currentNbBits {
580 continue
581 }
582 currentNbBits = huffNode[pos].nbBits // < maxNbBits
583 rankLast[maxNbBits-currentNbBits] = uint32(pos)
584 }
585 }
586
587 for totalCost > 0 {
588 nBitsToDecrease := uint8(highBit32(uint32(totalCost))) + 1
589
590 for ; nBitsToDecrease > 1; nBitsToDecrease-- {
591 highPos := rankLast[nBitsToDecrease]
592 lowPos := rankLast[nBitsToDecrease-1]
593 if highPos == noSymbol {
594 continue
595 }
596 if lowPos == noSymbol {
597 break
598 }
599 highTotal := huffNode[highPos].count
600 lowTotal := 2 * huffNode[lowPos].count
601 if highTotal <= lowTotal {
602 break
603 }
604 }
605 // only triggered when no more rank 1 symbol left => find closest one (note : there is necessarily at least one !)
606 // HUF_MAX_TABLELOG test just to please gcc 5+; but it should not be necessary
607 // FIXME: try to remove
608 for (nBitsToDecrease <= tableLogMax) && (rankLast[nBitsToDecrease] == noSymbol) {
609 nBitsToDecrease++
610 }
611 totalCost -= 1 << (nBitsToDecrease - 1)
612 if rankLast[nBitsToDecrease-1] == noSymbol {
613 // this rank is no longer empty
614 rankLast[nBitsToDecrease-1] = rankLast[nBitsToDecrease]
615 }
616 huffNode[rankLast[nBitsToDecrease]].nbBits++
617 if rankLast[nBitsToDecrease] == 0 {
618 /* special case, reached largest symbol */
619 rankLast[nBitsToDecrease] = noSymbol
620 } else {
621 rankLast[nBitsToDecrease]--
622 if huffNode[rankLast[nBitsToDecrease]].nbBits != maxNbBits-nBitsToDecrease {
623 rankLast[nBitsToDecrease] = noSymbol /* this rank is now empty */
624 }
625 }
626 }
627
628 for totalCost < 0 { /* Sometimes, cost correction overshoot */
629 if rankLast[1] == noSymbol { /* special case : no rank 1 symbol (using maxNbBits-1); let's create one from largest rank 0 (using maxNbBits) */
630 for huffNode[n].nbBits == maxNbBits {
631 n--
632 }
633 huffNode[n+1].nbBits--
634 rankLast[1] = n + 1
635 totalCost++
636 continue
637 }
638 huffNode[rankLast[1]+1].nbBits--
639 rankLast[1]++
640 totalCost++
641 }
642 }
643 return maxNbBits
644}
645
646type nodeElt struct {
647 count uint32
648 parent uint16
649 symbol byte
650 nbBits uint8
651}