blob: dc587b2c949d80b51d43c30bf3a77f9f0b94b29c [file] [log] [blame]
kesavandc71914f2022-03-25 11:19:03 +05301// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8 "errors"
9 "fmt"
10 "io"
11 "sync"
12
13 "github.com/klauspost/compress/huff0"
14 "github.com/klauspost/compress/zstd/internal/xxhash"
15)
16
17type blockType uint8
18
19//go:generate stringer -type=blockType,literalsBlockType,seqCompMode,tableIndex
20
21const (
22 blockTypeRaw blockType = iota
23 blockTypeRLE
24 blockTypeCompressed
25 blockTypeReserved
26)
27
28type literalsBlockType uint8
29
30const (
31 literalsBlockRaw literalsBlockType = iota
32 literalsBlockRLE
33 literalsBlockCompressed
34 literalsBlockTreeless
35)
36
37const (
38 // maxCompressedBlockSize is the biggest allowed compressed block size (128KB)
39 maxCompressedBlockSize = 128 << 10
40
41 // Maximum possible block size (all Raw+Uncompressed).
42 maxBlockSize = (1 << 21) - 1
43
44 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#literals_section_header
45 maxCompressedLiteralSize = 1 << 18
46 maxRLELiteralSize = 1 << 20
47 maxMatchLen = 131074
48 maxSequences = 0x7f00 + 0xffff
49
50 // We support slightly less than the reference decoder to be able to
51 // use ints on 32 bit archs.
52 maxOffsetBits = 30
53)
54
55var (
56 huffDecoderPool = sync.Pool{New: func() interface{} {
57 return &huff0.Scratch{}
58 }}
59
60 fseDecoderPool = sync.Pool{New: func() interface{} {
61 return &fseDecoder{}
62 }}
63)
64
65type blockDec struct {
66 // Raw source data of the block.
67 data []byte
68 dataStorage []byte
69
70 // Destination of the decoded data.
71 dst []byte
72
73 // Buffer for literals data.
74 literalBuf []byte
75
76 // Window size of the block.
77 WindowSize uint64
78
79 history chan *history
80 input chan struct{}
81 result chan decodeOutput
82 err error
83 decWG sync.WaitGroup
84
85 // Frame to use for singlethreaded decoding.
86 // Should not be used by the decoder itself since parent may be another frame.
87 localFrame *frameDec
88
89 // Block is RLE, this is the size.
90 RLESize uint32
91 tmp [4]byte
92
93 Type blockType
94
95 // Is this the last block of a frame?
96 Last bool
97
98 // Use less memory
99 lowMem bool
100}
101
102func (b *blockDec) String() string {
103 if b == nil {
104 return "<nil>"
105 }
106 return fmt.Sprintf("Steam Size: %d, Type: %v, Last: %t, Window: %d", len(b.data), b.Type, b.Last, b.WindowSize)
107}
108
109func newBlockDec(lowMem bool) *blockDec {
110 b := blockDec{
111 lowMem: lowMem,
112 result: make(chan decodeOutput, 1),
113 input: make(chan struct{}, 1),
114 history: make(chan *history, 1),
115 }
116 b.decWG.Add(1)
117 go b.startDecoder()
118 return &b
119}
120
121// reset will reset the block.
122// Input must be a start of a block and will be at the end of the block when returned.
123func (b *blockDec) reset(br byteBuffer, windowSize uint64) error {
124 b.WindowSize = windowSize
125 tmp, err := br.readSmall(3)
126 if err != nil {
127 println("Reading block header:", err)
128 return err
129 }
130 bh := uint32(tmp[0]) | (uint32(tmp[1]) << 8) | (uint32(tmp[2]) << 16)
131 b.Last = bh&1 != 0
132 b.Type = blockType((bh >> 1) & 3)
133 // find size.
134 cSize := int(bh >> 3)
135 maxSize := maxBlockSize
136 switch b.Type {
137 case blockTypeReserved:
138 return ErrReservedBlockType
139 case blockTypeRLE:
140 b.RLESize = uint32(cSize)
141 if b.lowMem {
142 maxSize = cSize
143 }
144 cSize = 1
145 case blockTypeCompressed:
146 if debugDecoder {
147 println("Data size on stream:", cSize)
148 }
149 b.RLESize = 0
150 maxSize = maxCompressedBlockSize
151 if windowSize < maxCompressedBlockSize && b.lowMem {
152 maxSize = int(windowSize)
153 }
154 if cSize > maxCompressedBlockSize || uint64(cSize) > b.WindowSize {
155 if debugDecoder {
156 printf("compressed block too big: csize:%d block: %+v\n", uint64(cSize), b)
157 }
158 return ErrCompressedSizeTooBig
159 }
160 case blockTypeRaw:
161 b.RLESize = 0
162 // We do not need a destination for raw blocks.
163 maxSize = -1
164 default:
165 panic("Invalid block type")
166 }
167
168 // Read block data.
169 if cap(b.dataStorage) < cSize {
170 if b.lowMem || cSize > maxCompressedBlockSize {
171 b.dataStorage = make([]byte, 0, cSize)
172 } else {
173 b.dataStorage = make([]byte, 0, maxCompressedBlockSize)
174 }
175 }
176 if cap(b.dst) <= maxSize {
177 b.dst = make([]byte, 0, maxSize+1)
178 }
179 b.data, err = br.readBig(cSize, b.dataStorage)
180 if err != nil {
181 if debugDecoder {
182 println("Reading block:", err, "(", cSize, ")", len(b.data))
183 printf("%T", br)
184 }
185 return err
186 }
187 return nil
188}
189
190// sendEOF will make the decoder send EOF on this frame.
191func (b *blockDec) sendErr(err error) {
192 b.Last = true
193 b.Type = blockTypeReserved
194 b.err = err
195 b.input <- struct{}{}
196}
197
198// Close will release resources.
199// Closed blockDec cannot be reset.
200func (b *blockDec) Close() {
201 close(b.input)
202 close(b.history)
203 close(b.result)
204 b.decWG.Wait()
205}
206
207// decodeAsync will prepare decoding the block when it receives input.
208// This will separate output and history.
209func (b *blockDec) startDecoder() {
210 defer b.decWG.Done()
211 for range b.input {
212 //println("blockDec: Got block input")
213 switch b.Type {
214 case blockTypeRLE:
215 if cap(b.dst) < int(b.RLESize) {
216 if b.lowMem {
217 b.dst = make([]byte, b.RLESize)
218 } else {
219 b.dst = make([]byte, maxBlockSize)
220 }
221 }
222 o := decodeOutput{
223 d: b,
224 b: b.dst[:b.RLESize],
225 err: nil,
226 }
227 v := b.data[0]
228 for i := range o.b {
229 o.b[i] = v
230 }
231 hist := <-b.history
232 hist.append(o.b)
233 b.result <- o
234 case blockTypeRaw:
235 o := decodeOutput{
236 d: b,
237 b: b.data,
238 err: nil,
239 }
240 hist := <-b.history
241 hist.append(o.b)
242 b.result <- o
243 case blockTypeCompressed:
244 b.dst = b.dst[:0]
245 err := b.decodeCompressed(nil)
246 o := decodeOutput{
247 d: b,
248 b: b.dst,
249 err: err,
250 }
251 if debugDecoder {
252 println("Decompressed to", len(b.dst), "bytes, error:", err)
253 }
254 b.result <- o
255 case blockTypeReserved:
256 // Used for returning errors.
257 <-b.history
258 b.result <- decodeOutput{
259 d: b,
260 b: nil,
261 err: b.err,
262 }
263 default:
264 panic("Invalid block type")
265 }
266 if debugDecoder {
267 println("blockDec: Finished block")
268 }
269 }
270}
271
272// decodeAsync will prepare decoding the block when it receives the history.
273// If history is provided, it will not fetch it from the channel.
274func (b *blockDec) decodeBuf(hist *history) error {
275 switch b.Type {
276 case blockTypeRLE:
277 if cap(b.dst) < int(b.RLESize) {
278 if b.lowMem {
279 b.dst = make([]byte, b.RLESize)
280 } else {
281 b.dst = make([]byte, maxBlockSize)
282 }
283 }
284 b.dst = b.dst[:b.RLESize]
285 v := b.data[0]
286 for i := range b.dst {
287 b.dst[i] = v
288 }
289 hist.appendKeep(b.dst)
290 return nil
291 case blockTypeRaw:
292 hist.appendKeep(b.data)
293 return nil
294 case blockTypeCompressed:
295 saved := b.dst
296 b.dst = hist.b
297 hist.b = nil
298 err := b.decodeCompressed(hist)
299 if debugDecoder {
300 println("Decompressed to total", len(b.dst), "bytes, hash:", xxhash.Sum64(b.dst), "error:", err)
301 }
302 hist.b = b.dst
303 b.dst = saved
304 return err
305 case blockTypeReserved:
306 // Used for returning errors.
307 return b.err
308 default:
309 panic("Invalid block type")
310 }
311}
312
313// decodeCompressed will start decompressing a block.
314// If no history is supplied the decoder will decodeAsync as much as possible
315// before fetching from blockDec.history
316func (b *blockDec) decodeCompressed(hist *history) error {
317 in := b.data
318 delayedHistory := hist == nil
319
320 if delayedHistory {
321 // We must always grab history.
322 defer func() {
323 if hist == nil {
324 <-b.history
325 }
326 }()
327 }
328 // There must be at least one byte for Literals_Block_Type and one for Sequences_Section_Header
329 if len(in) < 2 {
330 return ErrBlockTooSmall
331 }
332 litType := literalsBlockType(in[0] & 3)
333 var litRegenSize int
334 var litCompSize int
335 sizeFormat := (in[0] >> 2) & 3
336 var fourStreams bool
337 switch litType {
338 case literalsBlockRaw, literalsBlockRLE:
339 switch sizeFormat {
340 case 0, 2:
341 // Regenerated_Size uses 5 bits (0-31). Literals_Section_Header uses 1 byte.
342 litRegenSize = int(in[0] >> 3)
343 in = in[1:]
344 case 1:
345 // Regenerated_Size uses 12 bits (0-4095). Literals_Section_Header uses 2 bytes.
346 litRegenSize = int(in[0]>>4) + (int(in[1]) << 4)
347 in = in[2:]
348 case 3:
349 // Regenerated_Size uses 20 bits (0-1048575). Literals_Section_Header uses 3 bytes.
350 if len(in) < 3 {
351 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
352 return ErrBlockTooSmall
353 }
354 litRegenSize = int(in[0]>>4) + (int(in[1]) << 4) + (int(in[2]) << 12)
355 in = in[3:]
356 }
357 case literalsBlockCompressed, literalsBlockTreeless:
358 switch sizeFormat {
359 case 0, 1:
360 // Both Regenerated_Size and Compressed_Size use 10 bits (0-1023).
361 if len(in) < 3 {
362 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
363 return ErrBlockTooSmall
364 }
365 n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12)
366 litRegenSize = int(n & 1023)
367 litCompSize = int(n >> 10)
368 fourStreams = sizeFormat == 1
369 in = in[3:]
370 case 2:
371 fourStreams = true
372 if len(in) < 4 {
373 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
374 return ErrBlockTooSmall
375 }
376 n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) + (uint64(in[3]) << 20)
377 litRegenSize = int(n & 16383)
378 litCompSize = int(n >> 14)
379 in = in[4:]
380 case 3:
381 fourStreams = true
382 if len(in) < 5 {
383 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
384 return ErrBlockTooSmall
385 }
386 n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) + (uint64(in[3]) << 20) + (uint64(in[4]) << 28)
387 litRegenSize = int(n & 262143)
388 litCompSize = int(n >> 18)
389 in = in[5:]
390 }
391 }
392 if debugDecoder {
393 println("literals type:", litType, "litRegenSize:", litRegenSize, "litCompSize:", litCompSize, "sizeFormat:", sizeFormat, "4X:", fourStreams)
394 }
395 var literals []byte
396 var huff *huff0.Scratch
397 switch litType {
398 case literalsBlockRaw:
399 if len(in) < litRegenSize {
400 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litRegenSize)
401 return ErrBlockTooSmall
402 }
403 literals = in[:litRegenSize]
404 in = in[litRegenSize:]
405 //printf("Found %d uncompressed literals\n", litRegenSize)
406 case literalsBlockRLE:
407 if len(in) < 1 {
408 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", 1)
409 return ErrBlockTooSmall
410 }
411 if cap(b.literalBuf) < litRegenSize {
412 if b.lowMem {
413 b.literalBuf = make([]byte, litRegenSize)
414 } else {
415 if litRegenSize > maxCompressedLiteralSize {
416 // Exceptional
417 b.literalBuf = make([]byte, litRegenSize)
418 } else {
419 b.literalBuf = make([]byte, litRegenSize, maxCompressedLiteralSize)
420
421 }
422 }
423 }
424 literals = b.literalBuf[:litRegenSize]
425 v := in[0]
426 for i := range literals {
427 literals[i] = v
428 }
429 in = in[1:]
430 if debugDecoder {
431 printf("Found %d RLE compressed literals\n", litRegenSize)
432 }
433 case literalsBlockTreeless:
434 if len(in) < litCompSize {
435 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litCompSize)
436 return ErrBlockTooSmall
437 }
438 // Store compressed literals, so we defer decoding until we get history.
439 literals = in[:litCompSize]
440 in = in[litCompSize:]
441 if debugDecoder {
442 printf("Found %d compressed literals\n", litCompSize)
443 }
444 case literalsBlockCompressed:
445 if len(in) < litCompSize {
446 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litCompSize)
447 return ErrBlockTooSmall
448 }
449 literals = in[:litCompSize]
450 in = in[litCompSize:]
451 huff = huffDecoderPool.Get().(*huff0.Scratch)
452 var err error
453 // Ensure we have space to store it.
454 if cap(b.literalBuf) < litRegenSize {
455 if b.lowMem {
456 b.literalBuf = make([]byte, 0, litRegenSize)
457 } else {
458 b.literalBuf = make([]byte, 0, maxCompressedLiteralSize)
459 }
460 }
461 if huff == nil {
462 huff = &huff0.Scratch{}
463 }
464 huff, literals, err = huff0.ReadTable(literals, huff)
465 if err != nil {
466 println("reading huffman table:", err)
467 return err
468 }
469 // Use our out buffer.
470 if fourStreams {
471 literals, err = huff.Decoder().Decompress4X(b.literalBuf[:0:litRegenSize], literals)
472 } else {
473 literals, err = huff.Decoder().Decompress1X(b.literalBuf[:0:litRegenSize], literals)
474 }
475 if err != nil {
476 println("decoding compressed literals:", err)
477 return err
478 }
479 // Make sure we don't leak our literals buffer
480 if len(literals) != litRegenSize {
481 return fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals))
482 }
483 if debugDecoder {
484 printf("Decompressed %d literals into %d bytes\n", litCompSize, litRegenSize)
485 }
486 }
487
488 // Decode Sequences
489 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#sequences-section
490 if len(in) < 1 {
491 return ErrBlockTooSmall
492 }
493 seqHeader := in[0]
494 nSeqs := 0
495 switch {
496 case seqHeader == 0:
497 in = in[1:]
498 case seqHeader < 128:
499 nSeqs = int(seqHeader)
500 in = in[1:]
501 case seqHeader < 255:
502 if len(in) < 2 {
503 return ErrBlockTooSmall
504 }
505 nSeqs = int(seqHeader-128)<<8 | int(in[1])
506 in = in[2:]
507 case seqHeader == 255:
508 if len(in) < 3 {
509 return ErrBlockTooSmall
510 }
511 nSeqs = 0x7f00 + int(in[1]) + (int(in[2]) << 8)
512 in = in[3:]
513 }
514
515 var seqs = &sequenceDecs{}
516 if nSeqs > 0 {
517 if len(in) < 1 {
518 return ErrBlockTooSmall
519 }
520 br := byteReader{b: in, off: 0}
521 compMode := br.Uint8()
522 br.advance(1)
523 if debugDecoder {
524 printf("Compression modes: 0b%b", compMode)
525 }
526 for i := uint(0); i < 3; i++ {
527 mode := seqCompMode((compMode >> (6 - i*2)) & 3)
528 if debugDecoder {
529 println("Table", tableIndex(i), "is", mode)
530 }
531 var seq *sequenceDec
532 switch tableIndex(i) {
533 case tableLiteralLengths:
534 seq = &seqs.litLengths
535 case tableOffsets:
536 seq = &seqs.offsets
537 case tableMatchLengths:
538 seq = &seqs.matchLengths
539 default:
540 panic("unknown table")
541 }
542 switch mode {
543 case compModePredefined:
544 seq.fse = &fsePredef[i]
545 case compModeRLE:
546 if br.remain() < 1 {
547 return ErrBlockTooSmall
548 }
549 v := br.Uint8()
550 br.advance(1)
551 dec := fseDecoderPool.Get().(*fseDecoder)
552 symb, err := decSymbolValue(v, symbolTableX[i])
553 if err != nil {
554 printf("RLE Transform table (%v) error: %v", tableIndex(i), err)
555 return err
556 }
557 dec.setRLE(symb)
558 seq.fse = dec
559 if debugDecoder {
560 printf("RLE set to %+v, code: %v", symb, v)
561 }
562 case compModeFSE:
563 println("Reading table for", tableIndex(i))
564 dec := fseDecoderPool.Get().(*fseDecoder)
565 err := dec.readNCount(&br, uint16(maxTableSymbol[i]))
566 if err != nil {
567 println("Read table error:", err)
568 return err
569 }
570 err = dec.transform(symbolTableX[i])
571 if err != nil {
572 println("Transform table error:", err)
573 return err
574 }
575 if debugDecoder {
576 println("Read table ok", "symbolLen:", dec.symbolLen)
577 }
578 seq.fse = dec
579 case compModeRepeat:
580 seq.repeat = true
581 }
582 if br.overread() {
583 return io.ErrUnexpectedEOF
584 }
585 }
586 in = br.unread()
587 }
588
589 // Wait for history.
590 // All time spent after this is critical since it is strictly sequential.
591 if hist == nil {
592 hist = <-b.history
593 if hist.error {
594 return ErrDecoderClosed
595 }
596 }
597
598 // Decode treeless literal block.
599 if litType == literalsBlockTreeless {
600 // TODO: We could send the history early WITHOUT the stream history.
601 // This would allow decoding treeless literals before the byte history is available.
602 // Silencia stats: Treeless 4393, with: 32775, total: 37168, 11% treeless.
603 // So not much obvious gain here.
604
605 if hist.huffTree == nil {
606 return errors.New("literal block was treeless, but no history was defined")
607 }
608 // Ensure we have space to store it.
609 if cap(b.literalBuf) < litRegenSize {
610 if b.lowMem {
611 b.literalBuf = make([]byte, 0, litRegenSize)
612 } else {
613 b.literalBuf = make([]byte, 0, maxCompressedLiteralSize)
614 }
615 }
616 var err error
617 // Use our out buffer.
618 huff = hist.huffTree
619 if fourStreams {
620 literals, err = huff.Decoder().Decompress4X(b.literalBuf[:0:litRegenSize], literals)
621 } else {
622 literals, err = huff.Decoder().Decompress1X(b.literalBuf[:0:litRegenSize], literals)
623 }
624 // Make sure we don't leak our literals buffer
625 if err != nil {
626 println("decompressing literals:", err)
627 return err
628 }
629 if len(literals) != litRegenSize {
630 return fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals))
631 }
632 } else {
633 if hist.huffTree != nil && huff != nil {
634 if hist.dict == nil || hist.dict.litEnc != hist.huffTree {
635 huffDecoderPool.Put(hist.huffTree)
636 }
637 hist.huffTree = nil
638 }
639 }
640 if huff != nil {
641 hist.huffTree = huff
642 }
643 if debugDecoder {
644 println("Final literals:", len(literals), "hash:", xxhash.Sum64(literals), "and", nSeqs, "sequences.")
645 }
646
647 if nSeqs == 0 {
648 // Decompressed content is defined entirely as Literals Section content.
649 b.dst = append(b.dst, literals...)
650 if delayedHistory {
651 hist.append(literals)
652 }
653 return nil
654 }
655
656 seqs, err := seqs.mergeHistory(&hist.decoders)
657 if err != nil {
658 return err
659 }
660 if debugDecoder {
661 println("History merged ok")
662 }
663 br := &bitReader{}
664 if err := br.init(in); err != nil {
665 return err
666 }
667
668 // TODO: Investigate if sending history without decoders are faster.
669 // This would allow the sequences to be decoded async and only have to construct stream history.
670 // If only recent offsets were not transferred, this would be an obvious win.
671 // Also, if first 3 sequences don't reference recent offsets, all sequences can be decoded.
672
673 hbytes := hist.b
674 if len(hbytes) > hist.windowSize {
675 hbytes = hbytes[len(hbytes)-hist.windowSize:]
676 // We do not need history any more.
677 if hist.dict != nil {
678 hist.dict.content = nil
679 }
680 }
681
682 if err := seqs.initialize(br, hist, literals, b.dst); err != nil {
683 println("initializing sequences:", err)
684 return err
685 }
686
687 err = seqs.decode(nSeqs, br, hbytes)
688 if err != nil {
689 return err
690 }
691 if !br.finished() {
692 return fmt.Errorf("%d extra bits on block, should be 0", br.remain())
693 }
694
695 err = br.close()
696 if err != nil {
697 printf("Closing sequences: %v, %+v\n", err, *br)
698 }
699 if len(b.data) > maxCompressedBlockSize {
700 return fmt.Errorf("compressed block size too large (%d)", len(b.data))
701 }
702 // Set output and release references.
703 b.dst = seqs.out
704 seqs.out, seqs.literals, seqs.hist = nil, nil, nil
705
706 if !delayedHistory {
707 // If we don't have delayed history, no need to update.
708 hist.recentOffsets = seqs.prevOffset
709 return nil
710 }
711 if b.Last {
712 // if last block we don't care about history.
713 println("Last block, no history returned")
714 hist.b = hist.b[:0]
715 return nil
716 }
717 hist.append(b.dst)
718 hist.recentOffsets = seqs.prevOffset
719 if debugDecoder {
720 println("Finished block with literals:", len(literals), "and", nSeqs, "sequences.")
721 }
722
723 return nil
724}