blob: 693c5f05d2d403d480ea8690a0d998c94c23f49a [file] [log] [blame]
khenaidoo106c61a2021-08-11 18:05:46 -04001// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8 "bytes"
9 "encoding/hex"
10 "errors"
11 "hash"
12 "io"
13 "sync"
14
15 "github.com/klauspost/compress/zstd/internal/xxhash"
16)
17
18type frameDec struct {
19 o decoderOptions
20 crc hash.Hash64
21 offset int64
22
23 WindowSize uint64
24
25 // maxWindowSize is the maximum windows size to support.
26 // should never be bigger than max-int.
27 maxWindowSize uint64
28
29 // In order queue of blocks being decoded.
30 decoding chan *blockDec
31
32 // Frame history passed between blocks
33 history history
34
35 rawInput byteBuffer
36
37 // Byte buffer that can be reused for small input blocks.
38 bBuf byteBuf
39
40 FrameContentSize uint64
41 frameDone sync.WaitGroup
42
43 DictionaryID *uint32
44 HasCheckSum bool
45 SingleSegment bool
46
47 // asyncRunning indicates whether the async routine processes input on 'decoding'.
48 asyncRunningMu sync.Mutex
49 asyncRunning bool
50}
51
52const (
53 // The minimum Window_Size is 1 KB.
54 MinWindowSize = 1 << 10
55 MaxWindowSize = 1 << 29
56)
57
58var (
59 frameMagic = []byte{0x28, 0xb5, 0x2f, 0xfd}
60 skippableFrameMagic = []byte{0x2a, 0x4d, 0x18}
61)
62
63func newFrameDec(o decoderOptions) *frameDec {
64 d := frameDec{
65 o: o,
66 maxWindowSize: MaxWindowSize,
67 }
68 if d.maxWindowSize > o.maxDecodedSize {
69 d.maxWindowSize = o.maxDecodedSize
70 }
71 return &d
72}
73
74// reset will read the frame header and prepare for block decoding.
75// If nothing can be read from the input, io.EOF will be returned.
76// Any other error indicated that the stream contained data, but
77// there was a problem.
78func (d *frameDec) reset(br byteBuffer) error {
79 d.HasCheckSum = false
80 d.WindowSize = 0
81 var b []byte
82 for {
83 b = br.readSmall(4)
84 if b == nil {
85 return io.EOF
86 }
87 if !bytes.Equal(b[1:4], skippableFrameMagic) || b[0]&0xf0 != 0x50 {
88 if debug {
89 println("Not skippable", hex.EncodeToString(b), hex.EncodeToString(skippableFrameMagic))
90 }
91 // Break if not skippable frame.
92 break
93 }
94 // Read size to skip
95 b = br.readSmall(4)
96 if b == nil {
97 println("Reading Frame Size EOF")
98 return io.ErrUnexpectedEOF
99 }
100 n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
101 println("Skipping frame with", n, "bytes.")
102 err := br.skipN(int(n))
103 if err != nil {
104 if debug {
105 println("Reading discarded frame", err)
106 }
107 return err
108 }
109 }
110 if !bytes.Equal(b, frameMagic) {
111 println("Got magic numbers: ", b, "want:", frameMagic)
112 return ErrMagicMismatch
113 }
114
115 // Read Frame_Header_Descriptor
116 fhd, err := br.readByte()
117 if err != nil {
118 println("Reading Frame_Header_Descriptor", err)
119 return err
120 }
121 d.SingleSegment = fhd&(1<<5) != 0
122
123 if fhd&(1<<3) != 0 {
124 return errors.New("reserved bit set on frame header")
125 }
126
127 // Read Window_Descriptor
128 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
129 d.WindowSize = 0
130 if !d.SingleSegment {
131 wd, err := br.readByte()
132 if err != nil {
133 println("Reading Window_Descriptor", err)
134 return err
135 }
136 printf("raw: %x, mantissa: %d, exponent: %d\n", wd, wd&7, wd>>3)
137 windowLog := 10 + (wd >> 3)
138 windowBase := uint64(1) << windowLog
139 windowAdd := (windowBase / 8) * uint64(wd&0x7)
140 d.WindowSize = windowBase + windowAdd
141 }
142
143 // Read Dictionary_ID
144 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary_id
145 d.DictionaryID = nil
146 if size := fhd & 3; size != 0 {
147 if size == 3 {
148 size = 4
149 }
150 b = br.readSmall(int(size))
151 if b == nil {
152 if debug {
153 println("Reading Dictionary_ID", io.ErrUnexpectedEOF)
154 }
155 return io.ErrUnexpectedEOF
156 }
157 var id uint32
158 switch size {
159 case 1:
160 id = uint32(b[0])
161 case 2:
162 id = uint32(b[0]) | (uint32(b[1]) << 8)
163 case 4:
164 id = uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
165 }
166 if debug {
167 println("Dict size", size, "ID:", id)
168 }
169 if id > 0 {
170 // ID 0 means "sorry, no dictionary anyway".
171 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary-format
172 d.DictionaryID = &id
173 }
174 }
175
176 // Read Frame_Content_Size
177 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame_content_size
178 var fcsSize int
179 v := fhd >> 6
180 switch v {
181 case 0:
182 if d.SingleSegment {
183 fcsSize = 1
184 }
185 default:
186 fcsSize = 1 << v
187 }
188 d.FrameContentSize = 0
189 if fcsSize > 0 {
190 b := br.readSmall(fcsSize)
191 if b == nil {
192 println("Reading Frame content", io.ErrUnexpectedEOF)
193 return io.ErrUnexpectedEOF
194 }
195 switch fcsSize {
196 case 1:
197 d.FrameContentSize = uint64(b[0])
198 case 2:
199 // When FCS_Field_Size is 2, the offset of 256 is added.
200 d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) + 256
201 case 4:
202 d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24)
203 case 8:
204 d1 := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
205 d2 := uint32(b[4]) | (uint32(b[5]) << 8) | (uint32(b[6]) << 16) | (uint32(b[7]) << 24)
206 d.FrameContentSize = uint64(d1) | (uint64(d2) << 32)
207 }
208 if debug {
209 println("field size bits:", v, "fcsSize:", fcsSize, "FrameContentSize:", d.FrameContentSize, hex.EncodeToString(b[:fcsSize]), "singleseg:", d.SingleSegment, "window:", d.WindowSize)
210 }
211 }
212 // Move this to shared.
213 d.HasCheckSum = fhd&(1<<2) != 0
214 if d.HasCheckSum {
215 if d.crc == nil {
216 d.crc = xxhash.New()
217 }
218 d.crc.Reset()
219 }
220
221 if d.WindowSize == 0 && d.SingleSegment {
222 // We may not need window in this case.
223 d.WindowSize = d.FrameContentSize
224 if d.WindowSize < MinWindowSize {
225 d.WindowSize = MinWindowSize
226 }
227 }
228
229 if d.WindowSize > d.maxWindowSize {
230 printf("window size %d > max %d\n", d.WindowSize, d.maxWindowSize)
231 return ErrWindowSizeExceeded
232 }
233 // The minimum Window_Size is 1 KB.
234 if d.WindowSize < MinWindowSize {
235 println("got window size: ", d.WindowSize)
236 return ErrWindowSizeTooSmall
237 }
238 d.history.windowSize = int(d.WindowSize)
239 if d.o.lowMem && d.history.windowSize < maxBlockSize {
240 d.history.maxSize = d.history.windowSize * 2
241 } else {
242 d.history.maxSize = d.history.windowSize + maxBlockSize
243 }
244 // history contains input - maybe we do something
245 d.rawInput = br
246 return nil
247}
248
249// next will start decoding the next block from stream.
250func (d *frameDec) next(block *blockDec) error {
251 if debug {
252 printf("decoding new block %p:%p", block, block.data)
253 }
254 err := block.reset(d.rawInput, d.WindowSize)
255 if err != nil {
256 println("block error:", err)
257 // Signal the frame decoder we have a problem.
258 d.sendErr(block, err)
259 return err
260 }
261 block.input <- struct{}{}
262 if debug {
263 println("next block:", block)
264 }
265 d.asyncRunningMu.Lock()
266 defer d.asyncRunningMu.Unlock()
267 if !d.asyncRunning {
268 return nil
269 }
270 if block.Last {
271 // We indicate the frame is done by sending io.EOF
272 d.decoding <- block
273 return io.EOF
274 }
275 d.decoding <- block
276 return nil
277}
278
279// sendEOF will queue an error block on the frame.
280// This will cause the frame decoder to return when it encounters the block.
281// Returns true if the decoder was added.
282func (d *frameDec) sendErr(block *blockDec, err error) bool {
283 d.asyncRunningMu.Lock()
284 defer d.asyncRunningMu.Unlock()
285 if !d.asyncRunning {
286 return false
287 }
288
289 println("sending error", err.Error())
290 block.sendErr(err)
291 d.decoding <- block
292 return true
293}
294
295// checkCRC will check the checksum if the frame has one.
296// Will return ErrCRCMismatch if crc check failed, otherwise nil.
297func (d *frameDec) checkCRC() error {
298 if !d.HasCheckSum {
299 return nil
300 }
301 var tmp [4]byte
302 got := d.crc.Sum64()
303 // Flip to match file order.
304 tmp[0] = byte(got >> 0)
305 tmp[1] = byte(got >> 8)
306 tmp[2] = byte(got >> 16)
307 tmp[3] = byte(got >> 24)
308
309 // We can overwrite upper tmp now
310 want := d.rawInput.readSmall(4)
311 if want == nil {
312 println("CRC missing?")
313 return io.ErrUnexpectedEOF
314 }
315
316 if !bytes.Equal(tmp[:], want) {
317 if debug {
318 println("CRC Check Failed:", tmp[:], "!=", want)
319 }
320 return ErrCRCMismatch
321 }
322 if debug {
323 println("CRC ok", tmp[:])
324 }
325 return nil
326}
327
328func (d *frameDec) initAsync() {
329 if !d.o.lowMem && !d.SingleSegment {
330 // set max extra size history to 10MB.
331 d.history.maxSize = d.history.windowSize + maxBlockSize*5
332 }
333 // re-alloc if more than one extra block size.
334 if d.o.lowMem && cap(d.history.b) > d.history.maxSize+maxBlockSize {
335 d.history.b = make([]byte, 0, d.history.maxSize)
336 }
337 if cap(d.history.b) < d.history.maxSize {
338 d.history.b = make([]byte, 0, d.history.maxSize)
339 }
340 if cap(d.decoding) < d.o.concurrent {
341 d.decoding = make(chan *blockDec, d.o.concurrent)
342 }
343 if debug {
344 h := d.history
345 printf("history init. len: %d, cap: %d", len(h.b), cap(h.b))
346 }
347 d.asyncRunningMu.Lock()
348 d.asyncRunning = true
349 d.asyncRunningMu.Unlock()
350}
351
352// startDecoder will start decoding blocks and write them to the writer.
353// The decoder will stop as soon as an error occurs or at end of frame.
354// When the frame has finished decoding the *bufio.Reader
355// containing the remaining input will be sent on frameDec.frameDone.
356func (d *frameDec) startDecoder(output chan decodeOutput) {
357 written := int64(0)
358
359 defer func() {
360 d.asyncRunningMu.Lock()
361 d.asyncRunning = false
362 d.asyncRunningMu.Unlock()
363
364 // Drain the currently decoding.
365 d.history.error = true
366 flushdone:
367 for {
368 select {
369 case b := <-d.decoding:
370 b.history <- &d.history
371 output <- <-b.result
372 default:
373 break flushdone
374 }
375 }
376 println("frame decoder done, signalling done")
377 d.frameDone.Done()
378 }()
379 // Get decoder for first block.
380 block := <-d.decoding
381 block.history <- &d.history
382 for {
383 var next *blockDec
384 // Get result
385 r := <-block.result
386 if r.err != nil {
387 println("Result contained error", r.err)
388 output <- r
389 return
390 }
391 if debug {
392 println("got result, from ", d.offset, "to", d.offset+int64(len(r.b)))
393 d.offset += int64(len(r.b))
394 }
395 if !block.Last {
396 // Send history to next block
397 select {
398 case next = <-d.decoding:
399 if debug {
400 println("Sending ", len(d.history.b), "bytes as history")
401 }
402 next.history <- &d.history
403 default:
404 // Wait until we have sent the block, so
405 // other decoders can potentially get the decoder.
406 next = nil
407 }
408 }
409
410 // Add checksum, async to decoding.
411 if d.HasCheckSum {
412 n, err := d.crc.Write(r.b)
413 if err != nil {
414 r.err = err
415 if n != len(r.b) {
416 r.err = io.ErrShortWrite
417 }
418 output <- r
419 return
420 }
421 }
422 written += int64(len(r.b))
423 if d.SingleSegment && uint64(written) > d.FrameContentSize {
424 println("runDecoder: single segment and", uint64(written), ">", d.FrameContentSize)
425 r.err = ErrFrameSizeExceeded
426 output <- r
427 return
428 }
429 if block.Last {
430 r.err = d.checkCRC()
431 output <- r
432 return
433 }
434 output <- r
435 if next == nil {
436 // There was no decoder available, we wait for one now that we have sent to the writer.
437 if debug {
438 println("Sending ", len(d.history.b), " bytes as history")
439 }
440 next = <-d.decoding
441 next.history <- &d.history
442 }
443 block = next
444 }
445}
446
447// runDecoder will create a sync decoder that will decode a block of data.
448func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
449 saved := d.history.b
450
451 // We use the history for output to avoid copying it.
452 d.history.b = dst
453 // Store input length, so we only check new data.
454 crcStart := len(dst)
455 var err error
456 for {
457 err = dec.reset(d.rawInput, d.WindowSize)
458 if err != nil {
459 break
460 }
461 if debug {
462 println("next block:", dec)
463 }
464 err = dec.decodeBuf(&d.history)
465 if err != nil || dec.Last {
466 break
467 }
468 if uint64(len(d.history.b)) > d.o.maxDecodedSize {
469 err = ErrDecoderSizeExceeded
470 break
471 }
472 if d.SingleSegment && uint64(len(d.history.b)) > d.o.maxDecodedSize {
473 println("runDecoder: single segment and", uint64(len(d.history.b)), ">", d.o.maxDecodedSize)
474 err = ErrFrameSizeExceeded
475 break
476 }
477 }
478 dst = d.history.b
479 if err == nil {
480 if d.HasCheckSum {
481 var n int
482 n, err = d.crc.Write(dst[crcStart:])
483 if err == nil {
484 if n != len(dst)-crcStart {
485 err = io.ErrShortWrite
486 } else {
487 err = d.checkCRC()
488 }
489 }
490 }
491 }
492 d.history.b = saved
493 return dst, err
494}