blob: b51d922bda6918af16e1d161520a5fd1e0080aa1 [file] [log] [blame]
khenaidood948f772021-08-11 17:49:24 -04001// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8 "errors"
9 "fmt"
10 "io"
11 "sync"
12
13 "github.com/klauspost/compress/huff0"
14 "github.com/klauspost/compress/zstd/internal/xxhash"
15)
16
17type blockType uint8
18
19//go:generate stringer -type=blockType,literalsBlockType,seqCompMode,tableIndex
20
21const (
22 blockTypeRaw blockType = iota
23 blockTypeRLE
24 blockTypeCompressed
25 blockTypeReserved
26)
27
28type literalsBlockType uint8
29
30const (
31 literalsBlockRaw literalsBlockType = iota
32 literalsBlockRLE
33 literalsBlockCompressed
34 literalsBlockTreeless
35)
36
37const (
38 // maxCompressedBlockSize is the biggest allowed compressed block size (128KB)
39 maxCompressedBlockSize = 128 << 10
40
41 // Maximum possible block size (all Raw+Uncompressed).
42 maxBlockSize = (1 << 21) - 1
43
44 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#literals_section_header
45 maxCompressedLiteralSize = 1 << 18
46 maxRLELiteralSize = 1 << 20
47 maxMatchLen = 131074
48 maxSequences = 0x7f00 + 0xffff
49
50 // We support slightly less than the reference decoder to be able to
51 // use ints on 32 bit archs.
52 maxOffsetBits = 30
53)
54
55var (
56 huffDecoderPool = sync.Pool{New: func() interface{} {
57 return &huff0.Scratch{}
58 }}
59
60 fseDecoderPool = sync.Pool{New: func() interface{} {
61 return &fseDecoder{}
62 }}
63)
64
65type blockDec struct {
66 // Raw source data of the block.
67 data []byte
68 dataStorage []byte
69
70 // Destination of the decoded data.
71 dst []byte
72
73 // Buffer for literals data.
74 literalBuf []byte
75
76 // Window size of the block.
77 WindowSize uint64
78
79 history chan *history
80 input chan struct{}
81 result chan decodeOutput
82 sequenceBuf []seq
83 err error
84 decWG sync.WaitGroup
85
86 // Frame to use for singlethreaded decoding.
87 // Should not be used by the decoder itself since parent may be another frame.
88 localFrame *frameDec
89
90 // Block is RLE, this is the size.
91 RLESize uint32
92 tmp [4]byte
93
94 Type blockType
95
96 // Is this the last block of a frame?
97 Last bool
98
99 // Use less memory
100 lowMem bool
101}
102
103func (b *blockDec) String() string {
104 if b == nil {
105 return "<nil>"
106 }
107 return fmt.Sprintf("Steam Size: %d, Type: %v, Last: %t, Window: %d", len(b.data), b.Type, b.Last, b.WindowSize)
108}
109
110func newBlockDec(lowMem bool) *blockDec {
111 b := blockDec{
112 lowMem: lowMem,
113 result: make(chan decodeOutput, 1),
114 input: make(chan struct{}, 1),
115 history: make(chan *history, 1),
116 }
117 b.decWG.Add(1)
118 go b.startDecoder()
119 return &b
120}
121
122// reset will reset the block.
123// Input must be a start of a block and will be at the end of the block when returned.
124func (b *blockDec) reset(br byteBuffer, windowSize uint64) error {
125 b.WindowSize = windowSize
126 tmp := br.readSmall(3)
127 if tmp == nil {
128 if debug {
129 println("Reading block header:", io.ErrUnexpectedEOF)
130 }
131 return io.ErrUnexpectedEOF
132 }
133 bh := uint32(tmp[0]) | (uint32(tmp[1]) << 8) | (uint32(tmp[2]) << 16)
134 b.Last = bh&1 != 0
135 b.Type = blockType((bh >> 1) & 3)
136 // find size.
137 cSize := int(bh >> 3)
138 maxSize := maxBlockSize
139 switch b.Type {
140 case blockTypeReserved:
141 return ErrReservedBlockType
142 case blockTypeRLE:
143 b.RLESize = uint32(cSize)
144 if b.lowMem {
145 maxSize = cSize
146 }
147 cSize = 1
148 case blockTypeCompressed:
149 if debug {
150 println("Data size on stream:", cSize)
151 }
152 b.RLESize = 0
153 maxSize = maxCompressedBlockSize
154 if windowSize < maxCompressedBlockSize && b.lowMem {
155 maxSize = int(windowSize)
156 }
157 if cSize > maxCompressedBlockSize || uint64(cSize) > b.WindowSize {
158 if debug {
159 printf("compressed block too big: csize:%d block: %+v\n", uint64(cSize), b)
160 }
161 return ErrCompressedSizeTooBig
162 }
163 case blockTypeRaw:
164 b.RLESize = 0
165 // We do not need a destination for raw blocks.
166 maxSize = -1
167 default:
168 panic("Invalid block type")
169 }
170
171 // Read block data.
172 if cap(b.dataStorage) < cSize {
173 if b.lowMem {
174 b.dataStorage = make([]byte, 0, cSize)
175 } else {
176 b.dataStorage = make([]byte, 0, maxBlockSize)
177 }
178 }
179 if cap(b.dst) <= maxSize {
180 b.dst = make([]byte, 0, maxSize+1)
181 }
182 var err error
183 b.data, err = br.readBig(cSize, b.dataStorage)
184 if err != nil {
185 if debug {
186 println("Reading block:", err, "(", cSize, ")", len(b.data))
187 printf("%T", br)
188 }
189 return err
190 }
191 return nil
192}
193
194// sendEOF will make the decoder send EOF on this frame.
195func (b *blockDec) sendErr(err error) {
196 b.Last = true
197 b.Type = blockTypeReserved
198 b.err = err
199 b.input <- struct{}{}
200}
201
202// Close will release resources.
203// Closed blockDec cannot be reset.
204func (b *blockDec) Close() {
205 close(b.input)
206 close(b.history)
207 close(b.result)
208 b.decWG.Wait()
209}
210
211// decodeAsync will prepare decoding the block when it receives input.
212// This will separate output and history.
213func (b *blockDec) startDecoder() {
214 defer b.decWG.Done()
215 for range b.input {
216 //println("blockDec: Got block input")
217 switch b.Type {
218 case blockTypeRLE:
219 if cap(b.dst) < int(b.RLESize) {
220 if b.lowMem {
221 b.dst = make([]byte, b.RLESize)
222 } else {
223 b.dst = make([]byte, maxBlockSize)
224 }
225 }
226 o := decodeOutput{
227 d: b,
228 b: b.dst[:b.RLESize],
229 err: nil,
230 }
231 v := b.data[0]
232 for i := range o.b {
233 o.b[i] = v
234 }
235 hist := <-b.history
236 hist.append(o.b)
237 b.result <- o
238 case blockTypeRaw:
239 o := decodeOutput{
240 d: b,
241 b: b.data,
242 err: nil,
243 }
244 hist := <-b.history
245 hist.append(o.b)
246 b.result <- o
247 case blockTypeCompressed:
248 b.dst = b.dst[:0]
249 err := b.decodeCompressed(nil)
250 o := decodeOutput{
251 d: b,
252 b: b.dst,
253 err: err,
254 }
255 if debug {
256 println("Decompressed to", len(b.dst), "bytes, error:", err)
257 }
258 b.result <- o
259 case blockTypeReserved:
260 // Used for returning errors.
261 <-b.history
262 b.result <- decodeOutput{
263 d: b,
264 b: nil,
265 err: b.err,
266 }
267 default:
268 panic("Invalid block type")
269 }
270 if debug {
271 println("blockDec: Finished block")
272 }
273 }
274}
275
276// decodeAsync will prepare decoding the block when it receives the history.
277// If history is provided, it will not fetch it from the channel.
278func (b *blockDec) decodeBuf(hist *history) error {
279 switch b.Type {
280 case blockTypeRLE:
281 if cap(b.dst) < int(b.RLESize) {
282 if b.lowMem {
283 b.dst = make([]byte, b.RLESize)
284 } else {
285 b.dst = make([]byte, maxBlockSize)
286 }
287 }
288 b.dst = b.dst[:b.RLESize]
289 v := b.data[0]
290 for i := range b.dst {
291 b.dst[i] = v
292 }
293 hist.appendKeep(b.dst)
294 return nil
295 case blockTypeRaw:
296 hist.appendKeep(b.data)
297 return nil
298 case blockTypeCompressed:
299 saved := b.dst
300 b.dst = hist.b
301 hist.b = nil
302 err := b.decodeCompressed(hist)
303 if debug {
304 println("Decompressed to total", len(b.dst), "bytes, hash:", xxhash.Sum64(b.dst), "error:", err)
305 }
306 hist.b = b.dst
307 b.dst = saved
308 return err
309 case blockTypeReserved:
310 // Used for returning errors.
311 return b.err
312 default:
313 panic("Invalid block type")
314 }
315}
316
317// decodeCompressed will start decompressing a block.
318// If no history is supplied the decoder will decodeAsync as much as possible
319// before fetching from blockDec.history
320func (b *blockDec) decodeCompressed(hist *history) error {
321 in := b.data
322 delayedHistory := hist == nil
323
324 if delayedHistory {
325 // We must always grab history.
326 defer func() {
327 if hist == nil {
328 <-b.history
329 }
330 }()
331 }
332 // There must be at least one byte for Literals_Block_Type and one for Sequences_Section_Header
333 if len(in) < 2 {
334 return ErrBlockTooSmall
335 }
336 litType := literalsBlockType(in[0] & 3)
337 var litRegenSize int
338 var litCompSize int
339 sizeFormat := (in[0] >> 2) & 3
340 var fourStreams bool
341 switch litType {
342 case literalsBlockRaw, literalsBlockRLE:
343 switch sizeFormat {
344 case 0, 2:
345 // Regenerated_Size uses 5 bits (0-31). Literals_Section_Header uses 1 byte.
346 litRegenSize = int(in[0] >> 3)
347 in = in[1:]
348 case 1:
349 // Regenerated_Size uses 12 bits (0-4095). Literals_Section_Header uses 2 bytes.
350 litRegenSize = int(in[0]>>4) + (int(in[1]) << 4)
351 in = in[2:]
352 case 3:
353 // Regenerated_Size uses 20 bits (0-1048575). Literals_Section_Header uses 3 bytes.
354 if len(in) < 3 {
355 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
356 return ErrBlockTooSmall
357 }
358 litRegenSize = int(in[0]>>4) + (int(in[1]) << 4) + (int(in[2]) << 12)
359 in = in[3:]
360 }
361 case literalsBlockCompressed, literalsBlockTreeless:
362 switch sizeFormat {
363 case 0, 1:
364 // Both Regenerated_Size and Compressed_Size use 10 bits (0-1023).
365 if len(in) < 3 {
366 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
367 return ErrBlockTooSmall
368 }
369 n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12)
370 litRegenSize = int(n & 1023)
371 litCompSize = int(n >> 10)
372 fourStreams = sizeFormat == 1
373 in = in[3:]
374 case 2:
375 fourStreams = true
376 if len(in) < 4 {
377 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
378 return ErrBlockTooSmall
379 }
380 n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) + (uint64(in[3]) << 20)
381 litRegenSize = int(n & 16383)
382 litCompSize = int(n >> 14)
383 in = in[4:]
384 case 3:
385 fourStreams = true
386 if len(in) < 5 {
387 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
388 return ErrBlockTooSmall
389 }
390 n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) + (uint64(in[3]) << 20) + (uint64(in[4]) << 28)
391 litRegenSize = int(n & 262143)
392 litCompSize = int(n >> 18)
393 in = in[5:]
394 }
395 }
396 if debug {
397 println("literals type:", litType, "litRegenSize:", litRegenSize, "litCompSize:", litCompSize, "sizeFormat:", sizeFormat, "4X:", fourStreams)
398 }
399 var literals []byte
400 var huff *huff0.Scratch
401 switch litType {
402 case literalsBlockRaw:
403 if len(in) < litRegenSize {
404 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litRegenSize)
405 return ErrBlockTooSmall
406 }
407 literals = in[:litRegenSize]
408 in = in[litRegenSize:]
409 //printf("Found %d uncompressed literals\n", litRegenSize)
410 case literalsBlockRLE:
411 if len(in) < 1 {
412 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", 1)
413 return ErrBlockTooSmall
414 }
415 if cap(b.literalBuf) < litRegenSize {
416 if b.lowMem {
417 b.literalBuf = make([]byte, litRegenSize)
418 } else {
419 if litRegenSize > maxCompressedLiteralSize {
420 // Exceptional
421 b.literalBuf = make([]byte, litRegenSize)
422 } else {
423 b.literalBuf = make([]byte, litRegenSize, maxCompressedLiteralSize)
424
425 }
426 }
427 }
428 literals = b.literalBuf[:litRegenSize]
429 v := in[0]
430 for i := range literals {
431 literals[i] = v
432 }
433 in = in[1:]
434 if debug {
435 printf("Found %d RLE compressed literals\n", litRegenSize)
436 }
437 case literalsBlockTreeless:
438 if len(in) < litCompSize {
439 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litCompSize)
440 return ErrBlockTooSmall
441 }
442 // Store compressed literals, so we defer decoding until we get history.
443 literals = in[:litCompSize]
444 in = in[litCompSize:]
445 if debug {
446 printf("Found %d compressed literals\n", litCompSize)
447 }
448 case literalsBlockCompressed:
449 if len(in) < litCompSize {
450 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litCompSize)
451 return ErrBlockTooSmall
452 }
453 literals = in[:litCompSize]
454 in = in[litCompSize:]
455 huff = huffDecoderPool.Get().(*huff0.Scratch)
456 var err error
457 // Ensure we have space to store it.
458 if cap(b.literalBuf) < litRegenSize {
459 if b.lowMem {
460 b.literalBuf = make([]byte, 0, litRegenSize)
461 } else {
462 b.literalBuf = make([]byte, 0, maxCompressedLiteralSize)
463 }
464 }
465 if huff == nil {
466 huff = &huff0.Scratch{}
467 }
468 huff, literals, err = huff0.ReadTable(literals, huff)
469 if err != nil {
470 println("reading huffman table:", err)
471 return err
472 }
473 // Use our out buffer.
474 if fourStreams {
475 literals, err = huff.Decoder().Decompress4X(b.literalBuf[:0:litRegenSize], literals)
476 } else {
477 literals, err = huff.Decoder().Decompress1X(b.literalBuf[:0:litRegenSize], literals)
478 }
479 if err != nil {
480 println("decoding compressed literals:", err)
481 return err
482 }
483 // Make sure we don't leak our literals buffer
484 if len(literals) != litRegenSize {
485 return fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals))
486 }
487 if debug {
488 printf("Decompressed %d literals into %d bytes\n", litCompSize, litRegenSize)
489 }
490 }
491
492 // Decode Sequences
493 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#sequences-section
494 if len(in) < 1 {
495 return ErrBlockTooSmall
496 }
497 seqHeader := in[0]
498 nSeqs := 0
499 switch {
500 case seqHeader == 0:
501 in = in[1:]
502 case seqHeader < 128:
503 nSeqs = int(seqHeader)
504 in = in[1:]
505 case seqHeader < 255:
506 if len(in) < 2 {
507 return ErrBlockTooSmall
508 }
509 nSeqs = int(seqHeader-128)<<8 | int(in[1])
510 in = in[2:]
511 case seqHeader == 255:
512 if len(in) < 3 {
513 return ErrBlockTooSmall
514 }
515 nSeqs = 0x7f00 + int(in[1]) + (int(in[2]) << 8)
516 in = in[3:]
517 }
518 // Allocate sequences
519 if cap(b.sequenceBuf) < nSeqs {
520 if b.lowMem {
521 b.sequenceBuf = make([]seq, nSeqs)
522 } else {
523 // Allocate max
524 b.sequenceBuf = make([]seq, nSeqs, maxSequences)
525 }
526 } else {
527 // Reuse buffer
528 b.sequenceBuf = b.sequenceBuf[:nSeqs]
529 }
530 var seqs = &sequenceDecs{}
531 if nSeqs > 0 {
532 if len(in) < 1 {
533 return ErrBlockTooSmall
534 }
535 br := byteReader{b: in, off: 0}
536 compMode := br.Uint8()
537 br.advance(1)
538 if debug {
539 printf("Compression modes: 0b%b", compMode)
540 }
541 for i := uint(0); i < 3; i++ {
542 mode := seqCompMode((compMode >> (6 - i*2)) & 3)
543 if debug {
544 println("Table", tableIndex(i), "is", mode)
545 }
546 var seq *sequenceDec
547 switch tableIndex(i) {
548 case tableLiteralLengths:
549 seq = &seqs.litLengths
550 case tableOffsets:
551 seq = &seqs.offsets
552 case tableMatchLengths:
553 seq = &seqs.matchLengths
554 default:
555 panic("unknown table")
556 }
557 switch mode {
558 case compModePredefined:
559 seq.fse = &fsePredef[i]
560 case compModeRLE:
561 if br.remain() < 1 {
562 return ErrBlockTooSmall
563 }
564 v := br.Uint8()
565 br.advance(1)
566 dec := fseDecoderPool.Get().(*fseDecoder)
567 symb, err := decSymbolValue(v, symbolTableX[i])
568 if err != nil {
569 printf("RLE Transform table (%v) error: %v", tableIndex(i), err)
570 return err
571 }
572 dec.setRLE(symb)
573 seq.fse = dec
574 if debug {
575 printf("RLE set to %+v, code: %v", symb, v)
576 }
577 case compModeFSE:
578 println("Reading table for", tableIndex(i))
579 dec := fseDecoderPool.Get().(*fseDecoder)
580 err := dec.readNCount(&br, uint16(maxTableSymbol[i]))
581 if err != nil {
582 println("Read table error:", err)
583 return err
584 }
585 err = dec.transform(symbolTableX[i])
586 if err != nil {
587 println("Transform table error:", err)
588 return err
589 }
590 if debug {
591 println("Read table ok", "symbolLen:", dec.symbolLen)
592 }
593 seq.fse = dec
594 case compModeRepeat:
595 seq.repeat = true
596 }
597 if br.overread() {
598 return io.ErrUnexpectedEOF
599 }
600 }
601 in = br.unread()
602 }
603
604 // Wait for history.
605 // All time spent after this is critical since it is strictly sequential.
606 if hist == nil {
607 hist = <-b.history
608 if hist.error {
609 return ErrDecoderClosed
610 }
611 }
612
613 // Decode treeless literal block.
614 if litType == literalsBlockTreeless {
615 // TODO: We could send the history early WITHOUT the stream history.
616 // This would allow decoding treeless literals before the byte history is available.
617 // Silencia stats: Treeless 4393, with: 32775, total: 37168, 11% treeless.
618 // So not much obvious gain here.
619
620 if hist.huffTree == nil {
621 return errors.New("literal block was treeless, but no history was defined")
622 }
623 // Ensure we have space to store it.
624 if cap(b.literalBuf) < litRegenSize {
625 if b.lowMem {
626 b.literalBuf = make([]byte, 0, litRegenSize)
627 } else {
628 b.literalBuf = make([]byte, 0, maxCompressedLiteralSize)
629 }
630 }
631 var err error
632 // Use our out buffer.
633 huff = hist.huffTree
634 if fourStreams {
635 literals, err = huff.Decoder().Decompress4X(b.literalBuf[:0:litRegenSize], literals)
636 } else {
637 literals, err = huff.Decoder().Decompress1X(b.literalBuf[:0:litRegenSize], literals)
638 }
639 // Make sure we don't leak our literals buffer
640 if err != nil {
641 println("decompressing literals:", err)
642 return err
643 }
644 if len(literals) != litRegenSize {
645 return fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals))
646 }
647 } else {
648 if hist.huffTree != nil && huff != nil {
649 if hist.dict == nil || hist.dict.litEnc != hist.huffTree {
650 huffDecoderPool.Put(hist.huffTree)
651 }
652 hist.huffTree = nil
653 }
654 }
655 if huff != nil {
656 hist.huffTree = huff
657 }
658 if debug {
659 println("Final literals:", len(literals), "hash:", xxhash.Sum64(literals), "and", nSeqs, "sequences.")
660 }
661
662 if nSeqs == 0 {
663 // Decompressed content is defined entirely as Literals Section content.
664 b.dst = append(b.dst, literals...)
665 if delayedHistory {
666 hist.append(literals)
667 }
668 return nil
669 }
670
671 seqs, err := seqs.mergeHistory(&hist.decoders)
672 if err != nil {
673 return err
674 }
675 if debug {
676 println("History merged ok")
677 }
678 br := &bitReader{}
679 if err := br.init(in); err != nil {
680 return err
681 }
682
683 // TODO: Investigate if sending history without decoders are faster.
684 // This would allow the sequences to be decoded async and only have to construct stream history.
685 // If only recent offsets were not transferred, this would be an obvious win.
686 // Also, if first 3 sequences don't reference recent offsets, all sequences can be decoded.
687
688 hbytes := hist.b
689 if len(hbytes) > hist.windowSize {
690 hbytes = hbytes[len(hbytes)-hist.windowSize:]
691 // We do not need history any more.
692 if hist.dict != nil {
693 hist.dict.content = nil
694 }
695 }
696
697 if err := seqs.initialize(br, hist, literals, b.dst); err != nil {
698 println("initializing sequences:", err)
699 return err
700 }
701
702 err = seqs.decode(nSeqs, br, hbytes)
703 if err != nil {
704 return err
705 }
706 if !br.finished() {
707 return fmt.Errorf("%d extra bits on block, should be 0", br.remain())
708 }
709
710 err = br.close()
711 if err != nil {
712 printf("Closing sequences: %v, %+v\n", err, *br)
713 }
714 if len(b.data) > maxCompressedBlockSize {
715 return fmt.Errorf("compressed block size too large (%d)", len(b.data))
716 }
717 // Set output and release references.
718 b.dst = seqs.out
719 seqs.out, seqs.literals, seqs.hist = nil, nil, nil
720
721 if !delayedHistory {
722 // If we don't have delayed history, no need to update.
723 hist.recentOffsets = seqs.prevOffset
724 return nil
725 }
726 if b.Last {
727 // if last block we don't care about history.
728 println("Last block, no history returned")
729 hist.b = hist.b[:0]
730 return nil
731 }
732 hist.append(b.dst)
733 hist.recentOffsets = seqs.prevOffset
734 if debug {
735 println("Finished block with literals:", len(literals), "and", nSeqs, "sequences.")
736 }
737
738 return nil
739}