blob: dd4f7fefb9fbe7bfc33ba1841f80a5c3efe8e568 [file] [log] [blame]
Dinesh Belwalkare63f7f92019-11-22 23:11:16 +00001package huff0
2
3import (
4 "fmt"
5 "runtime"
6 "sync"
7)
8
9// Compress1X will compress the input.
10// The output can be decoded using Decompress1X.
11// Supply a Scratch object. The scratch object contains state about re-use,
12// So when sharing across independent encodes, be sure to set the re-use policy.
13func Compress1X(in []byte, s *Scratch) (out []byte, reUsed bool, err error) {
14 s, err = s.prepare(in)
15 if err != nil {
16 return nil, false, err
17 }
18 return compress(in, s, s.compress1X)
19}
20
21// Compress4X will compress the input. The input is split into 4 independent blocks
22// and compressed similar to Compress1X.
23// The output can be decoded using Decompress4X.
24// Supply a Scratch object. The scratch object contains state about re-use,
25// So when sharing across independent encodes, be sure to set the re-use policy.
26func Compress4X(in []byte, s *Scratch) (out []byte, reUsed bool, err error) {
27 s, err = s.prepare(in)
28 if err != nil {
29 return nil, false, err
30 }
31 if false {
32 // TODO: compress4Xp only slightly faster.
33 const parallelThreshold = 8 << 10
34 if len(in) < parallelThreshold || runtime.GOMAXPROCS(0) == 1 {
35 return compress(in, s, s.compress4X)
36 }
37 return compress(in, s, s.compress4Xp)
38 }
39 return compress(in, s, s.compress4X)
40}
41
42func compress(in []byte, s *Scratch, compressor func(src []byte) ([]byte, error)) (out []byte, reUsed bool, err error) {
43 // Nuke previous table if we cannot reuse anyway.
44 if s.Reuse == ReusePolicyNone {
45 s.prevTable = s.prevTable[:0]
46 }
47
48 // Create histogram, if none was provided.
49 maxCount := s.maxCount
50 var canReuse = false
51 if maxCount == 0 {
52 maxCount, canReuse = s.countSimple(in)
53 } else {
54 canReuse = s.canUseTable(s.prevTable)
55 }
56
57 // Reset for next run.
58 s.clearCount = true
59 s.maxCount = 0
60 if maxCount >= len(in) {
61 if maxCount > len(in) {
62 return nil, false, fmt.Errorf("maxCount (%d) > length (%d)", maxCount, len(in))
63 }
64 if len(in) == 1 {
65 return nil, false, ErrIncompressible
66 }
67 // One symbol, use RLE
68 return nil, false, ErrUseRLE
69 }
70 if maxCount == 1 || maxCount < (len(in)>>7) {
71 // Each symbol present maximum once or too well distributed.
72 return nil, false, ErrIncompressible
73 }
74
75 if s.Reuse == ReusePolicyPrefer && canReuse {
76 keepTable := s.cTable
77 s.cTable = s.prevTable
78 s.Out, err = compressor(in)
79 s.cTable = keepTable
80 if err == nil && len(s.Out) < len(in) {
81 s.OutData = s.Out
82 return s.Out, true, nil
83 }
84 // Do not attempt to re-use later.
85 s.prevTable = s.prevTable[:0]
86 }
87
88 // Calculate new table.
89 s.optimalTableLog()
90 err = s.buildCTable()
91 if err != nil {
92 return nil, false, err
93 }
94
95 if false && !s.canUseTable(s.cTable) {
96 panic("invalid table generated")
97 }
98
99 if s.Reuse == ReusePolicyAllow && canReuse {
100 hSize := len(s.Out)
101 oldSize := s.prevTable.estimateSize(s.count[:s.symbolLen])
102 newSize := s.cTable.estimateSize(s.count[:s.symbolLen])
103 if oldSize <= hSize+newSize || hSize+12 >= len(in) {
104 // Retain cTable even if we re-use.
105 keepTable := s.cTable
106 s.cTable = s.prevTable
107 s.Out, err = compressor(in)
108 s.cTable = keepTable
109 if len(s.Out) >= len(in) {
110 return nil, false, ErrIncompressible
111 }
112 s.OutData = s.Out
113 return s.Out, true, nil
114 }
115 }
116
117 // Use new table
118 err = s.cTable.write(s)
119 if err != nil {
120 s.OutTable = nil
121 return nil, false, err
122 }
123 s.OutTable = s.Out
124
125 // Compress using new table
126 s.Out, err = compressor(in)
127 if err != nil {
128 s.OutTable = nil
129 return nil, false, err
130 }
131 if len(s.Out) >= len(in) {
132 s.OutTable = nil
133 return nil, false, ErrIncompressible
134 }
135 // Move current table into previous.
136 s.prevTable, s.cTable = s.cTable, s.prevTable[:0]
137 s.OutData = s.Out[len(s.OutTable):]
138 return s.Out, false, nil
139}
140
141func (s *Scratch) compress1X(src []byte) ([]byte, error) {
142 return s.compress1xDo(s.Out, src)
143}
144
145func (s *Scratch) compress1xDo(dst, src []byte) ([]byte, error) {
146 var bw = bitWriter{out: dst}
147
148 // N is length divisible by 4.
149 n := len(src)
150 n -= n & 3
151 cTable := s.cTable[:256]
152
153 // Encode last bytes.
154 for i := len(src) & 3; i > 0; i-- {
155 bw.encSymbol(cTable, src[n+i-1])
156 }
157 if s.actualTableLog <= 8 {
158 n -= 4
159 for ; n >= 0; n -= 4 {
160 tmp := src[n : n+4]
161 // tmp should be len 4
162 bw.flush32()
163 bw.encSymbol(cTable, tmp[3])
164 bw.encSymbol(cTable, tmp[2])
165 bw.encSymbol(cTable, tmp[1])
166 bw.encSymbol(cTable, tmp[0])
167 }
168 } else {
169 n -= 4
170 for ; n >= 0; n -= 4 {
171 tmp := src[n : n+4]
172 // tmp should be len 4
173 bw.flush32()
174 bw.encSymbol(cTable, tmp[3])
175 bw.encSymbol(cTable, tmp[2])
176 bw.flush32()
177 bw.encSymbol(cTable, tmp[1])
178 bw.encSymbol(cTable, tmp[0])
179 }
180 }
181 err := bw.close()
182 return bw.out, err
183}
184
185var sixZeros [6]byte
186
187func (s *Scratch) compress4X(src []byte) ([]byte, error) {
188 if len(src) < 12 {
189 return nil, ErrIncompressible
190 }
191 segmentSize := (len(src) + 3) / 4
192
193 // Add placeholder for output length
194 offsetIdx := len(s.Out)
195 s.Out = append(s.Out, sixZeros[:]...)
196
197 for i := 0; i < 4; i++ {
198 toDo := src
199 if len(toDo) > segmentSize {
200 toDo = toDo[:segmentSize]
201 }
202 src = src[len(toDo):]
203
204 var err error
205 idx := len(s.Out)
206 s.Out, err = s.compress1xDo(s.Out, toDo)
207 if err != nil {
208 return nil, err
209 }
210 // Write compressed length as little endian before block.
211 if i < 3 {
212 // Last length is not written.
213 length := len(s.Out) - idx
214 s.Out[i*2+offsetIdx] = byte(length)
215 s.Out[i*2+offsetIdx+1] = byte(length >> 8)
216 }
217 }
218
219 return s.Out, nil
220}
221
222// compress4Xp will compress 4 streams using separate goroutines.
223func (s *Scratch) compress4Xp(src []byte) ([]byte, error) {
224 if len(src) < 12 {
225 return nil, ErrIncompressible
226 }
227 // Add placeholder for output length
228 s.Out = s.Out[:6]
229
230 segmentSize := (len(src) + 3) / 4
231 var wg sync.WaitGroup
232 var errs [4]error
233 wg.Add(4)
234 for i := 0; i < 4; i++ {
235 toDo := src
236 if len(toDo) > segmentSize {
237 toDo = toDo[:segmentSize]
238 }
239 src = src[len(toDo):]
240
241 // Separate goroutine for each block.
242 go func(i int) {
243 s.tmpOut[i], errs[i] = s.compress1xDo(s.tmpOut[i][:0], toDo)
244 wg.Done()
245 }(i)
246 }
247 wg.Wait()
248 for i := 0; i < 4; i++ {
249 if errs[i] != nil {
250 return nil, errs[i]
251 }
252 o := s.tmpOut[i]
253 // Write compressed length as little endian before block.
254 if i < 3 {
255 // Last length is not written.
256 s.Out[i*2] = byte(len(o))
257 s.Out[i*2+1] = byte(len(o) >> 8)
258 }
259
260 // Write output.
261 s.Out = append(s.Out, o...)
262 }
263 return s.Out, nil
264}
265
266// countSimple will create a simple histogram in s.count.
267// Returns the biggest count.
268// Does not update s.clearCount.
269func (s *Scratch) countSimple(in []byte) (max int, reuse bool) {
270 reuse = true
271 for _, v := range in {
272 s.count[v]++
273 }
274 m := uint32(0)
275 if len(s.prevTable) > 0 {
276 for i, v := range s.count[:] {
277 if v > m {
278 m = v
279 }
280 if v > 0 {
281 s.symbolLen = uint16(i) + 1
282 if i >= len(s.prevTable) {
283 reuse = false
284 } else {
285 if s.prevTable[i].nBits == 0 {
286 reuse = false
287 }
288 }
289 }
290 }
291 return int(m), reuse
292 }
293 for i, v := range s.count[:] {
294 if v > m {
295 m = v
296 }
297 if v > 0 {
298 s.symbolLen = uint16(i) + 1
299 }
300 }
301 return int(m), false
302}
303
304func (s *Scratch) canUseTable(c cTable) bool {
305 if len(c) < int(s.symbolLen) {
306 return false
307 }
308 for i, v := range s.count[:s.symbolLen] {
309 if v != 0 && c[i].nBits == 0 {
310 return false
311 }
312 }
313 return true
314}
315
316// minTableLog provides the minimum logSize to safely represent a distribution.
317func (s *Scratch) minTableLog() uint8 {
318 minBitsSrc := highBit32(uint32(s.br.remain()-1)) + 1
319 minBitsSymbols := highBit32(uint32(s.symbolLen-1)) + 2
320 if minBitsSrc < minBitsSymbols {
321 return uint8(minBitsSrc)
322 }
323 return uint8(minBitsSymbols)
324}
325
326// optimalTableLog calculates and sets the optimal tableLog in s.actualTableLog
327func (s *Scratch) optimalTableLog() {
328 tableLog := s.TableLog
329 minBits := s.minTableLog()
330 maxBitsSrc := uint8(highBit32(uint32(s.br.remain()-1))) - 2
331 if maxBitsSrc < tableLog {
332 // Accuracy can be reduced
333 tableLog = maxBitsSrc
334 }
335 if minBits > tableLog {
336 tableLog = minBits
337 }
338 // Need a minimum to safely represent all symbol values
339 if tableLog < minTablelog {
340 tableLog = minTablelog
341 }
342 if tableLog > tableLogMax {
343 tableLog = tableLogMax
344 }
345 s.actualTableLog = tableLog
346}
347
348type cTableEntry struct {
349 val uint16
350 nBits uint8
351 // We have 8 bits extra
352}
353
354const huffNodesMask = huffNodesLen - 1
355
356func (s *Scratch) buildCTable() error {
357 s.huffSort()
358 if cap(s.cTable) < maxSymbolValue+1 {
359 s.cTable = make([]cTableEntry, s.symbolLen, maxSymbolValue+1)
360 } else {
361 s.cTable = s.cTable[:s.symbolLen]
362 for i := range s.cTable {
363 s.cTable[i] = cTableEntry{}
364 }
365 }
366
367 var startNode = int16(s.symbolLen)
368 nonNullRank := s.symbolLen - 1
369
370 nodeNb := int16(startNode)
371 huffNode := s.nodes[1 : huffNodesLen+1]
372
373 // This overlays the slice above, but allows "-1" index lookups.
374 // Different from reference implementation.
375 huffNode0 := s.nodes[0 : huffNodesLen+1]
376
377 for huffNode[nonNullRank].count == 0 {
378 nonNullRank--
379 }
380
381 lowS := int16(nonNullRank)
382 nodeRoot := nodeNb + lowS - 1
383 lowN := nodeNb
384 huffNode[nodeNb].count = huffNode[lowS].count + huffNode[lowS-1].count
385 huffNode[lowS].parent, huffNode[lowS-1].parent = uint16(nodeNb), uint16(nodeNb)
386 nodeNb++
387 lowS -= 2
388 for n := nodeNb; n <= nodeRoot; n++ {
389 huffNode[n].count = 1 << 30
390 }
391 // fake entry, strong barrier
392 huffNode0[0].count = 1 << 31
393
394 // create parents
395 for nodeNb <= nodeRoot {
396 var n1, n2 int16
397 if huffNode0[lowS+1].count < huffNode0[lowN+1].count {
398 n1 = lowS
399 lowS--
400 } else {
401 n1 = lowN
402 lowN++
403 }
404 if huffNode0[lowS+1].count < huffNode0[lowN+1].count {
405 n2 = lowS
406 lowS--
407 } else {
408 n2 = lowN
409 lowN++
410 }
411
412 huffNode[nodeNb].count = huffNode0[n1+1].count + huffNode0[n2+1].count
413 huffNode0[n1+1].parent, huffNode0[n2+1].parent = uint16(nodeNb), uint16(nodeNb)
414 nodeNb++
415 }
416
417 // distribute weights (unlimited tree height)
418 huffNode[nodeRoot].nbBits = 0
419 for n := nodeRoot - 1; n >= startNode; n-- {
420 huffNode[n].nbBits = huffNode[huffNode[n].parent].nbBits + 1
421 }
422 for n := uint16(0); n <= nonNullRank; n++ {
423 huffNode[n].nbBits = huffNode[huffNode[n].parent].nbBits + 1
424 }
425 s.actualTableLog = s.setMaxHeight(int(nonNullRank))
426 maxNbBits := s.actualTableLog
427
428 // fill result into tree (val, nbBits)
429 if maxNbBits > tableLogMax {
430 return fmt.Errorf("internal error: maxNbBits (%d) > tableLogMax (%d)", maxNbBits, tableLogMax)
431 }
432 var nbPerRank [tableLogMax + 1]uint16
433 var valPerRank [tableLogMax + 1]uint16
434 for _, v := range huffNode[:nonNullRank+1] {
435 nbPerRank[v.nbBits]++
436 }
437 // determine stating value per rank
438 {
439 min := uint16(0)
440 for n := maxNbBits; n > 0; n-- {
441 // get starting value within each rank
442 valPerRank[n] = min
443 min += nbPerRank[n]
444 min >>= 1
445 }
446 }
447
448 // push nbBits per symbol, symbol order
449 // TODO: changed `s.symbolLen` -> `nonNullRank+1` (micro-opt)
450 for _, v := range huffNode[:nonNullRank+1] {
451 s.cTable[v.symbol].nBits = v.nbBits
452 }
453
454 // assign value within rank, symbol order
455 for n, val := range s.cTable[:s.symbolLen] {
456 v := valPerRank[val.nBits]
457 s.cTable[n].val = v
458 valPerRank[val.nBits] = v + 1
459 }
460
461 return nil
462}
463
464// huffSort will sort symbols, decreasing order.
465func (s *Scratch) huffSort() {
466 type rankPos struct {
467 base uint32
468 current uint32
469 }
470
471 // Clear nodes
472 nodes := s.nodes[:huffNodesLen+1]
473 s.nodes = nodes
474 nodes = nodes[1 : huffNodesLen+1]
475
476 // Sort into buckets based on length of symbol count.
477 var rank [32]rankPos
478 for _, v := range s.count[:s.symbolLen] {
479 r := highBit32(v+1) & 31
480 rank[r].base++
481 }
482 for n := 30; n > 0; n-- {
483 rank[n-1].base += rank[n].base
484 }
485 for n := range rank[:] {
486 rank[n].current = rank[n].base
487 }
488 for n, c := range s.count[:s.symbolLen] {
489 r := (highBit32(c+1) + 1) & 31
490 pos := rank[r].current
491 rank[r].current++
492 prev := nodes[(pos-1)&huffNodesMask]
493 for pos > rank[r].base && c > prev.count {
494 nodes[pos&huffNodesMask] = prev
495 pos--
496 prev = nodes[(pos-1)&huffNodesMask]
497 }
498 nodes[pos&huffNodesMask] = nodeElt{count: c, symbol: byte(n)}
499 }
500 return
501}
502
503func (s *Scratch) setMaxHeight(lastNonNull int) uint8 {
504 maxNbBits := s.TableLog
505 huffNode := s.nodes[1 : huffNodesLen+1]
506 //huffNode = huffNode[: huffNodesLen]
507
508 largestBits := huffNode[lastNonNull].nbBits
509
510 // early exit : no elt > maxNbBits
511 if largestBits <= maxNbBits {
512 return largestBits
513 }
514 totalCost := int(0)
515 baseCost := int(1) << (largestBits - maxNbBits)
516 n := uint32(lastNonNull)
517
518 for huffNode[n].nbBits > maxNbBits {
519 totalCost += baseCost - (1 << (largestBits - huffNode[n].nbBits))
520 huffNode[n].nbBits = maxNbBits
521 n--
522 }
523 // n stops at huffNode[n].nbBits <= maxNbBits
524
525 for huffNode[n].nbBits == maxNbBits {
526 n--
527 }
528 // n end at index of smallest symbol using < maxNbBits
529
530 // renorm totalCost
531 totalCost >>= largestBits - maxNbBits /* note : totalCost is necessarily a multiple of baseCost */
532
533 // repay normalized cost
534 {
535 const noSymbol = 0xF0F0F0F0
536 var rankLast [tableLogMax + 2]uint32
537
538 for i := range rankLast[:] {
539 rankLast[i] = noSymbol
540 }
541
542 // Get pos of last (smallest) symbol per rank
543 {
544 currentNbBits := uint8(maxNbBits)
545 for pos := int(n); pos >= 0; pos-- {
546 if huffNode[pos].nbBits >= currentNbBits {
547 continue
548 }
549 currentNbBits = huffNode[pos].nbBits // < maxNbBits
550 rankLast[maxNbBits-currentNbBits] = uint32(pos)
551 }
552 }
553
554 for totalCost > 0 {
555 nBitsToDecrease := uint8(highBit32(uint32(totalCost))) + 1
556
557 for ; nBitsToDecrease > 1; nBitsToDecrease-- {
558 highPos := rankLast[nBitsToDecrease]
559 lowPos := rankLast[nBitsToDecrease-1]
560 if highPos == noSymbol {
561 continue
562 }
563 if lowPos == noSymbol {
564 break
565 }
566 highTotal := huffNode[highPos].count
567 lowTotal := 2 * huffNode[lowPos].count
568 if highTotal <= lowTotal {
569 break
570 }
571 }
572 // only triggered when no more rank 1 symbol left => find closest one (note : there is necessarily at least one !)
573 // HUF_MAX_TABLELOG test just to please gcc 5+; but it should not be necessary
574 // FIXME: try to remove
575 for (nBitsToDecrease <= tableLogMax) && (rankLast[nBitsToDecrease] == noSymbol) {
576 nBitsToDecrease++
577 }
578 totalCost -= 1 << (nBitsToDecrease - 1)
579 if rankLast[nBitsToDecrease-1] == noSymbol {
580 // this rank is no longer empty
581 rankLast[nBitsToDecrease-1] = rankLast[nBitsToDecrease]
582 }
583 huffNode[rankLast[nBitsToDecrease]].nbBits++
584 if rankLast[nBitsToDecrease] == 0 {
585 /* special case, reached largest symbol */
586 rankLast[nBitsToDecrease] = noSymbol
587 } else {
588 rankLast[nBitsToDecrease]--
589 if huffNode[rankLast[nBitsToDecrease]].nbBits != maxNbBits-nBitsToDecrease {
590 rankLast[nBitsToDecrease] = noSymbol /* this rank is now empty */
591 }
592 }
593 }
594
595 for totalCost < 0 { /* Sometimes, cost correction overshoot */
596 if rankLast[1] == noSymbol { /* special case : no rank 1 symbol (using maxNbBits-1); let's create one from largest rank 0 (using maxNbBits) */
597 for huffNode[n].nbBits == maxNbBits {
598 n--
599 }
600 huffNode[n+1].nbBits--
601 rankLast[1] = n + 1
602 totalCost++
603 continue
604 }
605 huffNode[rankLast[1]+1].nbBits--
606 rankLast[1]++
607 totalCost++
608 }
609 }
610 return maxNbBits
611}
612
613type nodeElt struct {
614 count uint32
615 parent uint16
616 symbol byte
617 nbBits uint8
618}