blob: 15a45f7b5012ab8d9c3e5949cf925541ab6fda8e [file] [log] [blame]
Scott Bakered4efab2020-01-13 19:12:25 -08001// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8 "errors"
9 "fmt"
10 "io"
11)
12
13type seq struct {
14 litLen uint32
15 matchLen uint32
16 offset uint32
17
18 // Codes are stored here for the encoder
19 // so they only have to be looked up once.
20 llCode, mlCode, ofCode uint8
21}
22
23func (s seq) String() string {
24 if s.offset <= 3 {
25 if s.offset == 0 {
26 return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset: INVALID (0)")
27 }
28 return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset:", s.offset, " (repeat)")
29 }
30 return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset:", s.offset-3, " (new)")
31}
32
33type seqCompMode uint8
34
35const (
36 compModePredefined seqCompMode = iota
37 compModeRLE
38 compModeFSE
39 compModeRepeat
40)
41
42type sequenceDec struct {
43 // decoder keeps track of the current state and updates it from the bitstream.
44 fse *fseDecoder
45 state fseState
46 repeat bool
47}
48
49// init the state of the decoder with input from stream.
50func (s *sequenceDec) init(br *bitReader) error {
51 if s.fse == nil {
52 return errors.New("sequence decoder not defined")
53 }
54 s.state.init(br, s.fse.actualTableLog, s.fse.dt[:1<<s.fse.actualTableLog])
55 return nil
56}
57
58// sequenceDecs contains all 3 sequence decoders and their state.
59type sequenceDecs struct {
60 litLengths sequenceDec
61 offsets sequenceDec
62 matchLengths sequenceDec
63 prevOffset [3]int
64 hist []byte
65 literals []byte
66 out []byte
67 maxBits uint8
68}
69
70// initialize all 3 decoders from the stream input.
71func (s *sequenceDecs) initialize(br *bitReader, hist *history, literals, out []byte) error {
72 if err := s.litLengths.init(br); err != nil {
73 return errors.New("litLengths:" + err.Error())
74 }
75 if err := s.offsets.init(br); err != nil {
76 return errors.New("offsets:" + err.Error())
77 }
78 if err := s.matchLengths.init(br); err != nil {
79 return errors.New("matchLengths:" + err.Error())
80 }
81 s.literals = literals
82 s.hist = hist.b
83 s.prevOffset = hist.recentOffsets
84 s.maxBits = s.litLengths.fse.maxBits + s.offsets.fse.maxBits + s.matchLengths.fse.maxBits
85 s.out = out
86 return nil
87}
88
89// decode sequences from the stream with the provided history.
90func (s *sequenceDecs) decode(seqs int, br *bitReader, hist []byte) error {
91 startSize := len(s.out)
92 // Grab full sizes tables, to avoid bounds checks.
93 llTable, mlTable, ofTable := s.litLengths.fse.dt[:maxTablesize], s.matchLengths.fse.dt[:maxTablesize], s.offsets.fse.dt[:maxTablesize]
94 llState, mlState, ofState := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state
95
96 for i := seqs - 1; i >= 0; i-- {
97 if br.overread() {
98 printf("reading sequence %d, exceeded available data\n", seqs-i)
99 return io.ErrUnexpectedEOF
100 }
101 var litLen, matchOff, matchLen int
102 if br.off > 4+((maxOffsetBits+16+16)>>3) {
103 litLen, matchOff, matchLen = s.nextFast(br, llState, mlState, ofState)
104 br.fillFast()
105 } else {
106 litLen, matchOff, matchLen = s.next(br, llState, mlState, ofState)
107 br.fill()
108 }
109
110 if debugSequences {
111 println("Seq", seqs-i-1, "Litlen:", litLen, "matchOff:", matchOff, "(abs) matchLen:", matchLen)
112 }
113
114 if litLen > len(s.literals) {
115 return fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", litLen, len(s.literals))
116 }
117 size := litLen + matchLen + len(s.out)
118 if size-startSize > maxBlockSize {
119 return fmt.Errorf("output (%d) bigger than max block size", size)
120 }
121 if size > cap(s.out) {
122 // Not enough size, will be extremely rarely triggered,
123 // but could be if destination slice is too small for sync operations.
124 // We add maxBlockSize to the capacity.
125 s.out = append(s.out, make([]byte, maxBlockSize)...)
126 s.out = s.out[:len(s.out)-maxBlockSize]
127 }
128 if matchLen > maxMatchLen {
129 return fmt.Errorf("match len (%d) bigger than max allowed length", matchLen)
130 }
131 if matchOff > len(s.out)+len(hist)+litLen {
132 return fmt.Errorf("match offset (%d) bigger than current history (%d)", matchOff, len(s.out)+len(hist)+litLen)
133 }
134 if matchOff == 0 && matchLen > 0 {
135 return fmt.Errorf("zero matchoff and matchlen > 0")
136 }
137
138 s.out = append(s.out, s.literals[:litLen]...)
139 s.literals = s.literals[litLen:]
140 out := s.out
141
142 // Copy from history.
143 // TODO: Blocks without history could be made to ignore this completely.
144 if v := matchOff - len(s.out); v > 0 {
145 // v is the start position in history from end.
146 start := len(s.hist) - v
147 if matchLen > v {
148 // Some goes into current block.
149 // Copy remainder of history
150 out = append(out, s.hist[start:]...)
151 matchOff -= v
152 matchLen -= v
153 } else {
154 out = append(out, s.hist[start:start+matchLen]...)
155 matchLen = 0
156 }
157 }
158 // We must be in current buffer now
159 if matchLen > 0 {
160 start := len(s.out) - matchOff
161 if matchLen <= len(s.out)-start {
162 // No overlap
163 out = append(out, s.out[start:start+matchLen]...)
164 } else {
165 // Overlapping copy
166 // Extend destination slice and copy one byte at the time.
167 out = out[:len(out)+matchLen]
168 src := out[start : start+matchLen]
169 // Destination is the space we just added.
170 dst := out[len(out)-matchLen:]
171 dst = dst[:len(src)]
172 for i := range src {
173 dst[i] = src[i]
174 }
175 }
176 }
177 s.out = out
178 if i == 0 {
179 // This is the last sequence, so we shouldn't update state.
180 break
181 }
182
183 // Manually inlined, ~ 5-20% faster
184 // Update all 3 states at once. Approx 20% faster.
185 nBits := llState.nbBits() + mlState.nbBits() + ofState.nbBits()
186 if nBits == 0 {
187 llState = llTable[llState.newState()&maxTableMask]
188 mlState = mlTable[mlState.newState()&maxTableMask]
189 ofState = ofTable[ofState.newState()&maxTableMask]
190 } else {
191 bits := br.getBitsFast(nBits)
192 lowBits := uint16(bits >> ((ofState.nbBits() + mlState.nbBits()) & 31))
193 llState = llTable[(llState.newState()+lowBits)&maxTableMask]
194
195 lowBits = uint16(bits >> (ofState.nbBits() & 31))
196 lowBits &= bitMask[mlState.nbBits()&15]
197 mlState = mlTable[(mlState.newState()+lowBits)&maxTableMask]
198
199 lowBits = uint16(bits) & bitMask[ofState.nbBits()&15]
200 ofState = ofTable[(ofState.newState()+lowBits)&maxTableMask]
201 }
202 }
203
204 // Add final literals
205 s.out = append(s.out, s.literals...)
206 return nil
207}
208
209// update states, at least 27 bits must be available.
210func (s *sequenceDecs) update(br *bitReader) {
211 // Max 8 bits
212 s.litLengths.state.next(br)
213 // Max 9 bits
214 s.matchLengths.state.next(br)
215 // Max 8 bits
216 s.offsets.state.next(br)
217}
218
219var bitMask [16]uint16
220
221func init() {
222 for i := range bitMask[:] {
223 bitMask[i] = uint16((1 << uint(i)) - 1)
224 }
225}
226
227// update states, at least 27 bits must be available.
228func (s *sequenceDecs) updateAlt(br *bitReader) {
229 // Update all 3 states at once. Approx 20% faster.
230 a, b, c := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state
231
232 nBits := a.nbBits() + b.nbBits() + c.nbBits()
233 if nBits == 0 {
234 s.litLengths.state.state = s.litLengths.state.dt[a.newState()]
235 s.matchLengths.state.state = s.matchLengths.state.dt[b.newState()]
236 s.offsets.state.state = s.offsets.state.dt[c.newState()]
237 return
238 }
239 bits := br.getBitsFast(nBits)
240 lowBits := uint16(bits >> ((c.nbBits() + b.nbBits()) & 31))
241 s.litLengths.state.state = s.litLengths.state.dt[a.newState()+lowBits]
242
243 lowBits = uint16(bits >> (c.nbBits() & 31))
244 lowBits &= bitMask[b.nbBits()&15]
245 s.matchLengths.state.state = s.matchLengths.state.dt[b.newState()+lowBits]
246
247 lowBits = uint16(bits) & bitMask[c.nbBits()&15]
248 s.offsets.state.state = s.offsets.state.dt[c.newState()+lowBits]
249}
250
251// nextFast will return new states when there are at least 4 unused bytes left on the stream when done.
252func (s *sequenceDecs) nextFast(br *bitReader, llState, mlState, ofState decSymbol) (ll, mo, ml int) {
253 // Final will not read from stream.
254 ll, llB := llState.final()
255 ml, mlB := mlState.final()
256 mo, moB := ofState.final()
257
258 // extra bits are stored in reverse order.
259 br.fillFast()
260 mo += br.getBits(moB)
261 if s.maxBits > 32 {
262 br.fillFast()
263 }
264 ml += br.getBits(mlB)
265 ll += br.getBits(llB)
266
267 if moB > 1 {
268 s.prevOffset[2] = s.prevOffset[1]
269 s.prevOffset[1] = s.prevOffset[0]
270 s.prevOffset[0] = mo
271 return
272 }
273 // mo = s.adjustOffset(mo, ll, moB)
274 // Inlined for rather big speedup
275 if ll == 0 {
276 // There is an exception though, when current sequence's literals_length = 0.
277 // In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2,
278 // an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte.
279 mo++
280 }
281
282 if mo == 0 {
283 mo = s.prevOffset[0]
284 return
285 }
286 var temp int
287 if mo == 3 {
288 temp = s.prevOffset[0] - 1
289 } else {
290 temp = s.prevOffset[mo]
291 }
292
293 if temp == 0 {
294 // 0 is not valid; input is corrupted; force offset to 1
295 println("temp was 0")
296 temp = 1
297 }
298
299 if mo != 1 {
300 s.prevOffset[2] = s.prevOffset[1]
301 }
302 s.prevOffset[1] = s.prevOffset[0]
303 s.prevOffset[0] = temp
304 mo = temp
305 return
306}
307
308func (s *sequenceDecs) next(br *bitReader, llState, mlState, ofState decSymbol) (ll, mo, ml int) {
309 // Final will not read from stream.
310 ll, llB := llState.final()
311 ml, mlB := mlState.final()
312 mo, moB := ofState.final()
313
314 // extra bits are stored in reverse order.
315 br.fill()
316 if s.maxBits <= 32 {
317 mo += br.getBits(moB)
318 ml += br.getBits(mlB)
319 ll += br.getBits(llB)
320 } else {
321 mo += br.getBits(moB)
322 br.fill()
323 // matchlength+literal length, max 32 bits
324 ml += br.getBits(mlB)
325 ll += br.getBits(llB)
326
327 }
328 mo = s.adjustOffset(mo, ll, moB)
329 return
330}
331
332func (s *sequenceDecs) adjustOffset(offset, litLen int, offsetB uint8) int {
333 if offsetB > 1 {
334 s.prevOffset[2] = s.prevOffset[1]
335 s.prevOffset[1] = s.prevOffset[0]
336 s.prevOffset[0] = offset
337 return offset
338 }
339
340 if litLen == 0 {
341 // There is an exception though, when current sequence's literals_length = 0.
342 // In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2,
343 // an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte.
344 offset++
345 }
346
347 if offset == 0 {
348 return s.prevOffset[0]
349 }
350 var temp int
351 if offset == 3 {
352 temp = s.prevOffset[0] - 1
353 } else {
354 temp = s.prevOffset[offset]
355 }
356
357 if temp == 0 {
358 // 0 is not valid; input is corrupted; force offset to 1
359 println("temp was 0")
360 temp = 1
361 }
362
363 if offset != 1 {
364 s.prevOffset[2] = s.prevOffset[1]
365 }
366 s.prevOffset[1] = s.prevOffset[0]
367 s.prevOffset[0] = temp
368 return temp
369}
370
371// mergeHistory will merge history.
372func (s *sequenceDecs) mergeHistory(hist *sequenceDecs) (*sequenceDecs, error) {
373 for i := uint(0); i < 3; i++ {
374 var sNew, sHist *sequenceDec
375 switch i {
376 default:
377 // same as "case 0":
378 sNew = &s.litLengths
379 sHist = &hist.litLengths
380 case 1:
381 sNew = &s.offsets
382 sHist = &hist.offsets
383 case 2:
384 sNew = &s.matchLengths
385 sHist = &hist.matchLengths
386 }
387 if sNew.repeat {
388 if sHist.fse == nil {
389 return nil, fmt.Errorf("sequence stream %d, repeat requested, but no history", i)
390 }
391 continue
392 }
393 if sNew.fse == nil {
394 return nil, fmt.Errorf("sequence stream %d, no fse found", i)
395 }
396 if sHist.fse != nil && !sHist.fse.preDefined {
397 fseDecoderPool.Put(sHist.fse)
398 }
399 sHist.fse = sNew.fse
400 }
401 return hist, nil
402}