| // Copyright 2019+ Klaus Post. All rights reserved. |
| // License information can be found in the LICENSE file. |
| // Based on work by Yann Collet, released under BSD License. |
| |
| package zstd |
| |
| import ( |
| "bytes" |
| "encoding/hex" |
| "errors" |
| "hash" |
| "io" |
| "sync" |
| |
| "github.com/klauspost/compress/zstd/internal/xxhash" |
| ) |
| |
| type frameDec struct { |
| o decoderOptions |
| crc hash.Hash64 |
| frameDone sync.WaitGroup |
| offset int64 |
| |
| WindowSize uint64 |
| DictionaryID uint32 |
| FrameContentSize uint64 |
| HasCheckSum bool |
| SingleSegment bool |
| |
| // maxWindowSize is the maximum windows size to support. |
| // should never be bigger than max-int. |
| maxWindowSize uint64 |
| |
| // In order queue of blocks being decoded. |
| decoding chan *blockDec |
| |
| // Frame history passed between blocks |
| history history |
| |
| rawInput byteBuffer |
| |
| // Byte buffer that can be reused for small input blocks. |
| bBuf byteBuf |
| |
| // asyncRunning indicates whether the async routine processes input on 'decoding'. |
| asyncRunning bool |
| asyncRunningMu sync.Mutex |
| } |
| |
| const ( |
| // The minimum Window_Size is 1 KB. |
| MinWindowSize = 1 << 10 |
| MaxWindowSize = 1 << 30 |
| ) |
| |
| var ( |
| frameMagic = []byte{0x28, 0xb5, 0x2f, 0xfd} |
| skippableFrameMagic = []byte{0x2a, 0x4d, 0x18} |
| ) |
| |
| func newFrameDec(o decoderOptions) *frameDec { |
| d := frameDec{ |
| o: o, |
| maxWindowSize: MaxWindowSize, |
| } |
| if d.maxWindowSize > o.maxDecodedSize { |
| d.maxWindowSize = o.maxDecodedSize |
| } |
| return &d |
| } |
| |
| // reset will read the frame header and prepare for block decoding. |
| // If nothing can be read from the input, io.EOF will be returned. |
| // Any other error indicated that the stream contained data, but |
| // there was a problem. |
| func (d *frameDec) reset(br byteBuffer) error { |
| d.HasCheckSum = false |
| d.WindowSize = 0 |
| var b []byte |
| for { |
| b = br.readSmall(4) |
| if b == nil { |
| return io.EOF |
| } |
| if !bytes.Equal(b[1:4], skippableFrameMagic) || b[0]&0xf0 != 0x50 { |
| if debug { |
| println("Not skippable", hex.EncodeToString(b), hex.EncodeToString(skippableFrameMagic)) |
| } |
| // Break if not skippable frame. |
| break |
| } |
| // Read size to skip |
| b = br.readSmall(4) |
| if b == nil { |
| println("Reading Frame Size EOF") |
| return io.ErrUnexpectedEOF |
| } |
| n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24) |
| println("Skipping frame with", n, "bytes.") |
| err := br.skipN(int(n)) |
| if err != nil { |
| if debug { |
| println("Reading discarded frame", err) |
| } |
| return err |
| } |
| } |
| if !bytes.Equal(b, frameMagic) { |
| println("Got magic numbers: ", b, "want:", frameMagic) |
| return ErrMagicMismatch |
| } |
| |
| // Read Frame_Header_Descriptor |
| fhd, err := br.readByte() |
| if err != nil { |
| println("Reading Frame_Header_Descriptor", err) |
| return err |
| } |
| d.SingleSegment = fhd&(1<<5) != 0 |
| |
| if fhd&(1<<3) != 0 { |
| return errors.New("Reserved bit set on frame header") |
| } |
| |
| // Read Window_Descriptor |
| // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor |
| d.WindowSize = 0 |
| if !d.SingleSegment { |
| wd, err := br.readByte() |
| if err != nil { |
| println("Reading Window_Descriptor", err) |
| return err |
| } |
| printf("raw: %x, mantissa: %d, exponent: %d\n", wd, wd&7, wd>>3) |
| windowLog := 10 + (wd >> 3) |
| windowBase := uint64(1) << windowLog |
| windowAdd := (windowBase / 8) * uint64(wd&0x7) |
| d.WindowSize = windowBase + windowAdd |
| } |
| |
| // Read Dictionary_ID |
| // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary_id |
| d.DictionaryID = 0 |
| if size := fhd & 3; size != 0 { |
| if size == 3 { |
| size = 4 |
| } |
| b = br.readSmall(int(size)) |
| if b == nil { |
| if debug { |
| println("Reading Dictionary_ID", io.ErrUnexpectedEOF) |
| } |
| return io.ErrUnexpectedEOF |
| } |
| switch size { |
| case 1: |
| d.DictionaryID = uint32(b[0]) |
| case 2: |
| d.DictionaryID = uint32(b[0]) | (uint32(b[1]) << 8) |
| case 4: |
| d.DictionaryID = uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24) |
| } |
| if debug { |
| println("Dict size", size, "ID:", d.DictionaryID) |
| } |
| if d.DictionaryID != 0 { |
| return ErrUnknownDictionary |
| } |
| } |
| |
| // Read Frame_Content_Size |
| // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame_content_size |
| var fcsSize int |
| v := fhd >> 6 |
| switch v { |
| case 0: |
| if d.SingleSegment { |
| fcsSize = 1 |
| } |
| default: |
| fcsSize = 1 << v |
| } |
| d.FrameContentSize = 0 |
| if fcsSize > 0 { |
| b := br.readSmall(fcsSize) |
| if b == nil { |
| println("Reading Frame content", io.ErrUnexpectedEOF) |
| return io.ErrUnexpectedEOF |
| } |
| switch fcsSize { |
| case 1: |
| d.FrameContentSize = uint64(b[0]) |
| case 2: |
| // When FCS_Field_Size is 2, the offset of 256 is added. |
| d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) + 256 |
| case 4: |
| d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24) |
| case 8: |
| d1 := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24) |
| d2 := uint32(b[4]) | (uint32(b[5]) << 8) | (uint32(b[6]) << 16) | (uint32(b[7]) << 24) |
| d.FrameContentSize = uint64(d1) | (uint64(d2) << 32) |
| } |
| if debug { |
| println("field size bits:", v, "fcsSize:", fcsSize, "FrameContentSize:", d.FrameContentSize, hex.EncodeToString(b[:fcsSize]), "singleseg:", d.SingleSegment, "window:", d.WindowSize) |
| } |
| } |
| // Move this to shared. |
| d.HasCheckSum = fhd&(1<<2) != 0 |
| if d.HasCheckSum { |
| if d.crc == nil { |
| d.crc = xxhash.New() |
| } |
| d.crc.Reset() |
| } |
| |
| if d.WindowSize == 0 && d.SingleSegment { |
| // We may not need window in this case. |
| d.WindowSize = d.FrameContentSize |
| if d.WindowSize < MinWindowSize { |
| d.WindowSize = MinWindowSize |
| } |
| } |
| |
| if d.WindowSize > d.maxWindowSize { |
| printf("window size %d > max %d\n", d.WindowSize, d.maxWindowSize) |
| return ErrWindowSizeExceeded |
| } |
| // The minimum Window_Size is 1 KB. |
| if d.WindowSize < MinWindowSize { |
| println("got window size: ", d.WindowSize) |
| return ErrWindowSizeTooSmall |
| } |
| d.history.windowSize = int(d.WindowSize) |
| d.history.maxSize = d.history.windowSize + maxBlockSize |
| // history contains input - maybe we do something |
| d.rawInput = br |
| return nil |
| } |
| |
| // next will start decoding the next block from stream. |
| func (d *frameDec) next(block *blockDec) error { |
| if debug { |
| printf("decoding new block %p:%p", block, block.data) |
| } |
| err := block.reset(d.rawInput, d.WindowSize) |
| if err != nil { |
| println("block error:", err) |
| // Signal the frame decoder we have a problem. |
| d.sendErr(block, err) |
| return err |
| } |
| block.input <- struct{}{} |
| if debug { |
| println("next block:", block) |
| } |
| d.asyncRunningMu.Lock() |
| defer d.asyncRunningMu.Unlock() |
| if !d.asyncRunning { |
| return nil |
| } |
| if block.Last { |
| // We indicate the frame is done by sending io.EOF |
| d.decoding <- block |
| return io.EOF |
| } |
| d.decoding <- block |
| return nil |
| } |
| |
| // sendEOF will queue an error block on the frame. |
| // This will cause the frame decoder to return when it encounters the block. |
| // Returns true if the decoder was added. |
| func (d *frameDec) sendErr(block *blockDec, err error) bool { |
| d.asyncRunningMu.Lock() |
| defer d.asyncRunningMu.Unlock() |
| if !d.asyncRunning { |
| return false |
| } |
| |
| println("sending error", err.Error()) |
| block.sendErr(err) |
| d.decoding <- block |
| return true |
| } |
| |
| // checkCRC will check the checksum if the frame has one. |
| // Will return ErrCRCMismatch if crc check failed, otherwise nil. |
| func (d *frameDec) checkCRC() error { |
| if !d.HasCheckSum { |
| return nil |
| } |
| var tmp [4]byte |
| got := d.crc.Sum64() |
| // Flip to match file order. |
| tmp[0] = byte(got >> 0) |
| tmp[1] = byte(got >> 8) |
| tmp[2] = byte(got >> 16) |
| tmp[3] = byte(got >> 24) |
| |
| // We can overwrite upper tmp now |
| want := d.rawInput.readSmall(4) |
| if want == nil { |
| println("CRC missing?") |
| return io.ErrUnexpectedEOF |
| } |
| |
| if !bytes.Equal(tmp[:], want) { |
| if debug { |
| println("CRC Check Failed:", tmp[:], "!=", want) |
| } |
| return ErrCRCMismatch |
| } |
| if debug { |
| println("CRC ok", tmp[:]) |
| } |
| return nil |
| } |
| |
| func (d *frameDec) initAsync() { |
| if !d.o.lowMem && !d.SingleSegment { |
| // set max extra size history to 20MB. |
| d.history.maxSize = d.history.windowSize + maxBlockSize*10 |
| } |
| // re-alloc if more than one extra block size. |
| if d.o.lowMem && cap(d.history.b) > d.history.maxSize+maxBlockSize { |
| d.history.b = make([]byte, 0, d.history.maxSize) |
| } |
| if cap(d.history.b) < d.history.maxSize { |
| d.history.b = make([]byte, 0, d.history.maxSize) |
| } |
| if cap(d.decoding) < d.o.concurrent { |
| d.decoding = make(chan *blockDec, d.o.concurrent) |
| } |
| if debug { |
| h := d.history |
| printf("history init. len: %d, cap: %d", len(h.b), cap(h.b)) |
| } |
| d.asyncRunningMu.Lock() |
| d.asyncRunning = true |
| d.asyncRunningMu.Unlock() |
| } |
| |
| // startDecoder will start decoding blocks and write them to the writer. |
| // The decoder will stop as soon as an error occurs or at end of frame. |
| // When the frame has finished decoding the *bufio.Reader |
| // containing the remaining input will be sent on frameDec.frameDone. |
| func (d *frameDec) startDecoder(output chan decodeOutput) { |
| // TODO: Init to dictionary |
| d.history.reset() |
| written := int64(0) |
| |
| defer func() { |
| d.asyncRunningMu.Lock() |
| d.asyncRunning = false |
| d.asyncRunningMu.Unlock() |
| |
| // Drain the currently decoding. |
| d.history.error = true |
| flushdone: |
| for { |
| select { |
| case b := <-d.decoding: |
| b.history <- &d.history |
| output <- <-b.result |
| default: |
| break flushdone |
| } |
| } |
| println("frame decoder done, signalling done") |
| d.frameDone.Done() |
| }() |
| // Get decoder for first block. |
| block := <-d.decoding |
| block.history <- &d.history |
| for { |
| var next *blockDec |
| // Get result |
| r := <-block.result |
| if r.err != nil { |
| println("Result contained error", r.err) |
| output <- r |
| return |
| } |
| if debug { |
| println("got result, from ", d.offset, "to", d.offset+int64(len(r.b))) |
| d.offset += int64(len(r.b)) |
| } |
| if !block.Last { |
| // Send history to next block |
| select { |
| case next = <-d.decoding: |
| if debug { |
| println("Sending ", len(d.history.b), "bytes as history") |
| } |
| next.history <- &d.history |
| default: |
| // Wait until we have sent the block, so |
| // other decoders can potentially get the decoder. |
| next = nil |
| } |
| } |
| |
| // Add checksum, async to decoding. |
| if d.HasCheckSum { |
| n, err := d.crc.Write(r.b) |
| if err != nil { |
| r.err = err |
| if n != len(r.b) { |
| r.err = io.ErrShortWrite |
| } |
| output <- r |
| return |
| } |
| } |
| written += int64(len(r.b)) |
| if d.SingleSegment && uint64(written) > d.FrameContentSize { |
| println("runDecoder: single segment and", uint64(written), ">", d.FrameContentSize) |
| r.err = ErrFrameSizeExceeded |
| output <- r |
| return |
| } |
| if block.Last { |
| r.err = d.checkCRC() |
| output <- r |
| return |
| } |
| output <- r |
| if next == nil { |
| // There was no decoder available, we wait for one now that we have sent to the writer. |
| if debug { |
| println("Sending ", len(d.history.b), " bytes as history") |
| } |
| next = <-d.decoding |
| next.history <- &d.history |
| } |
| block = next |
| } |
| } |
| |
| // runDecoder will create a sync decoder that will decode a block of data. |
| func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) { |
| // TODO: Init to dictionary |
| d.history.reset() |
| saved := d.history.b |
| |
| // We use the history for output to avoid copying it. |
| d.history.b = dst |
| // Store input length, so we only check new data. |
| crcStart := len(dst) |
| var err error |
| for { |
| err = dec.reset(d.rawInput, d.WindowSize) |
| if err != nil { |
| break |
| } |
| if debug { |
| println("next block:", dec) |
| } |
| err = dec.decodeBuf(&d.history) |
| if err != nil || dec.Last { |
| break |
| } |
| if uint64(len(d.history.b)) > d.o.maxDecodedSize { |
| err = ErrDecoderSizeExceeded |
| break |
| } |
| if d.SingleSegment && uint64(len(d.history.b)) > d.o.maxDecodedSize { |
| println("runDecoder: single segment and", uint64(len(d.history.b)), ">", d.o.maxDecodedSize) |
| err = ErrFrameSizeExceeded |
| break |
| } |
| } |
| dst = d.history.b |
| if err == nil { |
| if d.HasCheckSum { |
| var n int |
| n, err = d.crc.Write(dst[crcStart:]) |
| if err == nil { |
| if n != len(dst)-crcStart { |
| err = io.ErrShortWrite |
| } else { |
| err = d.checkCRC() |
| } |
| } |
| } |
| } |
| d.history.b = saved |
| return dst, err |
| } |