blob: 839a95fbff87ff4406f11080d28a49c8f6c22256 [file] [log] [blame]
Dinesh Belwalkare63f7f92019-11-22 23:11:16 +00001// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8 "bytes"
9 "encoding/hex"
10 "errors"
11 "hash"
12 "io"
13 "sync"
14
15 "github.com/klauspost/compress/zstd/internal/xxhash"
16)
17
18type frameDec struct {
19 o decoderOptions
20 crc hash.Hash64
21 frameDone sync.WaitGroup
22 offset int64
23
24 WindowSize uint64
25 DictionaryID uint32
26 FrameContentSize uint64
27 HasCheckSum bool
28 SingleSegment bool
29
30 // maxWindowSize is the maximum windows size to support.
31 // should never be bigger than max-int.
32 maxWindowSize uint64
33
34 // In order queue of blocks being decoded.
35 decoding chan *blockDec
36
37 // Frame history passed between blocks
38 history history
39
40 rawInput byteBuffer
41
42 // Byte buffer that can be reused for small input blocks.
43 bBuf byteBuf
44
45 // asyncRunning indicates whether the async routine processes input on 'decoding'.
46 asyncRunning bool
47 asyncRunningMu sync.Mutex
48}
49
50const (
51 // The minimum Window_Size is 1 KB.
52 minWindowSize = 1 << 10
53)
54
55var (
56 frameMagic = []byte{0x28, 0xb5, 0x2f, 0xfd}
57 skippableFrameMagic = []byte{0x2a, 0x4d, 0x18}
58)
59
60func newFrameDec(o decoderOptions) *frameDec {
61 d := frameDec{
62 o: o,
63 maxWindowSize: 1 << 30,
64 }
65 if d.maxWindowSize > o.maxDecodedSize {
66 d.maxWindowSize = o.maxDecodedSize
67 }
68 return &d
69}
70
71// reset will read the frame header and prepare for block decoding.
72// If nothing can be read from the input, io.EOF will be returned.
73// Any other error indicated that the stream contained data, but
74// there was a problem.
75func (d *frameDec) reset(br byteBuffer) error {
76 d.HasCheckSum = false
77 d.WindowSize = 0
78 var b []byte
79 for {
80 b = br.readSmall(4)
81 if b == nil {
82 return io.EOF
83 }
84 if !bytes.Equal(b[1:4], skippableFrameMagic) || b[0]&0xf0 != 0x50 {
85 if debug {
86 println("Not skippable", hex.EncodeToString(b), hex.EncodeToString(skippableFrameMagic))
87 }
88 // Break if not skippable frame.
89 break
90 }
91 // Read size to skip
92 b = br.readSmall(4)
93 if b == nil {
94 println("Reading Frame Size EOF")
95 return io.ErrUnexpectedEOF
96 }
97 n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
98 println("Skipping frame with", n, "bytes.")
99 err := br.skipN(int(n))
100 if err != nil {
101 if debug {
102 println("Reading discarded frame", err)
103 }
104 return err
105 }
106 }
107 if !bytes.Equal(b, frameMagic) {
108 println("Got magic numbers: ", b, "want:", frameMagic)
109 return ErrMagicMismatch
110 }
111
112 // Read Frame_Header_Descriptor
113 fhd, err := br.readByte()
114 if err != nil {
115 println("Reading Frame_Header_Descriptor", err)
116 return err
117 }
118 d.SingleSegment = fhd&(1<<5) != 0
119
120 if fhd&(1<<3) != 0 {
121 return errors.New("Reserved bit set on frame header")
122 }
123
124 // Read Window_Descriptor
125 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
126 d.WindowSize = 0
127 if !d.SingleSegment {
128 wd, err := br.readByte()
129 if err != nil {
130 println("Reading Window_Descriptor", err)
131 return err
132 }
133 printf("raw: %x, mantissa: %d, exponent: %d\n", wd, wd&7, wd>>3)
134 windowLog := 10 + (wd >> 3)
135 windowBase := uint64(1) << windowLog
136 windowAdd := (windowBase / 8) * uint64(wd&0x7)
137 d.WindowSize = windowBase + windowAdd
138 }
139
140 // Read Dictionary_ID
141 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary_id
142 d.DictionaryID = 0
143 if size := fhd & 3; size != 0 {
144 if size == 3 {
145 size = 4
146 }
147 b = br.readSmall(int(size))
148 if b == nil {
149 if debug {
150 println("Reading Dictionary_ID", io.ErrUnexpectedEOF)
151 }
152 return io.ErrUnexpectedEOF
153 }
154 switch size {
155 case 1:
156 d.DictionaryID = uint32(b[0])
157 case 2:
158 d.DictionaryID = uint32(b[0]) | (uint32(b[1]) << 8)
159 case 4:
160 d.DictionaryID = uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
161 }
162 if debug {
163 println("Dict size", size, "ID:", d.DictionaryID)
164 }
165 if d.DictionaryID != 0 {
166 return ErrUnknownDictionary
167 }
168 }
169
170 // Read Frame_Content_Size
171 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame_content_size
172 var fcsSize int
173 v := fhd >> 6
174 switch v {
175 case 0:
176 if d.SingleSegment {
177 fcsSize = 1
178 }
179 default:
180 fcsSize = 1 << v
181 }
182 d.FrameContentSize = 0
183 if fcsSize > 0 {
184 b := br.readSmall(fcsSize)
185 if b == nil {
186 println("Reading Frame content", io.ErrUnexpectedEOF)
187 return io.ErrUnexpectedEOF
188 }
189 switch fcsSize {
190 case 1:
191 d.FrameContentSize = uint64(b[0])
192 case 2:
193 // When FCS_Field_Size is 2, the offset of 256 is added.
194 d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) + 256
195 case 4:
196 d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3] << 24))
197 case 8:
198 d1 := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
199 d2 := uint32(b[4]) | (uint32(b[5]) << 8) | (uint32(b[6]) << 16) | (uint32(b[7]) << 24)
200 d.FrameContentSize = uint64(d1) | (uint64(d2) << 32)
201 }
202 if debug {
203 println("field size bits:", v, "fcsSize:", fcsSize, "FrameContentSize:", d.FrameContentSize, hex.EncodeToString(b[:fcsSize]))
204 }
205 }
206 // Move this to shared.
207 d.HasCheckSum = fhd&(1<<2) != 0
208 if d.HasCheckSum {
209 if d.crc == nil {
210 d.crc = xxhash.New()
211 }
212 d.crc.Reset()
213 }
214
215 if d.WindowSize == 0 && d.SingleSegment {
216 // We may not need window in this case.
217 d.WindowSize = d.FrameContentSize
218 if d.WindowSize < minWindowSize {
219 d.WindowSize = minWindowSize
220 }
221 }
222
223 if d.WindowSize > d.maxWindowSize {
224 printf("window size %d > max %d\n", d.WindowSize, d.maxWindowSize)
225 return ErrWindowSizeExceeded
226 }
227 // The minimum Window_Size is 1 KB.
228 if d.WindowSize < minWindowSize {
229 println("got window size: ", d.WindowSize)
230 return ErrWindowSizeTooSmall
231 }
232 d.history.windowSize = int(d.WindowSize)
233 d.history.maxSize = d.history.windowSize + maxBlockSize
234 // history contains input - maybe we do something
235 d.rawInput = br
236 return nil
237}
238
239// next will start decoding the next block from stream.
240func (d *frameDec) next(block *blockDec) error {
241 if debug {
242 printf("decoding new block %p:%p", block, block.data)
243 }
244 err := block.reset(d.rawInput, d.WindowSize)
245 if err != nil {
246 println("block error:", err)
247 // Signal the frame decoder we have a problem.
248 d.sendErr(block, err)
249 return err
250 }
251 block.input <- struct{}{}
252 if debug {
253 println("next block:", block)
254 }
255 d.asyncRunningMu.Lock()
256 defer d.asyncRunningMu.Unlock()
257 if !d.asyncRunning {
258 return nil
259 }
260 if block.Last {
261 // We indicate the frame is done by sending io.EOF
262 d.decoding <- block
263 return io.EOF
264 }
265 d.decoding <- block
266 return nil
267}
268
269// sendEOF will queue an error block on the frame.
270// This will cause the frame decoder to return when it encounters the block.
271// Returns true if the decoder was added.
272func (d *frameDec) sendErr(block *blockDec, err error) bool {
273 d.asyncRunningMu.Lock()
274 defer d.asyncRunningMu.Unlock()
275 if !d.asyncRunning {
276 return false
277 }
278
279 println("sending error", err.Error())
280 block.sendErr(err)
281 d.decoding <- block
282 return true
283}
284
285// checkCRC will check the checksum if the frame has one.
286// Will return ErrCRCMismatch if crc check failed, otherwise nil.
287func (d *frameDec) checkCRC() error {
288 if !d.HasCheckSum {
289 return nil
290 }
291 var tmp [4]byte
292 got := d.crc.Sum64()
293 // Flip to match file order.
294 tmp[0] = byte(got >> 0)
295 tmp[1] = byte(got >> 8)
296 tmp[2] = byte(got >> 16)
297 tmp[3] = byte(got >> 24)
298
299 // We can overwrite upper tmp now
300 want := d.rawInput.readSmall(4)
301 if want == nil {
302 println("CRC missing?")
303 return io.ErrUnexpectedEOF
304 }
305
306 if !bytes.Equal(tmp[:], want) {
307 if debug {
308 println("CRC Check Failed:", tmp[:], "!=", want)
309 }
310 return ErrCRCMismatch
311 }
312 println("CRC ok")
313 return nil
314}
315
316func (d *frameDec) initAsync() {
317 if !d.o.lowMem && !d.SingleSegment {
318 // set max extra size history to 20MB.
319 d.history.maxSize = d.history.windowSize + maxBlockSize*10
320 }
321 // re-alloc if more than one extra block size.
322 if d.o.lowMem && cap(d.history.b) > d.history.maxSize+maxBlockSize {
323 d.history.b = make([]byte, 0, d.history.maxSize)
324 }
325 if cap(d.history.b) < d.history.maxSize {
326 d.history.b = make([]byte, 0, d.history.maxSize)
327 }
328 if cap(d.decoding) < d.o.concurrent {
329 d.decoding = make(chan *blockDec, d.o.concurrent)
330 }
331 if debug {
332 h := d.history
333 printf("history init. len: %d, cap: %d", len(h.b), cap(h.b))
334 }
335 d.asyncRunningMu.Lock()
336 d.asyncRunning = true
337 d.asyncRunningMu.Unlock()
338}
339
340// startDecoder will start decoding blocks and write them to the writer.
341// The decoder will stop as soon as an error occurs or at end of frame.
342// When the frame has finished decoding the *bufio.Reader
343// containing the remaining input will be sent on frameDec.frameDone.
344func (d *frameDec) startDecoder(output chan decodeOutput) {
345 // TODO: Init to dictionary
346 d.history.reset()
347 written := int64(0)
348
349 defer func() {
350 d.asyncRunningMu.Lock()
351 d.asyncRunning = false
352 d.asyncRunningMu.Unlock()
353
354 // Drain the currently decoding.
355 d.history.error = true
356 flushdone:
357 for {
358 select {
359 case b := <-d.decoding:
360 b.history <- &d.history
361 output <- <-b.result
362 default:
363 break flushdone
364 }
365 }
366 println("frame decoder done, signalling done")
367 d.frameDone.Done()
368 }()
369 // Get decoder for first block.
370 block := <-d.decoding
371 block.history <- &d.history
372 for {
373 var next *blockDec
374 // Get result
375 r := <-block.result
376 if r.err != nil {
377 println("Result contained error", r.err)
378 output <- r
379 return
380 }
381 if debug {
382 println("got result, from ", d.offset, "to", d.offset+int64(len(r.b)))
383 d.offset += int64(len(r.b))
384 }
385 if !block.Last {
386 // Send history to next block
387 select {
388 case next = <-d.decoding:
389 if debug {
390 println("Sending ", len(d.history.b), "bytes as history")
391 }
392 next.history <- &d.history
393 default:
394 // Wait until we have sent the block, so
395 // other decoders can potentially get the decoder.
396 next = nil
397 }
398 }
399
400 // Add checksum, async to decoding.
401 if d.HasCheckSum {
402 n, err := d.crc.Write(r.b)
403 if err != nil {
404 r.err = err
405 if n != len(r.b) {
406 r.err = io.ErrShortWrite
407 }
408 output <- r
409 return
410 }
411 }
412 written += int64(len(r.b))
413 if d.SingleSegment && uint64(written) > d.FrameContentSize {
414 r.err = ErrFrameSizeExceeded
415 output <- r
416 return
417 }
418 if block.Last {
419 r.err = d.checkCRC()
420 output <- r
421 return
422 }
423 output <- r
424 if next == nil {
425 // There was no decoder available, we wait for one now that we have sent to the writer.
426 if debug {
427 println("Sending ", len(d.history.b), " bytes as history")
428 }
429 next = <-d.decoding
430 next.history <- &d.history
431 }
432 block = next
433 }
434}
435
436// runDecoder will create a sync decoder that will decode a block of data.
437func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
438 // TODO: Init to dictionary
439 d.history.reset()
440 saved := d.history.b
441
442 // We use the history for output to avoid copying it.
443 d.history.b = dst
444 // Store input length, so we only check new data.
445 crcStart := len(dst)
446 var err error
447 for {
448 err = dec.reset(d.rawInput, d.WindowSize)
449 if err != nil {
450 break
451 }
452 if debug {
453 println("next block:", dec)
454 }
455 err = dec.decodeBuf(&d.history)
456 if err != nil || dec.Last {
457 break
458 }
459 if uint64(len(d.history.b)) > d.o.maxDecodedSize {
460 err = ErrDecoderSizeExceeded
461 break
462 }
463 if d.SingleSegment && uint64(len(d.history.b)) > d.o.maxDecodedSize {
464 err = ErrFrameSizeExceeded
465 break
466 }
467 }
468 dst = d.history.b
469 if err == nil {
470 if d.HasCheckSum {
471 var n int
472 n, err = d.crc.Write(dst[crcStart:])
473 if err == nil {
474 if n != len(dst)-crcStart {
475 err = io.ErrShortWrite
476 }
477 }
478 err = d.checkCRC()
479 }
480 }
481 d.history.b = saved
482 return dst, err
483}