blob: 1dd39e63b7e83f8e894aaeb90e8e63cfc94936cc [file] [log] [blame]
khenaidoo7d3c5582021-08-11 18:09:44 -04001// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8 "errors"
9 "fmt"
10 "io"
11)
12
13type seq struct {
14 litLen uint32
15 matchLen uint32
16 offset uint32
17
18 // Codes are stored here for the encoder
19 // so they only have to be looked up once.
20 llCode, mlCode, ofCode uint8
21}
22
23func (s seq) String() string {
24 if s.offset <= 3 {
25 if s.offset == 0 {
26 return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset: INVALID (0)")
27 }
28 return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset:", s.offset, " (repeat)")
29 }
30 return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset:", s.offset-3, " (new)")
31}
32
33type seqCompMode uint8
34
35const (
36 compModePredefined seqCompMode = iota
37 compModeRLE
38 compModeFSE
39 compModeRepeat
40)
41
42type sequenceDec struct {
43 // decoder keeps track of the current state and updates it from the bitstream.
44 fse *fseDecoder
45 state fseState
46 repeat bool
47}
48
49// init the state of the decoder with input from stream.
50func (s *sequenceDec) init(br *bitReader) error {
51 if s.fse == nil {
52 return errors.New("sequence decoder not defined")
53 }
54 s.state.init(br, s.fse.actualTableLog, s.fse.dt[:1<<s.fse.actualTableLog])
55 return nil
56}
57
58// sequenceDecs contains all 3 sequence decoders and their state.
59type sequenceDecs struct {
60 litLengths sequenceDec
61 offsets sequenceDec
62 matchLengths sequenceDec
63 prevOffset [3]int
64 hist []byte
65 dict []byte
66 literals []byte
67 out []byte
68 windowSize int
69 maxBits uint8
70}
71
72// initialize all 3 decoders from the stream input.
73func (s *sequenceDecs) initialize(br *bitReader, hist *history, literals, out []byte) error {
74 if err := s.litLengths.init(br); err != nil {
75 return errors.New("litLengths:" + err.Error())
76 }
77 if err := s.offsets.init(br); err != nil {
78 return errors.New("offsets:" + err.Error())
79 }
80 if err := s.matchLengths.init(br); err != nil {
81 return errors.New("matchLengths:" + err.Error())
82 }
83 s.literals = literals
84 s.hist = hist.b
85 s.prevOffset = hist.recentOffsets
86 s.maxBits = s.litLengths.fse.maxBits + s.offsets.fse.maxBits + s.matchLengths.fse.maxBits
87 s.windowSize = hist.windowSize
88 s.out = out
89 s.dict = nil
90 if hist.dict != nil {
91 s.dict = hist.dict.content
92 }
93 return nil
94}
95
96// decode sequences from the stream with the provided history.
97func (s *sequenceDecs) decode(seqs int, br *bitReader, hist []byte) error {
98 startSize := len(s.out)
99 // Grab full sizes tables, to avoid bounds checks.
100 llTable, mlTable, ofTable := s.litLengths.fse.dt[:maxTablesize], s.matchLengths.fse.dt[:maxTablesize], s.offsets.fse.dt[:maxTablesize]
101 llState, mlState, ofState := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state
102
103 for i := seqs - 1; i >= 0; i-- {
104 if br.overread() {
105 printf("reading sequence %d, exceeded available data\n", seqs-i)
106 return io.ErrUnexpectedEOF
107 }
108 var ll, mo, ml int
109 if br.off > 4+((maxOffsetBits+16+16)>>3) {
110 // inlined function:
111 // ll, mo, ml = s.nextFast(br, llState, mlState, ofState)
112
113 // Final will not read from stream.
114 var llB, mlB, moB uint8
115 ll, llB = llState.final()
116 ml, mlB = mlState.final()
117 mo, moB = ofState.final()
118
119 // extra bits are stored in reverse order.
120 br.fillFast()
121 mo += br.getBits(moB)
122 if s.maxBits > 32 {
123 br.fillFast()
124 }
125 ml += br.getBits(mlB)
126 ll += br.getBits(llB)
127
128 if moB > 1 {
129 s.prevOffset[2] = s.prevOffset[1]
130 s.prevOffset[1] = s.prevOffset[0]
131 s.prevOffset[0] = mo
132 } else {
133 // mo = s.adjustOffset(mo, ll, moB)
134 // Inlined for rather big speedup
135 if ll == 0 {
136 // There is an exception though, when current sequence's literals_length = 0.
137 // In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2,
138 // an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte.
139 mo++
140 }
141
142 if mo == 0 {
143 mo = s.prevOffset[0]
144 } else {
145 var temp int
146 if mo == 3 {
147 temp = s.prevOffset[0] - 1
148 } else {
149 temp = s.prevOffset[mo]
150 }
151
152 if temp == 0 {
153 // 0 is not valid; input is corrupted; force offset to 1
154 println("temp was 0")
155 temp = 1
156 }
157
158 if mo != 1 {
159 s.prevOffset[2] = s.prevOffset[1]
160 }
161 s.prevOffset[1] = s.prevOffset[0]
162 s.prevOffset[0] = temp
163 mo = temp
164 }
165 }
166 br.fillFast()
167 } else {
168 ll, mo, ml = s.next(br, llState, mlState, ofState)
169 br.fill()
170 }
171
172 if debugSequences {
173 println("Seq", seqs-i-1, "Litlen:", ll, "mo:", mo, "(abs) ml:", ml)
174 }
175
176 if ll > len(s.literals) {
177 return fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", ll, len(s.literals))
178 }
179 size := ll + ml + len(s.out)
180 if size-startSize > maxBlockSize {
181 return fmt.Errorf("output (%d) bigger than max block size", size)
182 }
183 if size > cap(s.out) {
184 // Not enough size, which can happen under high volume block streaming conditions
185 // but could be if destination slice is too small for sync operations.
186 // over-allocating here can create a large amount of GC pressure so we try to keep
187 // it as contained as possible
188 used := len(s.out) - startSize
189 addBytes := 256 + ll + ml + used>>2
190 // Clamp to max block size.
191 if used+addBytes > maxBlockSize {
192 addBytes = maxBlockSize - used
193 }
194 s.out = append(s.out, make([]byte, addBytes)...)
195 s.out = s.out[:len(s.out)-addBytes]
196 }
197 if ml > maxMatchLen {
198 return fmt.Errorf("match len (%d) bigger than max allowed length", ml)
199 }
200
201 // Add literals
202 s.out = append(s.out, s.literals[:ll]...)
203 s.literals = s.literals[ll:]
204 out := s.out
205
206 if mo == 0 && ml > 0 {
207 return fmt.Errorf("zero matchoff and matchlen (%d) > 0", ml)
208 }
209
210 if mo > len(s.out)+len(hist) || mo > s.windowSize {
211 if len(s.dict) == 0 {
212 return fmt.Errorf("match offset (%d) bigger than current history (%d)", mo, len(s.out)+len(hist))
213 }
214
215 // we may be in dictionary.
216 dictO := len(s.dict) - (mo - (len(s.out) + len(hist)))
217 if dictO < 0 || dictO >= len(s.dict) {
218 return fmt.Errorf("match offset (%d) bigger than current history (%d)", mo, len(s.out)+len(hist))
219 }
220 end := dictO + ml
221 if end > len(s.dict) {
222 out = append(out, s.dict[dictO:]...)
223 mo -= len(s.dict) - dictO
224 ml -= len(s.dict) - dictO
225 } else {
226 out = append(out, s.dict[dictO:end]...)
227 mo = 0
228 ml = 0
229 }
230 }
231
232 // Copy from history.
233 // TODO: Blocks without history could be made to ignore this completely.
234 if v := mo - len(s.out); v > 0 {
235 // v is the start position in history from end.
236 start := len(s.hist) - v
237 if ml > v {
238 // Some goes into current block.
239 // Copy remainder of history
240 out = append(out, s.hist[start:]...)
241 mo -= v
242 ml -= v
243 } else {
244 out = append(out, s.hist[start:start+ml]...)
245 ml = 0
246 }
247 }
248 // We must be in current buffer now
249 if ml > 0 {
250 start := len(s.out) - mo
251 if ml <= len(s.out)-start {
252 // No overlap
253 out = append(out, s.out[start:start+ml]...)
254 } else {
255 // Overlapping copy
256 // Extend destination slice and copy one byte at the time.
257 out = out[:len(out)+ml]
258 src := out[start : start+ml]
259 // Destination is the space we just added.
260 dst := out[len(out)-ml:]
261 dst = dst[:len(src)]
262 for i := range src {
263 dst[i] = src[i]
264 }
265 }
266 }
267 s.out = out
268 if i == 0 {
269 // This is the last sequence, so we shouldn't update state.
270 break
271 }
272
273 // Manually inlined, ~ 5-20% faster
274 // Update all 3 states at once. Approx 20% faster.
275 nBits := llState.nbBits() + mlState.nbBits() + ofState.nbBits()
276 if nBits == 0 {
277 llState = llTable[llState.newState()&maxTableMask]
278 mlState = mlTable[mlState.newState()&maxTableMask]
279 ofState = ofTable[ofState.newState()&maxTableMask]
280 } else {
281 bits := br.getBitsFast(nBits)
282 lowBits := uint16(bits >> ((ofState.nbBits() + mlState.nbBits()) & 31))
283 llState = llTable[(llState.newState()+lowBits)&maxTableMask]
284
285 lowBits = uint16(bits >> (ofState.nbBits() & 31))
286 lowBits &= bitMask[mlState.nbBits()&15]
287 mlState = mlTable[(mlState.newState()+lowBits)&maxTableMask]
288
289 lowBits = uint16(bits) & bitMask[ofState.nbBits()&15]
290 ofState = ofTable[(ofState.newState()+lowBits)&maxTableMask]
291 }
292 }
293
294 // Add final literals
295 s.out = append(s.out, s.literals...)
296 return nil
297}
298
299// update states, at least 27 bits must be available.
300func (s *sequenceDecs) update(br *bitReader) {
301 // Max 8 bits
302 s.litLengths.state.next(br)
303 // Max 9 bits
304 s.matchLengths.state.next(br)
305 // Max 8 bits
306 s.offsets.state.next(br)
307}
308
309var bitMask [16]uint16
310
311func init() {
312 for i := range bitMask[:] {
313 bitMask[i] = uint16((1 << uint(i)) - 1)
314 }
315}
316
317// update states, at least 27 bits must be available.
318func (s *sequenceDecs) updateAlt(br *bitReader) {
319 // Update all 3 states at once. Approx 20% faster.
320 a, b, c := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state
321
322 nBits := a.nbBits() + b.nbBits() + c.nbBits()
323 if nBits == 0 {
324 s.litLengths.state.state = s.litLengths.state.dt[a.newState()]
325 s.matchLengths.state.state = s.matchLengths.state.dt[b.newState()]
326 s.offsets.state.state = s.offsets.state.dt[c.newState()]
327 return
328 }
329 bits := br.getBitsFast(nBits)
330 lowBits := uint16(bits >> ((c.nbBits() + b.nbBits()) & 31))
331 s.litLengths.state.state = s.litLengths.state.dt[a.newState()+lowBits]
332
333 lowBits = uint16(bits >> (c.nbBits() & 31))
334 lowBits &= bitMask[b.nbBits()&15]
335 s.matchLengths.state.state = s.matchLengths.state.dt[b.newState()+lowBits]
336
337 lowBits = uint16(bits) & bitMask[c.nbBits()&15]
338 s.offsets.state.state = s.offsets.state.dt[c.newState()+lowBits]
339}
340
341// nextFast will return new states when there are at least 4 unused bytes left on the stream when done.
342func (s *sequenceDecs) nextFast(br *bitReader, llState, mlState, ofState decSymbol) (ll, mo, ml int) {
343 // Final will not read from stream.
344 ll, llB := llState.final()
345 ml, mlB := mlState.final()
346 mo, moB := ofState.final()
347
348 // extra bits are stored in reverse order.
349 br.fillFast()
350 mo += br.getBits(moB)
351 if s.maxBits > 32 {
352 br.fillFast()
353 }
354 ml += br.getBits(mlB)
355 ll += br.getBits(llB)
356
357 if moB > 1 {
358 s.prevOffset[2] = s.prevOffset[1]
359 s.prevOffset[1] = s.prevOffset[0]
360 s.prevOffset[0] = mo
361 return
362 }
363 // mo = s.adjustOffset(mo, ll, moB)
364 // Inlined for rather big speedup
365 if ll == 0 {
366 // There is an exception though, when current sequence's literals_length = 0.
367 // In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2,
368 // an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte.
369 mo++
370 }
371
372 if mo == 0 {
373 mo = s.prevOffset[0]
374 return
375 }
376 var temp int
377 if mo == 3 {
378 temp = s.prevOffset[0] - 1
379 } else {
380 temp = s.prevOffset[mo]
381 }
382
383 if temp == 0 {
384 // 0 is not valid; input is corrupted; force offset to 1
385 println("temp was 0")
386 temp = 1
387 }
388
389 if mo != 1 {
390 s.prevOffset[2] = s.prevOffset[1]
391 }
392 s.prevOffset[1] = s.prevOffset[0]
393 s.prevOffset[0] = temp
394 mo = temp
395 return
396}
397
398func (s *sequenceDecs) next(br *bitReader, llState, mlState, ofState decSymbol) (ll, mo, ml int) {
399 // Final will not read from stream.
400 ll, llB := llState.final()
401 ml, mlB := mlState.final()
402 mo, moB := ofState.final()
403
404 // extra bits are stored in reverse order.
405 br.fill()
406 if s.maxBits <= 32 {
407 mo += br.getBits(moB)
408 ml += br.getBits(mlB)
409 ll += br.getBits(llB)
410 } else {
411 mo += br.getBits(moB)
412 br.fill()
413 // matchlength+literal length, max 32 bits
414 ml += br.getBits(mlB)
415 ll += br.getBits(llB)
416
417 }
418 mo = s.adjustOffset(mo, ll, moB)
419 return
420}
421
422func (s *sequenceDecs) adjustOffset(offset, litLen int, offsetB uint8) int {
423 if offsetB > 1 {
424 s.prevOffset[2] = s.prevOffset[1]
425 s.prevOffset[1] = s.prevOffset[0]
426 s.prevOffset[0] = offset
427 return offset
428 }
429
430 if litLen == 0 {
431 // There is an exception though, when current sequence's literals_length = 0.
432 // In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2,
433 // an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte.
434 offset++
435 }
436
437 if offset == 0 {
438 return s.prevOffset[0]
439 }
440 var temp int
441 if offset == 3 {
442 temp = s.prevOffset[0] - 1
443 } else {
444 temp = s.prevOffset[offset]
445 }
446
447 if temp == 0 {
448 // 0 is not valid; input is corrupted; force offset to 1
449 println("temp was 0")
450 temp = 1
451 }
452
453 if offset != 1 {
454 s.prevOffset[2] = s.prevOffset[1]
455 }
456 s.prevOffset[1] = s.prevOffset[0]
457 s.prevOffset[0] = temp
458 return temp
459}
460
461// mergeHistory will merge history.
462func (s *sequenceDecs) mergeHistory(hist *sequenceDecs) (*sequenceDecs, error) {
463 for i := uint(0); i < 3; i++ {
464 var sNew, sHist *sequenceDec
465 switch i {
466 default:
467 // same as "case 0":
468 sNew = &s.litLengths
469 sHist = &hist.litLengths
470 case 1:
471 sNew = &s.offsets
472 sHist = &hist.offsets
473 case 2:
474 sNew = &s.matchLengths
475 sHist = &hist.matchLengths
476 }
477 if sNew.repeat {
478 if sHist.fse == nil {
479 return nil, fmt.Errorf("sequence stream %d, repeat requested, but no history", i)
480 }
481 continue
482 }
483 if sNew.fse == nil {
484 return nil, fmt.Errorf("sequence stream %d, no fse found", i)
485 }
486 if sHist.fse != nil && !sHist.fse.preDefined {
487 fseDecoderPool.Put(sHist.fse)
488 }
489 sHist.fse = sNew.fse
490 }
491 return hist, nil
492}