blob: 40790747a3789688f2e35439111c053437c106a8 [file] [log] [blame]
Scott Bakered4efab2020-01-13 19:12:25 -08001// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8 "bytes"
9 "encoding/hex"
10 "errors"
11 "hash"
12 "io"
13 "sync"
14
15 "github.com/klauspost/compress/zstd/internal/xxhash"
16)
17
18type frameDec struct {
19 o decoderOptions
20 crc hash.Hash64
21 frameDone sync.WaitGroup
22 offset int64
23
24 WindowSize uint64
25 DictionaryID uint32
26 FrameContentSize uint64
27 HasCheckSum bool
28 SingleSegment bool
29
30 // maxWindowSize is the maximum windows size to support.
31 // should never be bigger than max-int.
32 maxWindowSize uint64
33
34 // In order queue of blocks being decoded.
35 decoding chan *blockDec
36
37 // Frame history passed between blocks
38 history history
39
40 rawInput byteBuffer
41
42 // Byte buffer that can be reused for small input blocks.
43 bBuf byteBuf
44
45 // asyncRunning indicates whether the async routine processes input on 'decoding'.
46 asyncRunning bool
47 asyncRunningMu sync.Mutex
48}
49
50const (
51 // The minimum Window_Size is 1 KB.
52 MinWindowSize = 1 << 10
53 MaxWindowSize = 1 << 30
54)
55
56var (
57 frameMagic = []byte{0x28, 0xb5, 0x2f, 0xfd}
58 skippableFrameMagic = []byte{0x2a, 0x4d, 0x18}
59)
60
61func newFrameDec(o decoderOptions) *frameDec {
62 d := frameDec{
63 o: o,
64 maxWindowSize: MaxWindowSize,
65 }
66 if d.maxWindowSize > o.maxDecodedSize {
67 d.maxWindowSize = o.maxDecodedSize
68 }
69 return &d
70}
71
72// reset will read the frame header and prepare for block decoding.
73// If nothing can be read from the input, io.EOF will be returned.
74// Any other error indicated that the stream contained data, but
75// there was a problem.
76func (d *frameDec) reset(br byteBuffer) error {
77 d.HasCheckSum = false
78 d.WindowSize = 0
79 var b []byte
80 for {
81 b = br.readSmall(4)
82 if b == nil {
83 return io.EOF
84 }
85 if !bytes.Equal(b[1:4], skippableFrameMagic) || b[0]&0xf0 != 0x50 {
86 if debug {
87 println("Not skippable", hex.EncodeToString(b), hex.EncodeToString(skippableFrameMagic))
88 }
89 // Break if not skippable frame.
90 break
91 }
92 // Read size to skip
93 b = br.readSmall(4)
94 if b == nil {
95 println("Reading Frame Size EOF")
96 return io.ErrUnexpectedEOF
97 }
98 n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
99 println("Skipping frame with", n, "bytes.")
100 err := br.skipN(int(n))
101 if err != nil {
102 if debug {
103 println("Reading discarded frame", err)
104 }
105 return err
106 }
107 }
108 if !bytes.Equal(b, frameMagic) {
109 println("Got magic numbers: ", b, "want:", frameMagic)
110 return ErrMagicMismatch
111 }
112
113 // Read Frame_Header_Descriptor
114 fhd, err := br.readByte()
115 if err != nil {
116 println("Reading Frame_Header_Descriptor", err)
117 return err
118 }
119 d.SingleSegment = fhd&(1<<5) != 0
120
121 if fhd&(1<<3) != 0 {
122 return errors.New("Reserved bit set on frame header")
123 }
124
125 // Read Window_Descriptor
126 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
127 d.WindowSize = 0
128 if !d.SingleSegment {
129 wd, err := br.readByte()
130 if err != nil {
131 println("Reading Window_Descriptor", err)
132 return err
133 }
134 printf("raw: %x, mantissa: %d, exponent: %d\n", wd, wd&7, wd>>3)
135 windowLog := 10 + (wd >> 3)
136 windowBase := uint64(1) << windowLog
137 windowAdd := (windowBase / 8) * uint64(wd&0x7)
138 d.WindowSize = windowBase + windowAdd
139 }
140
141 // Read Dictionary_ID
142 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary_id
143 d.DictionaryID = 0
144 if size := fhd & 3; size != 0 {
145 if size == 3 {
146 size = 4
147 }
148 b = br.readSmall(int(size))
149 if b == nil {
150 if debug {
151 println("Reading Dictionary_ID", io.ErrUnexpectedEOF)
152 }
153 return io.ErrUnexpectedEOF
154 }
155 switch size {
156 case 1:
157 d.DictionaryID = uint32(b[0])
158 case 2:
159 d.DictionaryID = uint32(b[0]) | (uint32(b[1]) << 8)
160 case 4:
161 d.DictionaryID = uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
162 }
163 if debug {
164 println("Dict size", size, "ID:", d.DictionaryID)
165 }
166 if d.DictionaryID != 0 {
167 return ErrUnknownDictionary
168 }
169 }
170
171 // Read Frame_Content_Size
172 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame_content_size
173 var fcsSize int
174 v := fhd >> 6
175 switch v {
176 case 0:
177 if d.SingleSegment {
178 fcsSize = 1
179 }
180 default:
181 fcsSize = 1 << v
182 }
183 d.FrameContentSize = 0
184 if fcsSize > 0 {
185 b := br.readSmall(fcsSize)
186 if b == nil {
187 println("Reading Frame content", io.ErrUnexpectedEOF)
188 return io.ErrUnexpectedEOF
189 }
190 switch fcsSize {
191 case 1:
192 d.FrameContentSize = uint64(b[0])
193 case 2:
194 // When FCS_Field_Size is 2, the offset of 256 is added.
195 d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) + 256
196 case 4:
197 d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24)
198 case 8:
199 d1 := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
200 d2 := uint32(b[4]) | (uint32(b[5]) << 8) | (uint32(b[6]) << 16) | (uint32(b[7]) << 24)
201 d.FrameContentSize = uint64(d1) | (uint64(d2) << 32)
202 }
203 if debug {
204 println("field size bits:", v, "fcsSize:", fcsSize, "FrameContentSize:", d.FrameContentSize, hex.EncodeToString(b[:fcsSize]), "singleseg:", d.SingleSegment, "window:", d.WindowSize)
205 }
206 }
207 // Move this to shared.
208 d.HasCheckSum = fhd&(1<<2) != 0
209 if d.HasCheckSum {
210 if d.crc == nil {
211 d.crc = xxhash.New()
212 }
213 d.crc.Reset()
214 }
215
216 if d.WindowSize == 0 && d.SingleSegment {
217 // We may not need window in this case.
218 d.WindowSize = d.FrameContentSize
219 if d.WindowSize < MinWindowSize {
220 d.WindowSize = MinWindowSize
221 }
222 }
223
224 if d.WindowSize > d.maxWindowSize {
225 printf("window size %d > max %d\n", d.WindowSize, d.maxWindowSize)
226 return ErrWindowSizeExceeded
227 }
228 // The minimum Window_Size is 1 KB.
229 if d.WindowSize < MinWindowSize {
230 println("got window size: ", d.WindowSize)
231 return ErrWindowSizeTooSmall
232 }
233 d.history.windowSize = int(d.WindowSize)
234 d.history.maxSize = d.history.windowSize + maxBlockSize
235 // history contains input - maybe we do something
236 d.rawInput = br
237 return nil
238}
239
240// next will start decoding the next block from stream.
241func (d *frameDec) next(block *blockDec) error {
242 if debug {
243 printf("decoding new block %p:%p", block, block.data)
244 }
245 err := block.reset(d.rawInput, d.WindowSize)
246 if err != nil {
247 println("block error:", err)
248 // Signal the frame decoder we have a problem.
249 d.sendErr(block, err)
250 return err
251 }
252 block.input <- struct{}{}
253 if debug {
254 println("next block:", block)
255 }
256 d.asyncRunningMu.Lock()
257 defer d.asyncRunningMu.Unlock()
258 if !d.asyncRunning {
259 return nil
260 }
261 if block.Last {
262 // We indicate the frame is done by sending io.EOF
263 d.decoding <- block
264 return io.EOF
265 }
266 d.decoding <- block
267 return nil
268}
269
270// sendEOF will queue an error block on the frame.
271// This will cause the frame decoder to return when it encounters the block.
272// Returns true if the decoder was added.
273func (d *frameDec) sendErr(block *blockDec, err error) bool {
274 d.asyncRunningMu.Lock()
275 defer d.asyncRunningMu.Unlock()
276 if !d.asyncRunning {
277 return false
278 }
279
280 println("sending error", err.Error())
281 block.sendErr(err)
282 d.decoding <- block
283 return true
284}
285
286// checkCRC will check the checksum if the frame has one.
287// Will return ErrCRCMismatch if crc check failed, otherwise nil.
288func (d *frameDec) checkCRC() error {
289 if !d.HasCheckSum {
290 return nil
291 }
292 var tmp [4]byte
293 got := d.crc.Sum64()
294 // Flip to match file order.
295 tmp[0] = byte(got >> 0)
296 tmp[1] = byte(got >> 8)
297 tmp[2] = byte(got >> 16)
298 tmp[3] = byte(got >> 24)
299
300 // We can overwrite upper tmp now
301 want := d.rawInput.readSmall(4)
302 if want == nil {
303 println("CRC missing?")
304 return io.ErrUnexpectedEOF
305 }
306
307 if !bytes.Equal(tmp[:], want) {
308 if debug {
309 println("CRC Check Failed:", tmp[:], "!=", want)
310 }
311 return ErrCRCMismatch
312 }
313 if debug {
314 println("CRC ok", tmp[:])
315 }
316 return nil
317}
318
319func (d *frameDec) initAsync() {
320 if !d.o.lowMem && !d.SingleSegment {
321 // set max extra size history to 20MB.
322 d.history.maxSize = d.history.windowSize + maxBlockSize*10
323 }
324 // re-alloc if more than one extra block size.
325 if d.o.lowMem && cap(d.history.b) > d.history.maxSize+maxBlockSize {
326 d.history.b = make([]byte, 0, d.history.maxSize)
327 }
328 if cap(d.history.b) < d.history.maxSize {
329 d.history.b = make([]byte, 0, d.history.maxSize)
330 }
331 if cap(d.decoding) < d.o.concurrent {
332 d.decoding = make(chan *blockDec, d.o.concurrent)
333 }
334 if debug {
335 h := d.history
336 printf("history init. len: %d, cap: %d", len(h.b), cap(h.b))
337 }
338 d.asyncRunningMu.Lock()
339 d.asyncRunning = true
340 d.asyncRunningMu.Unlock()
341}
342
343// startDecoder will start decoding blocks and write them to the writer.
344// The decoder will stop as soon as an error occurs or at end of frame.
345// When the frame has finished decoding the *bufio.Reader
346// containing the remaining input will be sent on frameDec.frameDone.
347func (d *frameDec) startDecoder(output chan decodeOutput) {
348 // TODO: Init to dictionary
349 d.history.reset()
350 written := int64(0)
351
352 defer func() {
353 d.asyncRunningMu.Lock()
354 d.asyncRunning = false
355 d.asyncRunningMu.Unlock()
356
357 // Drain the currently decoding.
358 d.history.error = true
359 flushdone:
360 for {
361 select {
362 case b := <-d.decoding:
363 b.history <- &d.history
364 output <- <-b.result
365 default:
366 break flushdone
367 }
368 }
369 println("frame decoder done, signalling done")
370 d.frameDone.Done()
371 }()
372 // Get decoder for first block.
373 block := <-d.decoding
374 block.history <- &d.history
375 for {
376 var next *blockDec
377 // Get result
378 r := <-block.result
379 if r.err != nil {
380 println("Result contained error", r.err)
381 output <- r
382 return
383 }
384 if debug {
385 println("got result, from ", d.offset, "to", d.offset+int64(len(r.b)))
386 d.offset += int64(len(r.b))
387 }
388 if !block.Last {
389 // Send history to next block
390 select {
391 case next = <-d.decoding:
392 if debug {
393 println("Sending ", len(d.history.b), "bytes as history")
394 }
395 next.history <- &d.history
396 default:
397 // Wait until we have sent the block, so
398 // other decoders can potentially get the decoder.
399 next = nil
400 }
401 }
402
403 // Add checksum, async to decoding.
404 if d.HasCheckSum {
405 n, err := d.crc.Write(r.b)
406 if err != nil {
407 r.err = err
408 if n != len(r.b) {
409 r.err = io.ErrShortWrite
410 }
411 output <- r
412 return
413 }
414 }
415 written += int64(len(r.b))
416 if d.SingleSegment && uint64(written) > d.FrameContentSize {
417 println("runDecoder: single segment and", uint64(written), ">", d.FrameContentSize)
418 r.err = ErrFrameSizeExceeded
419 output <- r
420 return
421 }
422 if block.Last {
423 r.err = d.checkCRC()
424 output <- r
425 return
426 }
427 output <- r
428 if next == nil {
429 // There was no decoder available, we wait for one now that we have sent to the writer.
430 if debug {
431 println("Sending ", len(d.history.b), " bytes as history")
432 }
433 next = <-d.decoding
434 next.history <- &d.history
435 }
436 block = next
437 }
438}
439
440// runDecoder will create a sync decoder that will decode a block of data.
441func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
442 // TODO: Init to dictionary
443 d.history.reset()
444 saved := d.history.b
445
446 // We use the history for output to avoid copying it.
447 d.history.b = dst
448 // Store input length, so we only check new data.
449 crcStart := len(dst)
450 var err error
451 for {
452 err = dec.reset(d.rawInput, d.WindowSize)
453 if err != nil {
454 break
455 }
456 if debug {
457 println("next block:", dec)
458 }
459 err = dec.decodeBuf(&d.history)
460 if err != nil || dec.Last {
461 break
462 }
463 if uint64(len(d.history.b)) > d.o.maxDecodedSize {
464 err = ErrDecoderSizeExceeded
465 break
466 }
467 if d.SingleSegment && uint64(len(d.history.b)) > d.o.maxDecodedSize {
468 println("runDecoder: single segment and", uint64(len(d.history.b)), ">", d.o.maxDecodedSize)
469 err = ErrFrameSizeExceeded
470 break
471 }
472 }
473 dst = d.history.b
474 if err == nil {
475 if d.HasCheckSum {
476 var n int
477 n, err = d.crc.Write(dst[crcStart:])
478 if err == nil {
479 if n != len(dst)-crcStart {
480 err = io.ErrShortWrite
481 } else {
482 err = d.checkCRC()
483 }
484 }
485 }
486 }
487 d.history.b = saved
488 return dst, err
489}