blob: 3e161ea15e2654dfe828f8ebd86c44a8a9b7d9c4 [file] [log] [blame]
Dinesh Belwalkare63f7f92019-11-22 23:11:16 +00001// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8 "errors"
9 "fmt"
10 "io"
11 "sync"
12
13 "github.com/klauspost/compress/huff0"
14)
15
16type blockType uint8
17
18//go:generate stringer -type=blockType,literalsBlockType,seqCompMode,tableIndex
19
20const (
21 blockTypeRaw blockType = iota
22 blockTypeRLE
23 blockTypeCompressed
24 blockTypeReserved
25)
26
27type literalsBlockType uint8
28
29const (
30 literalsBlockRaw literalsBlockType = iota
31 literalsBlockRLE
32 literalsBlockCompressed
33 literalsBlockTreeless
34)
35
36const (
37 // maxCompressedBlockSize is the biggest allowed compressed block size (128KB)
38 maxCompressedBlockSize = 128 << 10
39
40 // Maximum possible block size (all Raw+Uncompressed).
41 maxBlockSize = (1 << 21) - 1
42
43 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#literals_section_header
44 maxCompressedLiteralSize = 1 << 18
45 maxRLELiteralSize = 1 << 20
46 maxMatchLen = 131074
47 maxSequences = 0x7f00 + 0xffff
48
49 // We support slightly less than the reference decoder to be able to
50 // use ints on 32 bit archs.
51 maxOffsetBits = 30
52)
53
54var (
55 huffDecoderPool = sync.Pool{New: func() interface{} {
56 return &huff0.Scratch{}
57 }}
58
59 fseDecoderPool = sync.Pool{New: func() interface{} {
60 return &fseDecoder{}
61 }}
62)
63
64type blockDec struct {
65 // Raw source data of the block.
66 data []byte
67 dataStorage []byte
68
69 // Destination of the decoded data.
70 dst []byte
71
72 // Buffer for literals data.
73 literalBuf []byte
74
75 // Window size of the block.
76 WindowSize uint64
77 Type blockType
78 RLESize uint32
79
80 // Is this the last block of a frame?
81 Last bool
82
83 // Use less memory
84 lowMem bool
85 history chan *history
86 input chan struct{}
87 result chan decodeOutput
88 sequenceBuf []seq
89 tmp [4]byte
90 err error
91}
92
93func (b *blockDec) String() string {
94 if b == nil {
95 return "<nil>"
96 }
97 return fmt.Sprintf("Steam Size: %d, Type: %v, Last: %t, Window: %d", len(b.data), b.Type, b.Last, b.WindowSize)
98}
99
100func newBlockDec(lowMem bool) *blockDec {
101 b := blockDec{
102 lowMem: lowMem,
103 result: make(chan decodeOutput, 1),
104 input: make(chan struct{}, 1),
105 history: make(chan *history, 1),
106 }
107 go b.startDecoder()
108 return &b
109}
110
111// reset will reset the block.
112// Input must be a start of a block and will be at the end of the block when returned.
113func (b *blockDec) reset(br byteBuffer, windowSize uint64) error {
114 b.WindowSize = windowSize
115 tmp := br.readSmall(3)
116 if tmp == nil {
117 if debug {
118 println("Reading block header:", io.ErrUnexpectedEOF)
119 }
120 return io.ErrUnexpectedEOF
121 }
122 bh := uint32(tmp[0]) | (uint32(tmp[1]) << 8) | (uint32(tmp[2]) << 16)
123 b.Last = bh&1 != 0
124 b.Type = blockType((bh >> 1) & 3)
125 // find size.
126 cSize := int(bh >> 3)
127 switch b.Type {
128 case blockTypeReserved:
129 return ErrReservedBlockType
130 case blockTypeRLE:
131 b.RLESize = uint32(cSize)
132 cSize = 1
133 case blockTypeCompressed:
134 if debug {
135 println("Data size on stream:", cSize)
136 }
137 b.RLESize = 0
138 if cSize > maxCompressedBlockSize || uint64(cSize) > b.WindowSize {
139 if debug {
140 printf("compressed block too big: csize:%d block: %+v\n", uint64(cSize), b)
141 }
142 return ErrCompressedSizeTooBig
143 }
144 default:
145 b.RLESize = 0
146 }
147
148 // Read block data.
149 if cap(b.dataStorage) < cSize {
150 if b.lowMem {
151 b.dataStorage = make([]byte, 0, cSize)
152 } else {
153 b.dataStorage = make([]byte, 0, maxBlockSize)
154 }
155 }
156 if cap(b.dst) <= maxBlockSize {
157 b.dst = make([]byte, 0, maxBlockSize+1)
158 }
159 var err error
160 b.data, err = br.readBig(cSize, b.dataStorage)
161 if err != nil {
162 if debug {
163 println("Reading block:", err)
164 }
165 return err
166 }
167 return nil
168}
169
170// sendEOF will make the decoder send EOF on this frame.
171func (b *blockDec) sendErr(err error) {
172 b.Last = true
173 b.Type = blockTypeReserved
174 b.err = err
175 b.input <- struct{}{}
176}
177
178// Close will release resources.
179// Closed blockDec cannot be reset.
180func (b *blockDec) Close() {
181 close(b.input)
182 close(b.history)
183 close(b.result)
184}
185
186// decodeAsync will prepare decoding the block when it receives input.
187// This will separate output and history.
188func (b *blockDec) startDecoder() {
189 for range b.input {
190 //println("blockDec: Got block input")
191 switch b.Type {
192 case blockTypeRLE:
193 if cap(b.dst) < int(b.RLESize) {
194 if b.lowMem {
195 b.dst = make([]byte, b.RLESize)
196 } else {
197 b.dst = make([]byte, maxBlockSize)
198 }
199 }
200 o := decodeOutput{
201 d: b,
202 b: b.dst[:b.RLESize],
203 err: nil,
204 }
205 v := b.data[0]
206 for i := range o.b {
207 o.b[i] = v
208 }
209 hist := <-b.history
210 hist.append(o.b)
211 b.result <- o
212 case blockTypeRaw:
213 o := decodeOutput{
214 d: b,
215 b: b.data,
216 err: nil,
217 }
218 hist := <-b.history
219 hist.append(o.b)
220 b.result <- o
221 case blockTypeCompressed:
222 b.dst = b.dst[:0]
223 err := b.decodeCompressed(nil)
224 o := decodeOutput{
225 d: b,
226 b: b.dst,
227 err: err,
228 }
229 if debug {
230 println("Decompressed to", len(b.dst), "bytes, error:", err)
231 }
232 b.result <- o
233 case blockTypeReserved:
234 // Used for returning errors.
235 <-b.history
236 b.result <- decodeOutput{
237 d: b,
238 b: nil,
239 err: b.err,
240 }
241 default:
242 panic("Invalid block type")
243 }
244 if debug {
245 println("blockDec: Finished block")
246 }
247 }
248}
249
250// decodeAsync will prepare decoding the block when it receives the history.
251// If history is provided, it will not fetch it from the channel.
252func (b *blockDec) decodeBuf(hist *history) error {
253 switch b.Type {
254 case blockTypeRLE:
255 if cap(b.dst) < int(b.RLESize) {
256 if b.lowMem {
257 b.dst = make([]byte, b.RLESize)
258 } else {
259 b.dst = make([]byte, maxBlockSize)
260 }
261 }
262 b.dst = b.dst[:b.RLESize]
263 v := b.data[0]
264 for i := range b.dst {
265 b.dst[i] = v
266 }
267 hist.appendKeep(b.dst)
268 return nil
269 case blockTypeRaw:
270 hist.appendKeep(b.data)
271 return nil
272 case blockTypeCompressed:
273 saved := b.dst
274 b.dst = hist.b
275 hist.b = nil
276 err := b.decodeCompressed(hist)
277 if debug {
278 println("Decompressed to total", len(b.dst), "bytes, error:", err)
279 }
280 hist.b = b.dst
281 b.dst = saved
282 return err
283 case blockTypeReserved:
284 // Used for returning errors.
285 return b.err
286 default:
287 panic("Invalid block type")
288 }
289}
290
291// decodeCompressed will start decompressing a block.
292// If no history is supplied the decoder will decodeAsync as much as possible
293// before fetching from blockDec.history
294func (b *blockDec) decodeCompressed(hist *history) error {
295 in := b.data
296 delayedHistory := hist == nil
297
298 if delayedHistory {
299 // We must always grab history.
300 defer func() {
301 if hist == nil {
302 <-b.history
303 }
304 }()
305 }
306 // There must be at least one byte for Literals_Block_Type and one for Sequences_Section_Header
307 if len(in) < 2 {
308 return ErrBlockTooSmall
309 }
310 litType := literalsBlockType(in[0] & 3)
311 var litRegenSize int
312 var litCompSize int
313 sizeFormat := (in[0] >> 2) & 3
314 var fourStreams bool
315 switch litType {
316 case literalsBlockRaw, literalsBlockRLE:
317 switch sizeFormat {
318 case 0, 2:
319 // Regenerated_Size uses 5 bits (0-31). Literals_Section_Header uses 1 byte.
320 litRegenSize = int(in[0] >> 3)
321 in = in[1:]
322 case 1:
323 // Regenerated_Size uses 12 bits (0-4095). Literals_Section_Header uses 2 bytes.
324 litRegenSize = int(in[0]>>4) + (int(in[1]) << 4)
325 in = in[2:]
326 case 3:
327 // Regenerated_Size uses 20 bits (0-1048575). Literals_Section_Header uses 3 bytes.
328 if len(in) < 3 {
329 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
330 return ErrBlockTooSmall
331 }
332 litRegenSize = int(in[0]>>4) + (int(in[1]) << 4) + (int(in[2]) << 12)
333 in = in[3:]
334 }
335 case literalsBlockCompressed, literalsBlockTreeless:
336 switch sizeFormat {
337 case 0, 1:
338 // Both Regenerated_Size and Compressed_Size use 10 bits (0-1023).
339 if len(in) < 3 {
340 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
341 return ErrBlockTooSmall
342 }
343 n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12)
344 litRegenSize = int(n & 1023)
345 litCompSize = int(n >> 10)
346 fourStreams = sizeFormat == 1
347 in = in[3:]
348 case 2:
349 fourStreams = true
350 if len(in) < 4 {
351 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
352 return ErrBlockTooSmall
353 }
354 n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) + (uint64(in[3]) << 20)
355 litRegenSize = int(n & 16383)
356 litCompSize = int(n >> 14)
357 in = in[4:]
358 case 3:
359 fourStreams = true
360 if len(in) < 5 {
361 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
362 return ErrBlockTooSmall
363 }
364 n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) + (uint64(in[3]) << 20) + (uint64(in[4]) << 28)
365 litRegenSize = int(n & 262143)
366 litCompSize = int(n >> 18)
367 in = in[5:]
368 }
369 }
370 if debug {
371 println("literals type:", litType, "litRegenSize:", litRegenSize, "litCompSize", litCompSize)
372 }
373 var literals []byte
374 var huff *huff0.Scratch
375 switch litType {
376 case literalsBlockRaw:
377 if len(in) < litRegenSize {
378 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litRegenSize)
379 return ErrBlockTooSmall
380 }
381 literals = in[:litRegenSize]
382 in = in[litRegenSize:]
383 //printf("Found %d uncompressed literals\n", litRegenSize)
384 case literalsBlockRLE:
385 if len(in) < 1 {
386 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", 1)
387 return ErrBlockTooSmall
388 }
389 if cap(b.literalBuf) < litRegenSize {
390 if b.lowMem {
391 b.literalBuf = make([]byte, litRegenSize)
392 } else {
393 if litRegenSize > maxCompressedLiteralSize {
394 // Exceptional
395 b.literalBuf = make([]byte, litRegenSize)
396 } else {
397 b.literalBuf = make([]byte, litRegenSize, maxCompressedLiteralSize)
398
399 }
400 }
401 }
402 literals = b.literalBuf[:litRegenSize]
403 v := in[0]
404 for i := range literals {
405 literals[i] = v
406 }
407 in = in[1:]
408 if debug {
409 printf("Found %d RLE compressed literals\n", litRegenSize)
410 }
411 case literalsBlockTreeless:
412 if len(in) < litCompSize {
413 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litCompSize)
414 return ErrBlockTooSmall
415 }
416 // Store compressed literals, so we defer decoding until we get history.
417 literals = in[:litCompSize]
418 in = in[litCompSize:]
419 if debug {
420 printf("Found %d compressed literals\n", litCompSize)
421 }
422 case literalsBlockCompressed:
423 if len(in) < litCompSize {
424 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litCompSize)
425 return ErrBlockTooSmall
426 }
427 literals = in[:litCompSize]
428 in = in[litCompSize:]
429
430 huff = huffDecoderPool.Get().(*huff0.Scratch)
431 var err error
432 // Ensure we have space to store it.
433 if cap(b.literalBuf) < litRegenSize {
434 if b.lowMem {
435 b.literalBuf = make([]byte, 0, litRegenSize)
436 } else {
437 b.literalBuf = make([]byte, 0, maxCompressedLiteralSize)
438 }
439 }
440 if huff == nil {
441 huff = &huff0.Scratch{}
442 }
443 huff.Out = b.literalBuf[:0]
444 huff, literals, err = huff0.ReadTable(literals, huff)
445 if err != nil {
446 println("reading huffman table:", err)
447 return err
448 }
449 // Use our out buffer.
450 huff.Out = b.literalBuf[:0]
451 huff.MaxDecodedSize = litRegenSize
452 if fourStreams {
453 literals, err = huff.Decompress4X(literals, litRegenSize)
454 } else {
455 literals, err = huff.Decompress1X(literals)
456 }
457 if err != nil {
458 println("decoding compressed literals:", err)
459 return err
460 }
461 // Make sure we don't leak our literals buffer
462 huff.Out = nil
463 if len(literals) != litRegenSize {
464 return fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals))
465 }
466 if debug {
467 printf("Decompressed %d literals into %d bytes\n", litCompSize, litRegenSize)
468 }
469 }
470
471 // Decode Sequences
472 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#sequences-section
473 if len(in) < 1 {
474 return ErrBlockTooSmall
475 }
476 seqHeader := in[0]
477 nSeqs := 0
478 switch {
479 case seqHeader == 0:
480 in = in[1:]
481 case seqHeader < 128:
482 nSeqs = int(seqHeader)
483 in = in[1:]
484 case seqHeader < 255:
485 if len(in) < 2 {
486 return ErrBlockTooSmall
487 }
488 nSeqs = int(seqHeader-128)<<8 | int(in[1])
489 in = in[2:]
490 case seqHeader == 255:
491 if len(in) < 3 {
492 return ErrBlockTooSmall
493 }
494 nSeqs = 0x7f00 + int(in[1]) + (int(in[2]) << 8)
495 in = in[3:]
496 }
497 // Allocate sequences
498 if cap(b.sequenceBuf) < nSeqs {
499 if b.lowMem {
500 b.sequenceBuf = make([]seq, nSeqs)
501 } else {
502 // Allocate max
503 b.sequenceBuf = make([]seq, nSeqs, maxSequences)
504 }
505 } else {
506 // Reuse buffer
507 b.sequenceBuf = b.sequenceBuf[:nSeqs]
508 }
509 var seqs = &sequenceDecs{}
510 if nSeqs > 0 {
511 if len(in) < 1 {
512 return ErrBlockTooSmall
513 }
514 br := byteReader{b: in, off: 0}
515 compMode := br.Uint8()
516 br.advance(1)
517 if debug {
518 printf("Compression modes: 0b%b", compMode)
519 }
520 for i := uint(0); i < 3; i++ {
521 mode := seqCompMode((compMode >> (6 - i*2)) & 3)
522 if debug {
523 println("Table", tableIndex(i), "is", mode)
524 }
525 var seq *sequenceDec
526 switch tableIndex(i) {
527 case tableLiteralLengths:
528 seq = &seqs.litLengths
529 case tableOffsets:
530 seq = &seqs.offsets
531 case tableMatchLengths:
532 seq = &seqs.matchLengths
533 default:
534 panic("unknown table")
535 }
536 switch mode {
537 case compModePredefined:
538 seq.fse = &fsePredef[i]
539 case compModeRLE:
540 if br.remain() < 1 {
541 return ErrBlockTooSmall
542 }
543 v := br.Uint8()
544 br.advance(1)
545 dec := fseDecoderPool.Get().(*fseDecoder)
546 symb, err := decSymbolValue(v, symbolTableX[i])
547 if err != nil {
548 printf("RLE Transform table (%v) error: %v", tableIndex(i), err)
549 return err
550 }
551 dec.setRLE(symb)
552 seq.fse = dec
553 if debug {
554 printf("RLE set to %+v, code: %v", symb, v)
555 }
556 case compModeFSE:
557 println("Reading table for", tableIndex(i))
558 dec := fseDecoderPool.Get().(*fseDecoder)
559 err := dec.readNCount(&br, uint16(maxTableSymbol[i]))
560 if err != nil {
561 println("Read table error:", err)
562 return err
563 }
564 err = dec.transform(symbolTableX[i])
565 if err != nil {
566 println("Transform table error:", err)
567 return err
568 }
569 if debug {
570 println("Read table ok", "symbolLen:", dec.symbolLen)
571 }
572 seq.fse = dec
573 case compModeRepeat:
574 seq.repeat = true
575 }
576 if br.overread() {
577 return io.ErrUnexpectedEOF
578 }
579 }
580 in = br.unread()
581 }
582
583 // Wait for history.
584 // All time spent after this is critical since it is strictly sequential.
585 if hist == nil {
586 hist = <-b.history
587 if hist.error {
588 return ErrDecoderClosed
589 }
590 }
591
592 // Decode treeless literal block.
593 if litType == literalsBlockTreeless {
594 // TODO: We could send the history early WITHOUT the stream history.
595 // This would allow decoding treeless literials before the byte history is available.
596 // Silencia stats: Treeless 4393, with: 32775, total: 37168, 11% treeless.
597 // So not much obvious gain here.
598
599 if hist.huffTree == nil {
600 return errors.New("literal block was treeless, but no history was defined")
601 }
602 // Ensure we have space to store it.
603 if cap(b.literalBuf) < litRegenSize {
604 if b.lowMem {
605 b.literalBuf = make([]byte, 0, litRegenSize)
606 } else {
607 b.literalBuf = make([]byte, 0, maxCompressedLiteralSize)
608 }
609 }
610 var err error
611 // Use our out buffer.
612 huff = hist.huffTree
613 huff.Out = b.literalBuf[:0]
614 huff.MaxDecodedSize = litRegenSize
615 if fourStreams {
616 literals, err = huff.Decompress4X(literals, litRegenSize)
617 } else {
618 literals, err = huff.Decompress1X(literals)
619 }
620 // Make sure we don't leak our literals buffer
621 huff.Out = nil
622 if err != nil {
623 println("decompressing literals:", err)
624 return err
625 }
626 if len(literals) != litRegenSize {
627 return fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals))
628 }
629 } else {
630 if hist.huffTree != nil && huff != nil {
631 huffDecoderPool.Put(hist.huffTree)
632 hist.huffTree = nil
633 }
634 }
635 if huff != nil {
636 huff.Out = nil
637 hist.huffTree = huff
638 }
639 if debug {
640 println("Final literals:", len(literals), "and", nSeqs, "sequences.")
641 }
642
643 if nSeqs == 0 {
644 // Decompressed content is defined entirely as Literals Section content.
645 b.dst = append(b.dst, literals...)
646 if delayedHistory {
647 hist.append(literals)
648 }
649 return nil
650 }
651
652 seqs, err := seqs.mergeHistory(&hist.decoders)
653 if err != nil {
654 return err
655 }
656 if debug {
657 println("History merged ok")
658 }
659 br := &bitReader{}
660 if err := br.init(in); err != nil {
661 return err
662 }
663
664 // TODO: Investigate if sending history without decoders are faster.
665 // This would allow the sequences to be decoded async and only have to construct stream history.
666 // If only recent offsets were not transferred, this would be an obvious win.
667 // Also, if first 3 sequences don't reference recent offsets, all sequences can be decoded.
668
669 if err := seqs.initialize(br, hist, literals, b.dst); err != nil {
670 println("initializing sequences:", err)
671 return err
672 }
673
674 err = seqs.decode(nSeqs, br, hist.b)
675 if err != nil {
676 return err
677 }
678 if !br.finished() {
679 return fmt.Errorf("%d extra bits on block, should be 0", br.remain())
680 }
681
682 err = br.close()
683 if err != nil {
684 printf("Closing sequences: %v, %+v\n", err, *br)
685 }
686 if len(b.data) > maxCompressedBlockSize {
687 return fmt.Errorf("compressed block size too large (%d)", len(b.data))
688 }
689 // Set output and release references.
690 b.dst = seqs.out
691 seqs.out, seqs.literals, seqs.hist = nil, nil, nil
692
693 if !delayedHistory {
694 // If we don't have delayed history, no need to update.
695 hist.recentOffsets = seqs.prevOffset
696 return nil
697 }
698 if b.Last {
699 // if last block we don't care about history.
700 println("Last block, no history returned")
701 hist.b = hist.b[:0]
702 return nil
703 }
704 hist.append(b.dst)
705 hist.recentOffsets = seqs.prevOffset
706 if debug {
707 println("Finished block with literals:", len(literals), "and", nSeqs, "sequences.")
708 }
709
710 return nil
711}