blob: ed670bcc7ad475e681992809a2516d41054fd41b [file] [log] [blame]
Scott Bakered4efab2020-01-13 19:12:25 -08001// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8 "errors"
9 "fmt"
10 "io"
11 "sync"
12
13 "github.com/klauspost/compress/huff0"
14 "github.com/klauspost/compress/zstd/internal/xxhash"
15)
16
17type blockType uint8
18
19//go:generate stringer -type=blockType,literalsBlockType,seqCompMode,tableIndex
20
21const (
22 blockTypeRaw blockType = iota
23 blockTypeRLE
24 blockTypeCompressed
25 blockTypeReserved
26)
27
28type literalsBlockType uint8
29
30const (
31 literalsBlockRaw literalsBlockType = iota
32 literalsBlockRLE
33 literalsBlockCompressed
34 literalsBlockTreeless
35)
36
37const (
38 // maxCompressedBlockSize is the biggest allowed compressed block size (128KB)
39 maxCompressedBlockSize = 128 << 10
40
41 // Maximum possible block size (all Raw+Uncompressed).
42 maxBlockSize = (1 << 21) - 1
43
44 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#literals_section_header
45 maxCompressedLiteralSize = 1 << 18
46 maxRLELiteralSize = 1 << 20
47 maxMatchLen = 131074
48 maxSequences = 0x7f00 + 0xffff
49
50 // We support slightly less than the reference decoder to be able to
51 // use ints on 32 bit archs.
52 maxOffsetBits = 30
53)
54
55var (
56 huffDecoderPool = sync.Pool{New: func() interface{} {
57 return &huff0.Scratch{}
58 }}
59
60 fseDecoderPool = sync.Pool{New: func() interface{} {
61 return &fseDecoder{}
62 }}
63)
64
65type blockDec struct {
66 // Raw source data of the block.
67 data []byte
68 dataStorage []byte
69
70 // Destination of the decoded data.
71 dst []byte
72
73 // Buffer for literals data.
74 literalBuf []byte
75
76 // Window size of the block.
77 WindowSize uint64
78 Type blockType
79 RLESize uint32
80
81 // Is this the last block of a frame?
82 Last bool
83
84 // Use less memory
85 lowMem bool
86 history chan *history
87 input chan struct{}
88 result chan decodeOutput
89 sequenceBuf []seq
90 tmp [4]byte
91 err error
92 decWG sync.WaitGroup
93}
94
95func (b *blockDec) String() string {
96 if b == nil {
97 return "<nil>"
98 }
99 return fmt.Sprintf("Steam Size: %d, Type: %v, Last: %t, Window: %d", len(b.data), b.Type, b.Last, b.WindowSize)
100}
101
102func newBlockDec(lowMem bool) *blockDec {
103 b := blockDec{
104 lowMem: lowMem,
105 result: make(chan decodeOutput, 1),
106 input: make(chan struct{}, 1),
107 history: make(chan *history, 1),
108 }
109 b.decWG.Add(1)
110 go b.startDecoder()
111 return &b
112}
113
114// reset will reset the block.
115// Input must be a start of a block and will be at the end of the block when returned.
116func (b *blockDec) reset(br byteBuffer, windowSize uint64) error {
117 b.WindowSize = windowSize
118 tmp := br.readSmall(3)
119 if tmp == nil {
120 if debug {
121 println("Reading block header:", io.ErrUnexpectedEOF)
122 }
123 return io.ErrUnexpectedEOF
124 }
125 bh := uint32(tmp[0]) | (uint32(tmp[1]) << 8) | (uint32(tmp[2]) << 16)
126 b.Last = bh&1 != 0
127 b.Type = blockType((bh >> 1) & 3)
128 // find size.
129 cSize := int(bh >> 3)
130 switch b.Type {
131 case blockTypeReserved:
132 return ErrReservedBlockType
133 case blockTypeRLE:
134 b.RLESize = uint32(cSize)
135 cSize = 1
136 case blockTypeCompressed:
137 if debug {
138 println("Data size on stream:", cSize)
139 }
140 b.RLESize = 0
141 if cSize > maxCompressedBlockSize || uint64(cSize) > b.WindowSize {
142 if debug {
143 printf("compressed block too big: csize:%d block: %+v\n", uint64(cSize), b)
144 }
145 return ErrCompressedSizeTooBig
146 }
147 default:
148 b.RLESize = 0
149 }
150
151 // Read block data.
152 if cap(b.dataStorage) < cSize {
153 if b.lowMem {
154 b.dataStorage = make([]byte, 0, cSize)
155 } else {
156 b.dataStorage = make([]byte, 0, maxBlockSize)
157 }
158 }
159 if cap(b.dst) <= maxBlockSize {
160 b.dst = make([]byte, 0, maxBlockSize+1)
161 }
162 var err error
163 b.data, err = br.readBig(cSize, b.dataStorage)
164 if err != nil {
165 if debug {
166 println("Reading block:", err, "(", cSize, ")", len(b.data))
167 printf("%T", br)
168 }
169 return err
170 }
171 return nil
172}
173
174// sendEOF will make the decoder send EOF on this frame.
175func (b *blockDec) sendErr(err error) {
176 b.Last = true
177 b.Type = blockTypeReserved
178 b.err = err
179 b.input <- struct{}{}
180}
181
182// Close will release resources.
183// Closed blockDec cannot be reset.
184func (b *blockDec) Close() {
185 close(b.input)
186 close(b.history)
187 close(b.result)
188 b.decWG.Wait()
189}
190
191// decodeAsync will prepare decoding the block when it receives input.
192// This will separate output and history.
193func (b *blockDec) startDecoder() {
194 defer b.decWG.Done()
195 for range b.input {
196 //println("blockDec: Got block input")
197 switch b.Type {
198 case blockTypeRLE:
199 if cap(b.dst) < int(b.RLESize) {
200 if b.lowMem {
201 b.dst = make([]byte, b.RLESize)
202 } else {
203 b.dst = make([]byte, maxBlockSize)
204 }
205 }
206 o := decodeOutput{
207 d: b,
208 b: b.dst[:b.RLESize],
209 err: nil,
210 }
211 v := b.data[0]
212 for i := range o.b {
213 o.b[i] = v
214 }
215 hist := <-b.history
216 hist.append(o.b)
217 b.result <- o
218 case blockTypeRaw:
219 o := decodeOutput{
220 d: b,
221 b: b.data,
222 err: nil,
223 }
224 hist := <-b.history
225 hist.append(o.b)
226 b.result <- o
227 case blockTypeCompressed:
228 b.dst = b.dst[:0]
229 err := b.decodeCompressed(nil)
230 o := decodeOutput{
231 d: b,
232 b: b.dst,
233 err: err,
234 }
235 if debug {
236 println("Decompressed to", len(b.dst), "bytes, error:", err)
237 }
238 b.result <- o
239 case blockTypeReserved:
240 // Used for returning errors.
241 <-b.history
242 b.result <- decodeOutput{
243 d: b,
244 b: nil,
245 err: b.err,
246 }
247 default:
248 panic("Invalid block type")
249 }
250 if debug {
251 println("blockDec: Finished block")
252 }
253 }
254}
255
256// decodeAsync will prepare decoding the block when it receives the history.
257// If history is provided, it will not fetch it from the channel.
258func (b *blockDec) decodeBuf(hist *history) error {
259 switch b.Type {
260 case blockTypeRLE:
261 if cap(b.dst) < int(b.RLESize) {
262 if b.lowMem {
263 b.dst = make([]byte, b.RLESize)
264 } else {
265 b.dst = make([]byte, maxBlockSize)
266 }
267 }
268 b.dst = b.dst[:b.RLESize]
269 v := b.data[0]
270 for i := range b.dst {
271 b.dst[i] = v
272 }
273 hist.appendKeep(b.dst)
274 return nil
275 case blockTypeRaw:
276 hist.appendKeep(b.data)
277 return nil
278 case blockTypeCompressed:
279 saved := b.dst
280 b.dst = hist.b
281 hist.b = nil
282 err := b.decodeCompressed(hist)
283 if debug {
284 println("Decompressed to total", len(b.dst), "bytes, hash:", xxhash.Sum64(b.dst), "error:", err)
285 }
286 hist.b = b.dst
287 b.dst = saved
288 return err
289 case blockTypeReserved:
290 // Used for returning errors.
291 return b.err
292 default:
293 panic("Invalid block type")
294 }
295}
296
297// decodeCompressed will start decompressing a block.
298// If no history is supplied the decoder will decodeAsync as much as possible
299// before fetching from blockDec.history
300func (b *blockDec) decodeCompressed(hist *history) error {
301 in := b.data
302 delayedHistory := hist == nil
303
304 if delayedHistory {
305 // We must always grab history.
306 defer func() {
307 if hist == nil {
308 <-b.history
309 }
310 }()
311 }
312 // There must be at least one byte for Literals_Block_Type and one for Sequences_Section_Header
313 if len(in) < 2 {
314 return ErrBlockTooSmall
315 }
316 litType := literalsBlockType(in[0] & 3)
317 var litRegenSize int
318 var litCompSize int
319 sizeFormat := (in[0] >> 2) & 3
320 var fourStreams bool
321 switch litType {
322 case literalsBlockRaw, literalsBlockRLE:
323 switch sizeFormat {
324 case 0, 2:
325 // Regenerated_Size uses 5 bits (0-31). Literals_Section_Header uses 1 byte.
326 litRegenSize = int(in[0] >> 3)
327 in = in[1:]
328 case 1:
329 // Regenerated_Size uses 12 bits (0-4095). Literals_Section_Header uses 2 bytes.
330 litRegenSize = int(in[0]>>4) + (int(in[1]) << 4)
331 in = in[2:]
332 case 3:
333 // Regenerated_Size uses 20 bits (0-1048575). Literals_Section_Header uses 3 bytes.
334 if len(in) < 3 {
335 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
336 return ErrBlockTooSmall
337 }
338 litRegenSize = int(in[0]>>4) + (int(in[1]) << 4) + (int(in[2]) << 12)
339 in = in[3:]
340 }
341 case literalsBlockCompressed, literalsBlockTreeless:
342 switch sizeFormat {
343 case 0, 1:
344 // Both Regenerated_Size and Compressed_Size use 10 bits (0-1023).
345 if len(in) < 3 {
346 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
347 return ErrBlockTooSmall
348 }
349 n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12)
350 litRegenSize = int(n & 1023)
351 litCompSize = int(n >> 10)
352 fourStreams = sizeFormat == 1
353 in = in[3:]
354 case 2:
355 fourStreams = true
356 if len(in) < 4 {
357 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
358 return ErrBlockTooSmall
359 }
360 n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) + (uint64(in[3]) << 20)
361 litRegenSize = int(n & 16383)
362 litCompSize = int(n >> 14)
363 in = in[4:]
364 case 3:
365 fourStreams = true
366 if len(in) < 5 {
367 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
368 return ErrBlockTooSmall
369 }
370 n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) + (uint64(in[3]) << 20) + (uint64(in[4]) << 28)
371 litRegenSize = int(n & 262143)
372 litCompSize = int(n >> 18)
373 in = in[5:]
374 }
375 }
376 if debug {
377 println("literals type:", litType, "litRegenSize:", litRegenSize, "litCompSize:", litCompSize, "sizeFormat:", sizeFormat, "4X:", fourStreams)
378 }
379 var literals []byte
380 var huff *huff0.Scratch
381 switch litType {
382 case literalsBlockRaw:
383 if len(in) < litRegenSize {
384 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litRegenSize)
385 return ErrBlockTooSmall
386 }
387 literals = in[:litRegenSize]
388 in = in[litRegenSize:]
389 //printf("Found %d uncompressed literals\n", litRegenSize)
390 case literalsBlockRLE:
391 if len(in) < 1 {
392 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", 1)
393 return ErrBlockTooSmall
394 }
395 if cap(b.literalBuf) < litRegenSize {
396 if b.lowMem {
397 b.literalBuf = make([]byte, litRegenSize)
398 } else {
399 if litRegenSize > maxCompressedLiteralSize {
400 // Exceptional
401 b.literalBuf = make([]byte, litRegenSize)
402 } else {
403 b.literalBuf = make([]byte, litRegenSize, maxCompressedLiteralSize)
404
405 }
406 }
407 }
408 literals = b.literalBuf[:litRegenSize]
409 v := in[0]
410 for i := range literals {
411 literals[i] = v
412 }
413 in = in[1:]
414 if debug {
415 printf("Found %d RLE compressed literals\n", litRegenSize)
416 }
417 case literalsBlockTreeless:
418 if len(in) < litCompSize {
419 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litCompSize)
420 return ErrBlockTooSmall
421 }
422 // Store compressed literals, so we defer decoding until we get history.
423 literals = in[:litCompSize]
424 in = in[litCompSize:]
425 if debug {
426 printf("Found %d compressed literals\n", litCompSize)
427 }
428 case literalsBlockCompressed:
429 if len(in) < litCompSize {
430 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litCompSize)
431 return ErrBlockTooSmall
432 }
433 literals = in[:litCompSize]
434 in = in[litCompSize:]
435 huff = huffDecoderPool.Get().(*huff0.Scratch)
436 var err error
437 // Ensure we have space to store it.
438 if cap(b.literalBuf) < litRegenSize {
439 if b.lowMem {
440 b.literalBuf = make([]byte, 0, litRegenSize)
441 } else {
442 b.literalBuf = make([]byte, 0, maxCompressedLiteralSize)
443 }
444 }
445 if huff == nil {
446 huff = &huff0.Scratch{}
447 }
448 huff.Out = b.literalBuf[:0]
449 huff, literals, err = huff0.ReadTable(literals, huff)
450 if err != nil {
451 println("reading huffman table:", err)
452 return err
453 }
454 // Use our out buffer.
455 huff.Out = b.literalBuf[:0]
456 huff.MaxDecodedSize = litRegenSize
457 if fourStreams {
458 literals, err = huff.Decompress4X(literals, litRegenSize)
459 } else {
460 literals, err = huff.Decompress1X(literals)
461 }
462 if err != nil {
463 println("decoding compressed literals:", err)
464 return err
465 }
466 // Make sure we don't leak our literals buffer
467 huff.Out = nil
468 if len(literals) != litRegenSize {
469 return fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals))
470 }
471 if debug {
472 printf("Decompressed %d literals into %d bytes\n", litCompSize, litRegenSize)
473 }
474 }
475
476 // Decode Sequences
477 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#sequences-section
478 if len(in) < 1 {
479 return ErrBlockTooSmall
480 }
481 seqHeader := in[0]
482 nSeqs := 0
483 switch {
484 case seqHeader == 0:
485 in = in[1:]
486 case seqHeader < 128:
487 nSeqs = int(seqHeader)
488 in = in[1:]
489 case seqHeader < 255:
490 if len(in) < 2 {
491 return ErrBlockTooSmall
492 }
493 nSeqs = int(seqHeader-128)<<8 | int(in[1])
494 in = in[2:]
495 case seqHeader == 255:
496 if len(in) < 3 {
497 return ErrBlockTooSmall
498 }
499 nSeqs = 0x7f00 + int(in[1]) + (int(in[2]) << 8)
500 in = in[3:]
501 }
502 // Allocate sequences
503 if cap(b.sequenceBuf) < nSeqs {
504 if b.lowMem {
505 b.sequenceBuf = make([]seq, nSeqs)
506 } else {
507 // Allocate max
508 b.sequenceBuf = make([]seq, nSeqs, maxSequences)
509 }
510 } else {
511 // Reuse buffer
512 b.sequenceBuf = b.sequenceBuf[:nSeqs]
513 }
514 var seqs = &sequenceDecs{}
515 if nSeqs > 0 {
516 if len(in) < 1 {
517 return ErrBlockTooSmall
518 }
519 br := byteReader{b: in, off: 0}
520 compMode := br.Uint8()
521 br.advance(1)
522 if debug {
523 printf("Compression modes: 0b%b", compMode)
524 }
525 for i := uint(0); i < 3; i++ {
526 mode := seqCompMode((compMode >> (6 - i*2)) & 3)
527 if debug {
528 println("Table", tableIndex(i), "is", mode)
529 }
530 var seq *sequenceDec
531 switch tableIndex(i) {
532 case tableLiteralLengths:
533 seq = &seqs.litLengths
534 case tableOffsets:
535 seq = &seqs.offsets
536 case tableMatchLengths:
537 seq = &seqs.matchLengths
538 default:
539 panic("unknown table")
540 }
541 switch mode {
542 case compModePredefined:
543 seq.fse = &fsePredef[i]
544 case compModeRLE:
545 if br.remain() < 1 {
546 return ErrBlockTooSmall
547 }
548 v := br.Uint8()
549 br.advance(1)
550 dec := fseDecoderPool.Get().(*fseDecoder)
551 symb, err := decSymbolValue(v, symbolTableX[i])
552 if err != nil {
553 printf("RLE Transform table (%v) error: %v", tableIndex(i), err)
554 return err
555 }
556 dec.setRLE(symb)
557 seq.fse = dec
558 if debug {
559 printf("RLE set to %+v, code: %v", symb, v)
560 }
561 case compModeFSE:
562 println("Reading table for", tableIndex(i))
563 dec := fseDecoderPool.Get().(*fseDecoder)
564 err := dec.readNCount(&br, uint16(maxTableSymbol[i]))
565 if err != nil {
566 println("Read table error:", err)
567 return err
568 }
569 err = dec.transform(symbolTableX[i])
570 if err != nil {
571 println("Transform table error:", err)
572 return err
573 }
574 if debug {
575 println("Read table ok", "symbolLen:", dec.symbolLen)
576 }
577 seq.fse = dec
578 case compModeRepeat:
579 seq.repeat = true
580 }
581 if br.overread() {
582 return io.ErrUnexpectedEOF
583 }
584 }
585 in = br.unread()
586 }
587
588 // Wait for history.
589 // All time spent after this is critical since it is strictly sequential.
590 if hist == nil {
591 hist = <-b.history
592 if hist.error {
593 return ErrDecoderClosed
594 }
595 }
596
597 // Decode treeless literal block.
598 if litType == literalsBlockTreeless {
599 // TODO: We could send the history early WITHOUT the stream history.
600 // This would allow decoding treeless literials before the byte history is available.
601 // Silencia stats: Treeless 4393, with: 32775, total: 37168, 11% treeless.
602 // So not much obvious gain here.
603
604 if hist.huffTree == nil {
605 return errors.New("literal block was treeless, but no history was defined")
606 }
607 // Ensure we have space to store it.
608 if cap(b.literalBuf) < litRegenSize {
609 if b.lowMem {
610 b.literalBuf = make([]byte, 0, litRegenSize)
611 } else {
612 b.literalBuf = make([]byte, 0, maxCompressedLiteralSize)
613 }
614 }
615 var err error
616 // Use our out buffer.
617 huff = hist.huffTree
618 huff.Out = b.literalBuf[:0]
619 huff.MaxDecodedSize = litRegenSize
620 if fourStreams {
621 literals, err = huff.Decompress4X(literals, litRegenSize)
622 } else {
623 literals, err = huff.Decompress1X(literals)
624 }
625 // Make sure we don't leak our literals buffer
626 huff.Out = nil
627 if err != nil {
628 println("decompressing literals:", err)
629 return err
630 }
631 if len(literals) != litRegenSize {
632 return fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals))
633 }
634 } else {
635 if hist.huffTree != nil && huff != nil {
636 huffDecoderPool.Put(hist.huffTree)
637 hist.huffTree = nil
638 }
639 }
640 if huff != nil {
641 huff.Out = nil
642 hist.huffTree = huff
643 }
644 if debug {
645 println("Final literals:", len(literals), "hash:", xxhash.Sum64(literals), "and", nSeqs, "sequences.")
646 }
647
648 if nSeqs == 0 {
649 // Decompressed content is defined entirely as Literals Section content.
650 b.dst = append(b.dst, literals...)
651 if delayedHistory {
652 hist.append(literals)
653 }
654 return nil
655 }
656
657 seqs, err := seqs.mergeHistory(&hist.decoders)
658 if err != nil {
659 return err
660 }
661 if debug {
662 println("History merged ok")
663 }
664 br := &bitReader{}
665 if err := br.init(in); err != nil {
666 return err
667 }
668
669 // TODO: Investigate if sending history without decoders are faster.
670 // This would allow the sequences to be decoded async and only have to construct stream history.
671 // If only recent offsets were not transferred, this would be an obvious win.
672 // Also, if first 3 sequences don't reference recent offsets, all sequences can be decoded.
673
674 if err := seqs.initialize(br, hist, literals, b.dst); err != nil {
675 println("initializing sequences:", err)
676 return err
677 }
678
679 err = seqs.decode(nSeqs, br, hist.b)
680 if err != nil {
681 return err
682 }
683 if !br.finished() {
684 return fmt.Errorf("%d extra bits on block, should be 0", br.remain())
685 }
686
687 err = br.close()
688 if err != nil {
689 printf("Closing sequences: %v, %+v\n", err, *br)
690 }
691 if len(b.data) > maxCompressedBlockSize {
692 return fmt.Errorf("compressed block size too large (%d)", len(b.data))
693 }
694 // Set output and release references.
695 b.dst = seqs.out
696 seqs.out, seqs.literals, seqs.hist = nil, nil, nil
697
698 if !delayedHistory {
699 // If we don't have delayed history, no need to update.
700 hist.recentOffsets = seqs.prevOffset
701 return nil
702 }
703 if b.Last {
704 // if last block we don't care about history.
705 println("Last block, no history returned")
706 hist.b = hist.b[:0]
707 return nil
708 }
709 hist.append(b.dst)
710 hist.recentOffsets = seqs.prevOffset
711 if debug {
712 println("Finished block with literals:", len(literals), "and", nSeqs, "sequences.")
713 }
714
715 return nil
716}