blob: 41703bba4d65bcd80a68efb9c4c0b8657cc1cce2 [file] [log] [blame]
khenaidoo7d3c5582021-08-11 18:09:44 -04001package huff0
2
3import (
4 "errors"
5 "fmt"
6 "io"
7
8 "github.com/klauspost/compress/fse"
9)
10
11type dTable struct {
12 single []dEntrySingle
13 double []dEntryDouble
14}
15
16// single-symbols decoding
17type dEntrySingle struct {
18 entry uint16
19}
20
21// double-symbols decoding
22type dEntryDouble struct {
23 seq uint16
24 nBits uint8
25 len uint8
26}
27
28// Uses special code for all tables that are < 8 bits.
29const use8BitTables = true
30
31// ReadTable will read a table from the input.
32// The size of the input may be larger than the table definition.
33// Any content remaining after the table definition will be returned.
34// If no Scratch is provided a new one is allocated.
35// The returned Scratch can be used for encoding or decoding input using this table.
36func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) {
37 s, err = s.prepare(in)
38 if err != nil {
39 return s, nil, err
40 }
41 if len(in) <= 1 {
42 return s, nil, errors.New("input too small for table")
43 }
44 iSize := in[0]
45 in = in[1:]
46 if iSize >= 128 {
47 // Uncompressed
48 oSize := iSize - 127
49 iSize = (oSize + 1) / 2
50 if int(iSize) > len(in) {
51 return s, nil, errors.New("input too small for table")
52 }
53 for n := uint8(0); n < oSize; n += 2 {
54 v := in[n/2]
55 s.huffWeight[n] = v >> 4
56 s.huffWeight[n+1] = v & 15
57 }
58 s.symbolLen = uint16(oSize)
59 in = in[iSize:]
60 } else {
61 if len(in) < int(iSize) {
62 return s, nil, fmt.Errorf("input too small for table, want %d bytes, have %d", iSize, len(in))
63 }
64 // FSE compressed weights
65 s.fse.DecompressLimit = 255
66 hw := s.huffWeight[:]
67 s.fse.Out = hw
68 b, err := fse.Decompress(in[:iSize], s.fse)
69 s.fse.Out = nil
70 if err != nil {
71 return s, nil, err
72 }
73 if len(b) > 255 {
74 return s, nil, errors.New("corrupt input: output table too large")
75 }
76 s.symbolLen = uint16(len(b))
77 in = in[iSize:]
78 }
79
80 // collect weight stats
81 var rankStats [16]uint32
82 weightTotal := uint32(0)
83 for _, v := range s.huffWeight[:s.symbolLen] {
84 if v > tableLogMax {
85 return s, nil, errors.New("corrupt input: weight too large")
86 }
87 v2 := v & 15
88 rankStats[v2]++
89 // (1 << (v2-1)) is slower since the compiler cannot prove that v2 isn't 0.
90 weightTotal += (1 << v2) >> 1
91 }
92 if weightTotal == 0 {
93 return s, nil, errors.New("corrupt input: weights zero")
94 }
95
96 // get last non-null symbol weight (implied, total must be 2^n)
97 {
98 tableLog := highBit32(weightTotal) + 1
99 if tableLog > tableLogMax {
100 return s, nil, errors.New("corrupt input: tableLog too big")
101 }
102 s.actualTableLog = uint8(tableLog)
103 // determine last weight
104 {
105 total := uint32(1) << tableLog
106 rest := total - weightTotal
107 verif := uint32(1) << highBit32(rest)
108 lastWeight := highBit32(rest) + 1
109 if verif != rest {
110 // last value must be a clean power of 2
111 return s, nil, errors.New("corrupt input: last value not power of two")
112 }
113 s.huffWeight[s.symbolLen] = uint8(lastWeight)
114 s.symbolLen++
115 rankStats[lastWeight]++
116 }
117 }
118
119 if (rankStats[1] < 2) || (rankStats[1]&1 != 0) {
120 // by construction : at least 2 elts of rank 1, must be even
121 return s, nil, errors.New("corrupt input: min elt size, even check failed ")
122 }
123
124 // TODO: Choose between single/double symbol decoding
125
126 // Calculate starting value for each rank
127 {
128 var nextRankStart uint32
129 for n := uint8(1); n < s.actualTableLog+1; n++ {
130 current := nextRankStart
131 nextRankStart += rankStats[n] << (n - 1)
132 rankStats[n] = current
133 }
134 }
135
136 // fill DTable (always full size)
137 tSize := 1 << tableLogMax
138 if len(s.dt.single) != tSize {
139 s.dt.single = make([]dEntrySingle, tSize)
140 }
141 cTable := s.prevTable
142 if cap(cTable) < maxSymbolValue+1 {
143 cTable = make([]cTableEntry, 0, maxSymbolValue+1)
144 }
145 cTable = cTable[:maxSymbolValue+1]
146 s.prevTable = cTable[:s.symbolLen]
147 s.prevTableLog = s.actualTableLog
148
149 for n, w := range s.huffWeight[:s.symbolLen] {
150 if w == 0 {
151 cTable[n] = cTableEntry{
152 val: 0,
153 nBits: 0,
154 }
155 continue
156 }
157 length := (uint32(1) << w) >> 1
158 d := dEntrySingle{
159 entry: uint16(s.actualTableLog+1-w) | (uint16(n) << 8),
160 }
161
162 rank := &rankStats[w]
163 cTable[n] = cTableEntry{
164 val: uint16(*rank >> (w - 1)),
165 nBits: uint8(d.entry),
166 }
167
168 single := s.dt.single[*rank : *rank+length]
169 for i := range single {
170 single[i] = d
171 }
172 *rank += length
173 }
174
175 return s, in, nil
176}
177
178// Decompress1X will decompress a 1X encoded stream.
179// The length of the supplied input must match the end of a block exactly.
180// Before this is called, the table must be initialized with ReadTable unless
181// the encoder re-used the table.
182// deprecated: Use the stateless Decoder() to get a concurrent version.
183func (s *Scratch) Decompress1X(in []byte) (out []byte, err error) {
184 if cap(s.Out) < s.MaxDecodedSize {
185 s.Out = make([]byte, s.MaxDecodedSize)
186 }
187 s.Out = s.Out[:0:s.MaxDecodedSize]
188 s.Out, err = s.Decoder().Decompress1X(s.Out, in)
189 return s.Out, err
190}
191
192// Decompress4X will decompress a 4X encoded stream.
193// Before this is called, the table must be initialized with ReadTable unless
194// the encoder re-used the table.
195// The length of the supplied input must match the end of a block exactly.
196// The destination size of the uncompressed data must be known and provided.
197// deprecated: Use the stateless Decoder() to get a concurrent version.
198func (s *Scratch) Decompress4X(in []byte, dstSize int) (out []byte, err error) {
199 if dstSize > s.MaxDecodedSize {
200 return nil, ErrMaxDecodedSizeExceeded
201 }
202 if cap(s.Out) < dstSize {
203 s.Out = make([]byte, s.MaxDecodedSize)
204 }
205 s.Out = s.Out[:0:dstSize]
206 s.Out, err = s.Decoder().Decompress4X(s.Out, in)
207 return s.Out, err
208}
209
210// Decoder will return a stateless decoder that can be used by multiple
211// decompressors concurrently.
212// Before this is called, the table must be initialized with ReadTable.
213// The Decoder is still linked to the scratch buffer so that cannot be reused.
214// However, it is safe to discard the scratch.
215func (s *Scratch) Decoder() *Decoder {
216 return &Decoder{
217 dt: s.dt,
218 actualTableLog: s.actualTableLog,
219 }
220}
221
222// Decoder provides stateless decoding.
223type Decoder struct {
224 dt dTable
225 actualTableLog uint8
226}
227
228// Decompress1X will decompress a 1X encoded stream.
229// The cap of the output buffer will be the maximum decompressed size.
230// The length of the supplied input must match the end of a block exactly.
231func (d *Decoder) Decompress1X(dst, src []byte) ([]byte, error) {
232 if len(d.dt.single) == 0 {
233 return nil, errors.New("no table loaded")
234 }
235 if use8BitTables && d.actualTableLog <= 8 {
236 return d.decompress1X8Bit(dst, src)
237 }
238 var br bitReaderShifted
239 err := br.init(src)
240 if err != nil {
241 return dst, err
242 }
243 maxDecodedSize := cap(dst)
244 dst = dst[:0]
245
246 // Avoid bounds check by always having full sized table.
247 const tlSize = 1 << tableLogMax
248 const tlMask = tlSize - 1
249 dt := d.dt.single[:tlSize]
250
251 // Use temp table to avoid bound checks/append penalty.
252 var buf [256]byte
253 var off uint8
254
255 for br.off >= 8 {
256 br.fillFast()
257 v := dt[br.peekBitsFast(d.actualTableLog)&tlMask]
258 br.advance(uint8(v.entry))
259 buf[off+0] = uint8(v.entry >> 8)
260
261 v = dt[br.peekBitsFast(d.actualTableLog)&tlMask]
262 br.advance(uint8(v.entry))
263 buf[off+1] = uint8(v.entry >> 8)
264
265 // Refill
266 br.fillFast()
267
268 v = dt[br.peekBitsFast(d.actualTableLog)&tlMask]
269 br.advance(uint8(v.entry))
270 buf[off+2] = uint8(v.entry >> 8)
271
272 v = dt[br.peekBitsFast(d.actualTableLog)&tlMask]
273 br.advance(uint8(v.entry))
274 buf[off+3] = uint8(v.entry >> 8)
275
276 off += 4
277 if off == 0 {
278 if len(dst)+256 > maxDecodedSize {
279 br.close()
280 return nil, ErrMaxDecodedSizeExceeded
281 }
282 dst = append(dst, buf[:]...)
283 }
284 }
285
286 if len(dst)+int(off) > maxDecodedSize {
287 br.close()
288 return nil, ErrMaxDecodedSizeExceeded
289 }
290 dst = append(dst, buf[:off]...)
291
292 // br < 8, so uint8 is fine
293 bitsLeft := uint8(br.off)*8 + 64 - br.bitsRead
294 for bitsLeft > 0 {
295 br.fill()
296 if false && br.bitsRead >= 32 {
297 if br.off >= 4 {
298 v := br.in[br.off-4:]
299 v = v[:4]
300 low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
301 br.value = (br.value << 32) | uint64(low)
302 br.bitsRead -= 32
303 br.off -= 4
304 } else {
305 for br.off > 0 {
306 br.value = (br.value << 8) | uint64(br.in[br.off-1])
307 br.bitsRead -= 8
308 br.off--
309 }
310 }
311 }
312 if len(dst) >= maxDecodedSize {
313 br.close()
314 return nil, ErrMaxDecodedSizeExceeded
315 }
316 v := d.dt.single[br.peekBitsFast(d.actualTableLog)&tlMask]
317 nBits := uint8(v.entry)
318 br.advance(nBits)
319 bitsLeft -= nBits
320 dst = append(dst, uint8(v.entry>>8))
321 }
322 return dst, br.close()
323}
324
325// decompress1X8Bit will decompress a 1X encoded stream with tablelog <= 8.
326// The cap of the output buffer will be the maximum decompressed size.
327// The length of the supplied input must match the end of a block exactly.
328func (d *Decoder) decompress1X8Bit(dst, src []byte) ([]byte, error) {
329 if d.actualTableLog == 8 {
330 return d.decompress1X8BitExactly(dst, src)
331 }
332 var br bitReaderBytes
333 err := br.init(src)
334 if err != nil {
335 return dst, err
336 }
337 maxDecodedSize := cap(dst)
338 dst = dst[:0]
339
340 // Avoid bounds check by always having full sized table.
341 dt := d.dt.single[:256]
342
343 // Use temp table to avoid bound checks/append penalty.
344 var buf [256]byte
345 var off uint8
346
347 shift := (8 - d.actualTableLog) & 7
348
349 //fmt.Printf("mask: %b, tl:%d\n", mask, d.actualTableLog)
350 for br.off >= 4 {
351 br.fillFast()
352 v := dt[br.peekByteFast()>>shift]
353 br.advance(uint8(v.entry))
354 buf[off+0] = uint8(v.entry >> 8)
355
356 v = dt[br.peekByteFast()>>shift]
357 br.advance(uint8(v.entry))
358 buf[off+1] = uint8(v.entry >> 8)
359
360 v = dt[br.peekByteFast()>>shift]
361 br.advance(uint8(v.entry))
362 buf[off+2] = uint8(v.entry >> 8)
363
364 v = dt[br.peekByteFast()>>shift]
365 br.advance(uint8(v.entry))
366 buf[off+3] = uint8(v.entry >> 8)
367
368 off += 4
369 if off == 0 {
370 if len(dst)+256 > maxDecodedSize {
371 br.close()
372 return nil, ErrMaxDecodedSizeExceeded
373 }
374 dst = append(dst, buf[:]...)
375 }
376 }
377
378 if len(dst)+int(off) > maxDecodedSize {
379 br.close()
380 return nil, ErrMaxDecodedSizeExceeded
381 }
382 dst = append(dst, buf[:off]...)
383
384 // br < 4, so uint8 is fine
385 bitsLeft := int8(uint8(br.off)*8 + (64 - br.bitsRead))
386 for bitsLeft > 0 {
387 if br.bitsRead >= 64-8 {
388 for br.off > 0 {
389 br.value |= uint64(br.in[br.off-1]) << (br.bitsRead - 8)
390 br.bitsRead -= 8
391 br.off--
392 }
393 }
394 if len(dst) >= maxDecodedSize {
395 br.close()
396 return nil, ErrMaxDecodedSizeExceeded
397 }
398 v := dt[br.peekByteFast()>>shift]
399 nBits := uint8(v.entry)
400 br.advance(nBits)
401 bitsLeft -= int8(nBits)
402 dst = append(dst, uint8(v.entry>>8))
403 }
404 return dst, br.close()
405}
406
407// decompress1X8Bit will decompress a 1X encoded stream with tablelog <= 8.
408// The cap of the output buffer will be the maximum decompressed size.
409// The length of the supplied input must match the end of a block exactly.
410func (d *Decoder) decompress1X8BitExactly(dst, src []byte) ([]byte, error) {
411 var br bitReaderBytes
412 err := br.init(src)
413 if err != nil {
414 return dst, err
415 }
416 maxDecodedSize := cap(dst)
417 dst = dst[:0]
418
419 // Avoid bounds check by always having full sized table.
420 dt := d.dt.single[:256]
421
422 // Use temp table to avoid bound checks/append penalty.
423 var buf [256]byte
424 var off uint8
425
426 const shift = 0
427
428 //fmt.Printf("mask: %b, tl:%d\n", mask, d.actualTableLog)
429 for br.off >= 4 {
430 br.fillFast()
431 v := dt[br.peekByteFast()>>shift]
432 br.advance(uint8(v.entry))
433 buf[off+0] = uint8(v.entry >> 8)
434
435 v = dt[br.peekByteFast()>>shift]
436 br.advance(uint8(v.entry))
437 buf[off+1] = uint8(v.entry >> 8)
438
439 v = dt[br.peekByteFast()>>shift]
440 br.advance(uint8(v.entry))
441 buf[off+2] = uint8(v.entry >> 8)
442
443 v = dt[br.peekByteFast()>>shift]
444 br.advance(uint8(v.entry))
445 buf[off+3] = uint8(v.entry >> 8)
446
447 off += 4
448 if off == 0 {
449 if len(dst)+256 > maxDecodedSize {
450 br.close()
451 return nil, ErrMaxDecodedSizeExceeded
452 }
453 dst = append(dst, buf[:]...)
454 }
455 }
456
457 if len(dst)+int(off) > maxDecodedSize {
458 br.close()
459 return nil, ErrMaxDecodedSizeExceeded
460 }
461 dst = append(dst, buf[:off]...)
462
463 // br < 4, so uint8 is fine
464 bitsLeft := int8(uint8(br.off)*8 + (64 - br.bitsRead))
465 for bitsLeft > 0 {
466 if br.bitsRead >= 64-8 {
467 for br.off > 0 {
468 br.value |= uint64(br.in[br.off-1]) << (br.bitsRead - 8)
469 br.bitsRead -= 8
470 br.off--
471 }
472 }
473 if len(dst) >= maxDecodedSize {
474 br.close()
475 return nil, ErrMaxDecodedSizeExceeded
476 }
477 v := dt[br.peekByteFast()>>shift]
478 nBits := uint8(v.entry)
479 br.advance(nBits)
480 bitsLeft -= int8(nBits)
481 dst = append(dst, uint8(v.entry>>8))
482 }
483 return dst, br.close()
484}
485
486// Decompress4X will decompress a 4X encoded stream.
487// The length of the supplied input must match the end of a block exactly.
488// The *capacity* of the dst slice must match the destination size of
489// the uncompressed data exactly.
490func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
491 if len(d.dt.single) == 0 {
492 return nil, errors.New("no table loaded")
493 }
494 if len(src) < 6+(4*1) {
495 return nil, errors.New("input too small")
496 }
497 if use8BitTables && d.actualTableLog <= 8 {
498 return d.decompress4X8bit(dst, src)
499 }
500
501 var br [4]bitReaderShifted
502 start := 6
503 for i := 0; i < 3; i++ {
504 length := int(src[i*2]) | (int(src[i*2+1]) << 8)
505 if start+length >= len(src) {
506 return nil, errors.New("truncated input (or invalid offset)")
507 }
508 err := br[i].init(src[start : start+length])
509 if err != nil {
510 return nil, err
511 }
512 start += length
513 }
514 err := br[3].init(src[start:])
515 if err != nil {
516 return nil, err
517 }
518
519 // destination, offset to match first output
520 dstSize := cap(dst)
521 dst = dst[:dstSize]
522 out := dst
523 dstEvery := (dstSize + 3) / 4
524
525 const tlSize = 1 << tableLogMax
526 const tlMask = tlSize - 1
527 single := d.dt.single[:tlSize]
528
529 // Use temp table to avoid bound checks/append penalty.
530 var buf [256]byte
531 var off uint8
532 var decoded int
533
534 // Decode 2 values from each decoder/loop.
535 const bufoff = 256 / 4
536 for {
537 if br[0].off < 4 || br[1].off < 4 || br[2].off < 4 || br[3].off < 4 {
538 break
539 }
540
541 {
542 const stream = 0
543 const stream2 = 1
544 br[stream].fillFast()
545 br[stream2].fillFast()
546
547 val := br[stream].peekBitsFast(d.actualTableLog)
548 v := single[val&tlMask]
549 br[stream].advance(uint8(v.entry))
550 buf[off+bufoff*stream] = uint8(v.entry >> 8)
551
552 val2 := br[stream2].peekBitsFast(d.actualTableLog)
553 v2 := single[val2&tlMask]
554 br[stream2].advance(uint8(v2.entry))
555 buf[off+bufoff*stream2] = uint8(v2.entry >> 8)
556
557 val = br[stream].peekBitsFast(d.actualTableLog)
558 v = single[val&tlMask]
559 br[stream].advance(uint8(v.entry))
560 buf[off+bufoff*stream+1] = uint8(v.entry >> 8)
561
562 val2 = br[stream2].peekBitsFast(d.actualTableLog)
563 v2 = single[val2&tlMask]
564 br[stream2].advance(uint8(v2.entry))
565 buf[off+bufoff*stream2+1] = uint8(v2.entry >> 8)
566 }
567
568 {
569 const stream = 2
570 const stream2 = 3
571 br[stream].fillFast()
572 br[stream2].fillFast()
573
574 val := br[stream].peekBitsFast(d.actualTableLog)
575 v := single[val&tlMask]
576 br[stream].advance(uint8(v.entry))
577 buf[off+bufoff*stream] = uint8(v.entry >> 8)
578
579 val2 := br[stream2].peekBitsFast(d.actualTableLog)
580 v2 := single[val2&tlMask]
581 br[stream2].advance(uint8(v2.entry))
582 buf[off+bufoff*stream2] = uint8(v2.entry >> 8)
583
584 val = br[stream].peekBitsFast(d.actualTableLog)
585 v = single[val&tlMask]
586 br[stream].advance(uint8(v.entry))
587 buf[off+bufoff*stream+1] = uint8(v.entry >> 8)
588
589 val2 = br[stream2].peekBitsFast(d.actualTableLog)
590 v2 = single[val2&tlMask]
591 br[stream2].advance(uint8(v2.entry))
592 buf[off+bufoff*stream2+1] = uint8(v2.entry >> 8)
593 }
594
595 off += 2
596
597 if off == bufoff {
598 if bufoff > dstEvery {
599 return nil, errors.New("corruption detected: stream overrun 1")
600 }
601 copy(out, buf[:bufoff])
602 copy(out[dstEvery:], buf[bufoff:bufoff*2])
603 copy(out[dstEvery*2:], buf[bufoff*2:bufoff*3])
604 copy(out[dstEvery*3:], buf[bufoff*3:bufoff*4])
605 off = 0
606 out = out[bufoff:]
607 decoded += 256
608 // There must at least be 3 buffers left.
609 if len(out) < dstEvery*3 {
610 return nil, errors.New("corruption detected: stream overrun 2")
611 }
612 }
613 }
614 if off > 0 {
615 ioff := int(off)
616 if len(out) < dstEvery*3+ioff {
617 return nil, errors.New("corruption detected: stream overrun 3")
618 }
619 copy(out, buf[:off])
620 copy(out[dstEvery:dstEvery+ioff], buf[bufoff:bufoff*2])
621 copy(out[dstEvery*2:dstEvery*2+ioff], buf[bufoff*2:bufoff*3])
622 copy(out[dstEvery*3:dstEvery*3+ioff], buf[bufoff*3:bufoff*4])
623 decoded += int(off) * 4
624 out = out[off:]
625 }
626
627 // Decode remaining.
628 for i := range br {
629 offset := dstEvery * i
630 br := &br[i]
631 bitsLeft := br.off*8 + uint(64-br.bitsRead)
632 for bitsLeft > 0 {
633 br.fill()
634 if false && br.bitsRead >= 32 {
635 if br.off >= 4 {
636 v := br.in[br.off-4:]
637 v = v[:4]
638 low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
639 br.value = (br.value << 32) | uint64(low)
640 br.bitsRead -= 32
641 br.off -= 4
642 } else {
643 for br.off > 0 {
644 br.value = (br.value << 8) | uint64(br.in[br.off-1])
645 br.bitsRead -= 8
646 br.off--
647 }
648 }
649 }
650 // end inline...
651 if offset >= len(out) {
652 return nil, errors.New("corruption detected: stream overrun 4")
653 }
654
655 // Read value and increment offset.
656 val := br.peekBitsFast(d.actualTableLog)
657 v := single[val&tlMask].entry
658 nBits := uint8(v)
659 br.advance(nBits)
660 bitsLeft -= uint(nBits)
661 out[offset] = uint8(v >> 8)
662 offset++
663 }
664 decoded += offset - dstEvery*i
665 err = br.close()
666 if err != nil {
667 return nil, err
668 }
669 }
670 if dstSize != decoded {
671 return nil, errors.New("corruption detected: short output block")
672 }
673 return dst, nil
674}
675
676// Decompress4X will decompress a 4X encoded stream.
677// The length of the supplied input must match the end of a block exactly.
678// The *capacity* of the dst slice must match the destination size of
679// the uncompressed data exactly.
680func (d *Decoder) decompress4X8bit(dst, src []byte) ([]byte, error) {
681 if d.actualTableLog == 8 {
682 return d.decompress4X8bitExactly(dst, src)
683 }
684
685 var br [4]bitReaderBytes
686 start := 6
687 for i := 0; i < 3; i++ {
688 length := int(src[i*2]) | (int(src[i*2+1]) << 8)
689 if start+length >= len(src) {
690 return nil, errors.New("truncated input (or invalid offset)")
691 }
692 err := br[i].init(src[start : start+length])
693 if err != nil {
694 return nil, err
695 }
696 start += length
697 }
698 err := br[3].init(src[start:])
699 if err != nil {
700 return nil, err
701 }
702
703 // destination, offset to match first output
704 dstSize := cap(dst)
705 dst = dst[:dstSize]
706 out := dst
707 dstEvery := (dstSize + 3) / 4
708
709 shift := (8 - d.actualTableLog) & 7
710
711 const tlSize = 1 << 8
712 const tlMask = tlSize - 1
713 single := d.dt.single[:tlSize]
714
715 // Use temp table to avoid bound checks/append penalty.
716 var buf [256]byte
717 var off uint8
718 var decoded int
719
720 // Decode 4 values from each decoder/loop.
721 const bufoff = 256 / 4
722 for {
723 if br[0].off < 4 || br[1].off < 4 || br[2].off < 4 || br[3].off < 4 {
724 break
725 }
726
727 {
728 // Interleave 2 decodes.
729 const stream = 0
730 const stream2 = 1
731 br[stream].fillFast()
732 br[stream2].fillFast()
733
734 v := single[br[stream].peekByteFast()>>shift].entry
735 buf[off+bufoff*stream] = uint8(v >> 8)
736 br[stream].advance(uint8(v))
737
738 v2 := single[br[stream2].peekByteFast()>>shift].entry
739 buf[off+bufoff*stream2] = uint8(v2 >> 8)
740 br[stream2].advance(uint8(v2))
741
742 v = single[br[stream].peekByteFast()>>shift].entry
743 buf[off+bufoff*stream+1] = uint8(v >> 8)
744 br[stream].advance(uint8(v))
745
746 v2 = single[br[stream2].peekByteFast()>>shift].entry
747 buf[off+bufoff*stream2+1] = uint8(v2 >> 8)
748 br[stream2].advance(uint8(v2))
749
750 v = single[br[stream].peekByteFast()>>shift].entry
751 buf[off+bufoff*stream+2] = uint8(v >> 8)
752 br[stream].advance(uint8(v))
753
754 v2 = single[br[stream2].peekByteFast()>>shift].entry
755 buf[off+bufoff*stream2+2] = uint8(v2 >> 8)
756 br[stream2].advance(uint8(v2))
757
758 v = single[br[stream].peekByteFast()>>shift].entry
759 buf[off+bufoff*stream+3] = uint8(v >> 8)
760 br[stream].advance(uint8(v))
761
762 v2 = single[br[stream2].peekByteFast()>>shift].entry
763 buf[off+bufoff*stream2+3] = uint8(v2 >> 8)
764 br[stream2].advance(uint8(v2))
765 }
766
767 {
768 const stream = 2
769 const stream2 = 3
770 br[stream].fillFast()
771 br[stream2].fillFast()
772
773 v := single[br[stream].peekByteFast()>>shift].entry
774 buf[off+bufoff*stream] = uint8(v >> 8)
775 br[stream].advance(uint8(v))
776
777 v2 := single[br[stream2].peekByteFast()>>shift].entry
778 buf[off+bufoff*stream2] = uint8(v2 >> 8)
779 br[stream2].advance(uint8(v2))
780
781 v = single[br[stream].peekByteFast()>>shift].entry
782 buf[off+bufoff*stream+1] = uint8(v >> 8)
783 br[stream].advance(uint8(v))
784
785 v2 = single[br[stream2].peekByteFast()>>shift].entry
786 buf[off+bufoff*stream2+1] = uint8(v2 >> 8)
787 br[stream2].advance(uint8(v2))
788
789 v = single[br[stream].peekByteFast()>>shift].entry
790 buf[off+bufoff*stream+2] = uint8(v >> 8)
791 br[stream].advance(uint8(v))
792
793 v2 = single[br[stream2].peekByteFast()>>shift].entry
794 buf[off+bufoff*stream2+2] = uint8(v2 >> 8)
795 br[stream2].advance(uint8(v2))
796
797 v = single[br[stream].peekByteFast()>>shift].entry
798 buf[off+bufoff*stream+3] = uint8(v >> 8)
799 br[stream].advance(uint8(v))
800
801 v2 = single[br[stream2].peekByteFast()>>shift].entry
802 buf[off+bufoff*stream2+3] = uint8(v2 >> 8)
803 br[stream2].advance(uint8(v2))
804 }
805
806 off += 4
807
808 if off == bufoff {
809 if bufoff > dstEvery {
810 return nil, errors.New("corruption detected: stream overrun 1")
811 }
812 copy(out, buf[:bufoff])
813 copy(out[dstEvery:], buf[bufoff:bufoff*2])
814 copy(out[dstEvery*2:], buf[bufoff*2:bufoff*3])
815 copy(out[dstEvery*3:], buf[bufoff*3:bufoff*4])
816 off = 0
817 out = out[bufoff:]
818 decoded += 256
819 // There must at least be 3 buffers left.
820 if len(out) < dstEvery*3 {
821 return nil, errors.New("corruption detected: stream overrun 2")
822 }
823 }
824 }
825 if off > 0 {
826 ioff := int(off)
827 if len(out) < dstEvery*3+ioff {
828 return nil, errors.New("corruption detected: stream overrun 3")
829 }
830 copy(out, buf[:off])
831 copy(out[dstEvery:dstEvery+ioff], buf[bufoff:bufoff*2])
832 copy(out[dstEvery*2:dstEvery*2+ioff], buf[bufoff*2:bufoff*3])
833 copy(out[dstEvery*3:dstEvery*3+ioff], buf[bufoff*3:bufoff*4])
834 decoded += int(off) * 4
835 out = out[off:]
836 }
837
838 // Decode remaining.
839 for i := range br {
840 offset := dstEvery * i
841 br := &br[i]
842 bitsLeft := int(br.off*8) + int(64-br.bitsRead)
843 for bitsLeft > 0 {
844 if br.finished() {
845 return nil, io.ErrUnexpectedEOF
846 }
847 if br.bitsRead >= 56 {
848 if br.off >= 4 {
849 v := br.in[br.off-4:]
850 v = v[:4]
851 low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
852 br.value |= uint64(low) << (br.bitsRead - 32)
853 br.bitsRead -= 32
854 br.off -= 4
855 } else {
856 for br.off > 0 {
857 br.value |= uint64(br.in[br.off-1]) << (br.bitsRead - 8)
858 br.bitsRead -= 8
859 br.off--
860 }
861 }
862 }
863 // end inline...
864 if offset >= len(out) {
865 return nil, errors.New("corruption detected: stream overrun 4")
866 }
867
868 // Read value and increment offset.
869 v := single[br.peekByteFast()>>shift].entry
870 nBits := uint8(v)
871 br.advance(nBits)
872 bitsLeft -= int(nBits)
873 out[offset] = uint8(v >> 8)
874 offset++
875 }
876 decoded += offset - dstEvery*i
877 err = br.close()
878 if err != nil {
879 return nil, err
880 }
881 }
882 if dstSize != decoded {
883 return nil, errors.New("corruption detected: short output block")
884 }
885 return dst, nil
886}
887
888// Decompress4X will decompress a 4X encoded stream.
889// The length of the supplied input must match the end of a block exactly.
890// The *capacity* of the dst slice must match the destination size of
891// the uncompressed data exactly.
892func (d *Decoder) decompress4X8bitExactly(dst, src []byte) ([]byte, error) {
893 var br [4]bitReaderBytes
894 start := 6
895 for i := 0; i < 3; i++ {
896 length := int(src[i*2]) | (int(src[i*2+1]) << 8)
897 if start+length >= len(src) {
898 return nil, errors.New("truncated input (or invalid offset)")
899 }
900 err := br[i].init(src[start : start+length])
901 if err != nil {
902 return nil, err
903 }
904 start += length
905 }
906 err := br[3].init(src[start:])
907 if err != nil {
908 return nil, err
909 }
910
911 // destination, offset to match first output
912 dstSize := cap(dst)
913 dst = dst[:dstSize]
914 out := dst
915 dstEvery := (dstSize + 3) / 4
916
917 const shift = 0
918 const tlSize = 1 << 8
919 const tlMask = tlSize - 1
920 single := d.dt.single[:tlSize]
921
922 // Use temp table to avoid bound checks/append penalty.
923 var buf [256]byte
924 var off uint8
925 var decoded int
926
927 // Decode 4 values from each decoder/loop.
928 const bufoff = 256 / 4
929 for {
930 if br[0].off < 4 || br[1].off < 4 || br[2].off < 4 || br[3].off < 4 {
931 break
932 }
933
934 {
935 // Interleave 2 decodes.
936 const stream = 0
937 const stream2 = 1
938 br[stream].fillFast()
939 br[stream2].fillFast()
940
941 v := single[br[stream].peekByteFast()>>shift].entry
942 buf[off+bufoff*stream] = uint8(v >> 8)
943 br[stream].advance(uint8(v))
944
945 v2 := single[br[stream2].peekByteFast()>>shift].entry
946 buf[off+bufoff*stream2] = uint8(v2 >> 8)
947 br[stream2].advance(uint8(v2))
948
949 v = single[br[stream].peekByteFast()>>shift].entry
950 buf[off+bufoff*stream+1] = uint8(v >> 8)
951 br[stream].advance(uint8(v))
952
953 v2 = single[br[stream2].peekByteFast()>>shift].entry
954 buf[off+bufoff*stream2+1] = uint8(v2 >> 8)
955 br[stream2].advance(uint8(v2))
956
957 v = single[br[stream].peekByteFast()>>shift].entry
958 buf[off+bufoff*stream+2] = uint8(v >> 8)
959 br[stream].advance(uint8(v))
960
961 v2 = single[br[stream2].peekByteFast()>>shift].entry
962 buf[off+bufoff*stream2+2] = uint8(v2 >> 8)
963 br[stream2].advance(uint8(v2))
964
965 v = single[br[stream].peekByteFast()>>shift].entry
966 buf[off+bufoff*stream+3] = uint8(v >> 8)
967 br[stream].advance(uint8(v))
968
969 v2 = single[br[stream2].peekByteFast()>>shift].entry
970 buf[off+bufoff*stream2+3] = uint8(v2 >> 8)
971 br[stream2].advance(uint8(v2))
972 }
973
974 {
975 const stream = 2
976 const stream2 = 3
977 br[stream].fillFast()
978 br[stream2].fillFast()
979
980 v := single[br[stream].peekByteFast()>>shift].entry
981 buf[off+bufoff*stream] = uint8(v >> 8)
982 br[stream].advance(uint8(v))
983
984 v2 := single[br[stream2].peekByteFast()>>shift].entry
985 buf[off+bufoff*stream2] = uint8(v2 >> 8)
986 br[stream2].advance(uint8(v2))
987
988 v = single[br[stream].peekByteFast()>>shift].entry
989 buf[off+bufoff*stream+1] = uint8(v >> 8)
990 br[stream].advance(uint8(v))
991
992 v2 = single[br[stream2].peekByteFast()>>shift].entry
993 buf[off+bufoff*stream2+1] = uint8(v2 >> 8)
994 br[stream2].advance(uint8(v2))
995
996 v = single[br[stream].peekByteFast()>>shift].entry
997 buf[off+bufoff*stream+2] = uint8(v >> 8)
998 br[stream].advance(uint8(v))
999
1000 v2 = single[br[stream2].peekByteFast()>>shift].entry
1001 buf[off+bufoff*stream2+2] = uint8(v2 >> 8)
1002 br[stream2].advance(uint8(v2))
1003
1004 v = single[br[stream].peekByteFast()>>shift].entry
1005 buf[off+bufoff*stream+3] = uint8(v >> 8)
1006 br[stream].advance(uint8(v))
1007
1008 v2 = single[br[stream2].peekByteFast()>>shift].entry
1009 buf[off+bufoff*stream2+3] = uint8(v2 >> 8)
1010 br[stream2].advance(uint8(v2))
1011 }
1012
1013 off += 4
1014
1015 if off == bufoff {
1016 if bufoff > dstEvery {
1017 return nil, errors.New("corruption detected: stream overrun 1")
1018 }
1019 copy(out, buf[:bufoff])
1020 copy(out[dstEvery:], buf[bufoff:bufoff*2])
1021 copy(out[dstEvery*2:], buf[bufoff*2:bufoff*3])
1022 copy(out[dstEvery*3:], buf[bufoff*3:bufoff*4])
1023 off = 0
1024 out = out[bufoff:]
1025 decoded += 256
1026 // There must at least be 3 buffers left.
1027 if len(out) < dstEvery*3 {
1028 return nil, errors.New("corruption detected: stream overrun 2")
1029 }
1030 }
1031 }
1032 if off > 0 {
1033 ioff := int(off)
1034 if len(out) < dstEvery*3+ioff {
1035 return nil, errors.New("corruption detected: stream overrun 3")
1036 }
1037 copy(out, buf[:off])
1038 copy(out[dstEvery:dstEvery+ioff], buf[bufoff:bufoff*2])
1039 copy(out[dstEvery*2:dstEvery*2+ioff], buf[bufoff*2:bufoff*3])
1040 copy(out[dstEvery*3:dstEvery*3+ioff], buf[bufoff*3:bufoff*4])
1041 decoded += int(off) * 4
1042 out = out[off:]
1043 }
1044
1045 // Decode remaining.
1046 for i := range br {
1047 offset := dstEvery * i
1048 br := &br[i]
1049 bitsLeft := int(br.off*8) + int(64-br.bitsRead)
1050 for bitsLeft > 0 {
1051 if br.finished() {
1052 return nil, io.ErrUnexpectedEOF
1053 }
1054 if br.bitsRead >= 56 {
1055 if br.off >= 4 {
1056 v := br.in[br.off-4:]
1057 v = v[:4]
1058 low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
1059 br.value |= uint64(low) << (br.bitsRead - 32)
1060 br.bitsRead -= 32
1061 br.off -= 4
1062 } else {
1063 for br.off > 0 {
1064 br.value |= uint64(br.in[br.off-1]) << (br.bitsRead - 8)
1065 br.bitsRead -= 8
1066 br.off--
1067 }
1068 }
1069 }
1070 // end inline...
1071 if offset >= len(out) {
1072 return nil, errors.New("corruption detected: stream overrun 4")
1073 }
1074
1075 // Read value and increment offset.
1076 v := single[br.peekByteFast()>>shift].entry
1077 nBits := uint8(v)
1078 br.advance(nBits)
1079 bitsLeft -= int(nBits)
1080 out[offset] = uint8(v >> 8)
1081 offset++
1082 }
1083 decoded += offset - dstEvery*i
1084 err = br.close()
1085 if err != nil {
1086 return nil, err
1087 }
1088 }
1089 if dstSize != decoded {
1090 return nil, errors.New("corruption detected: short output block")
1091 }
1092 return dst, nil
1093}
1094
1095// matches will compare a decoding table to a coding table.
1096// Errors are written to the writer.
1097// Nothing will be written if table is ok.
1098func (s *Scratch) matches(ct cTable, w io.Writer) {
1099 if s == nil || len(s.dt.single) == 0 {
1100 return
1101 }
1102 dt := s.dt.single[:1<<s.actualTableLog]
1103 tablelog := s.actualTableLog
1104 ok := 0
1105 broken := 0
1106 for sym, enc := range ct {
1107 errs := 0
1108 broken++
1109 if enc.nBits == 0 {
1110 for _, dec := range dt {
1111 if uint8(dec.entry>>8) == byte(sym) {
1112 fmt.Fprintf(w, "symbol %x has decoder, but no encoder\n", sym)
1113 errs++
1114 break
1115 }
1116 }
1117 if errs == 0 {
1118 broken--
1119 }
1120 continue
1121 }
1122 // Unused bits in input
1123 ub := tablelog - enc.nBits
1124 top := enc.val << ub
1125 // decoder looks at top bits.
1126 dec := dt[top]
1127 if uint8(dec.entry) != enc.nBits {
1128 fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", sym, enc.nBits, uint8(dec.entry))
1129 errs++
1130 }
1131 if uint8(dec.entry>>8) != uint8(sym) {
1132 fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", sym, sym, uint8(dec.entry>>8))
1133 errs++
1134 }
1135 if errs > 0 {
1136 fmt.Fprintf(w, "%d errros in base, stopping\n", errs)
1137 continue
1138 }
1139 // Ensure that all combinations are covered.
1140 for i := uint16(0); i < (1 << ub); i++ {
1141 vval := top | i
1142 dec := dt[vval]
1143 if uint8(dec.entry) != enc.nBits {
1144 fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", vval, enc.nBits, uint8(dec.entry))
1145 errs++
1146 }
1147 if uint8(dec.entry>>8) != uint8(sym) {
1148 fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", vval, sym, uint8(dec.entry>>8))
1149 errs++
1150 }
1151 if errs > 20 {
1152 fmt.Fprintf(w, "%d errros, stopping\n", errs)
1153 break
1154 }
1155 }
1156 if errs == 0 {
1157 ok++
1158 broken--
1159 }
1160 }
1161 if broken > 0 {
1162 fmt.Fprintf(w, "%d broken, %d ok\n", broken, ok)
1163 }
1164}