blob: 97ae66a4ac7cc34098a4af80b26bdb421d1f6767 [file] [log] [blame]
Scott Bakered4efab2020-01-13 19:12:25 -08001package huff0
2
3import (
4 "errors"
5 "fmt"
6 "io"
7
8 "github.com/klauspost/compress/fse"
9)
10
11type dTable struct {
12 single []dEntrySingle
13 double []dEntryDouble
14}
15
16// single-symbols decoding
17type dEntrySingle struct {
18 entry uint16
19}
20
21// double-symbols decoding
22type dEntryDouble struct {
23 seq uint16
24 nBits uint8
25 len uint8
26}
27
28// ReadTable will read a table from the input.
29// The size of the input may be larger than the table definition.
30// Any content remaining after the table definition will be returned.
31// If no Scratch is provided a new one is allocated.
32// The returned Scratch can be used for decoding input using this table.
33func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) {
34 s, err = s.prepare(in)
35 if err != nil {
36 return s, nil, err
37 }
38 if len(in) <= 1 {
39 return s, nil, errors.New("input too small for table")
40 }
41 iSize := in[0]
42 in = in[1:]
43 if iSize >= 128 {
44 // Uncompressed
45 oSize := iSize - 127
46 iSize = (oSize + 1) / 2
47 if int(iSize) > len(in) {
48 return s, nil, errors.New("input too small for table")
49 }
50 for n := uint8(0); n < oSize; n += 2 {
51 v := in[n/2]
52 s.huffWeight[n] = v >> 4
53 s.huffWeight[n+1] = v & 15
54 }
55 s.symbolLen = uint16(oSize)
56 in = in[iSize:]
57 } else {
58 if len(in) <= int(iSize) {
59 return s, nil, errors.New("input too small for table")
60 }
61 // FSE compressed weights
62 s.fse.DecompressLimit = 255
63 hw := s.huffWeight[:]
64 s.fse.Out = hw
65 b, err := fse.Decompress(in[:iSize], s.fse)
66 s.fse.Out = nil
67 if err != nil {
68 return s, nil, err
69 }
70 if len(b) > 255 {
71 return s, nil, errors.New("corrupt input: output table too large")
72 }
73 s.symbolLen = uint16(len(b))
74 in = in[iSize:]
75 }
76
77 // collect weight stats
78 var rankStats [16]uint32
79 weightTotal := uint32(0)
80 for _, v := range s.huffWeight[:s.symbolLen] {
81 if v > tableLogMax {
82 return s, nil, errors.New("corrupt input: weight too large")
83 }
84 v2 := v & 15
85 rankStats[v2]++
86 weightTotal += (1 << v2) >> 1
87 }
88 if weightTotal == 0 {
89 return s, nil, errors.New("corrupt input: weights zero")
90 }
91
92 // get last non-null symbol weight (implied, total must be 2^n)
93 {
94 tableLog := highBit32(weightTotal) + 1
95 if tableLog > tableLogMax {
96 return s, nil, errors.New("corrupt input: tableLog too big")
97 }
98 s.actualTableLog = uint8(tableLog)
99 // determine last weight
100 {
101 total := uint32(1) << tableLog
102 rest := total - weightTotal
103 verif := uint32(1) << highBit32(rest)
104 lastWeight := highBit32(rest) + 1
105 if verif != rest {
106 // last value must be a clean power of 2
107 return s, nil, errors.New("corrupt input: last value not power of two")
108 }
109 s.huffWeight[s.symbolLen] = uint8(lastWeight)
110 s.symbolLen++
111 rankStats[lastWeight]++
112 }
113 }
114
115 if (rankStats[1] < 2) || (rankStats[1]&1 != 0) {
116 // by construction : at least 2 elts of rank 1, must be even
117 return s, nil, errors.New("corrupt input: min elt size, even check failed ")
118 }
119
120 // TODO: Choose between single/double symbol decoding
121
122 // Calculate starting value for each rank
123 {
124 var nextRankStart uint32
125 for n := uint8(1); n < s.actualTableLog+1; n++ {
126 current := nextRankStart
127 nextRankStart += rankStats[n] << (n - 1)
128 rankStats[n] = current
129 }
130 }
131
132 // fill DTable (always full size)
133 tSize := 1 << tableLogMax
134 if len(s.dt.single) != tSize {
135 s.dt.single = make([]dEntrySingle, tSize)
136 }
137 for n, w := range s.huffWeight[:s.symbolLen] {
138 if w == 0 {
139 continue
140 }
141 length := (uint32(1) << w) >> 1
142 d := dEntrySingle{
143 entry: uint16(s.actualTableLog+1-w) | (uint16(n) << 8),
144 }
145 single := s.dt.single[rankStats[w] : rankStats[w]+length]
146 for i := range single {
147 single[i] = d
148 }
149 rankStats[w] += length
150 }
151 return s, in, nil
152}
153
154// Decompress1X will decompress a 1X encoded stream.
155// The length of the supplied input must match the end of a block exactly.
156// Before this is called, the table must be initialized with ReadTable unless
157// the encoder re-used the table.
158func (s *Scratch) Decompress1X(in []byte) (out []byte, err error) {
159 if len(s.dt.single) == 0 {
160 return nil, errors.New("no table loaded")
161 }
162 var br bitReader
163 err = br.init(in)
164 if err != nil {
165 return nil, err
166 }
167 s.Out = s.Out[:0]
168
169 decode := func() byte {
170 val := br.peekBitsFast(s.actualTableLog) /* note : actualTableLog >= 1 */
171 v := s.dt.single[val]
172 br.bitsRead += uint8(v.entry)
173 return uint8(v.entry >> 8)
174 }
175 hasDec := func(v dEntrySingle) byte {
176 br.bitsRead += uint8(v.entry)
177 return uint8(v.entry >> 8)
178 }
179
180 // Avoid bounds check by always having full sized table.
181 const tlSize = 1 << tableLogMax
182 const tlMask = tlSize - 1
183 dt := s.dt.single[:tlSize]
184
185 // Use temp table to avoid bound checks/append penalty.
186 var tmp = s.huffWeight[:256]
187 var off uint8
188
189 for br.off >= 8 {
190 br.fillFast()
191 tmp[off+0] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask])
192 tmp[off+1] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask])
193 br.fillFast()
194 tmp[off+2] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask])
195 tmp[off+3] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask])
196 off += 4
197 if off == 0 {
198 if len(s.Out)+256 > s.MaxDecodedSize {
199 br.close()
200 return nil, ErrMaxDecodedSizeExceeded
201 }
202 s.Out = append(s.Out, tmp...)
203 }
204 }
205
206 if len(s.Out)+int(off) > s.MaxDecodedSize {
207 br.close()
208 return nil, ErrMaxDecodedSizeExceeded
209 }
210 s.Out = append(s.Out, tmp[:off]...)
211
212 for !br.finished() {
213 br.fill()
214 if len(s.Out) >= s.MaxDecodedSize {
215 br.close()
216 return nil, ErrMaxDecodedSizeExceeded
217 }
218 s.Out = append(s.Out, decode())
219 }
220 return s.Out, br.close()
221}
222
223// Decompress4X will decompress a 4X encoded stream.
224// Before this is called, the table must be initialized with ReadTable unless
225// the encoder re-used the table.
226// The length of the supplied input must match the end of a block exactly.
227// The destination size of the uncompressed data must be known and provided.
228func (s *Scratch) Decompress4X(in []byte, dstSize int) (out []byte, err error) {
229 if len(s.dt.single) == 0 {
230 return nil, errors.New("no table loaded")
231 }
232 if len(in) < 6+(4*1) {
233 return nil, errors.New("input too small")
234 }
235 if dstSize > s.MaxDecodedSize {
236 return nil, ErrMaxDecodedSizeExceeded
237 }
238 // TODO: We do not detect when we overrun a buffer, except if the last one does.
239
240 var br [4]bitReader
241 start := 6
242 for i := 0; i < 3; i++ {
243 length := int(in[i*2]) | (int(in[i*2+1]) << 8)
244 if start+length >= len(in) {
245 return nil, errors.New("truncated input (or invalid offset)")
246 }
247 err = br[i].init(in[start : start+length])
248 if err != nil {
249 return nil, err
250 }
251 start += length
252 }
253 err = br[3].init(in[start:])
254 if err != nil {
255 return nil, err
256 }
257
258 // Prepare output
259 if cap(s.Out) < dstSize {
260 s.Out = make([]byte, 0, dstSize)
261 }
262 s.Out = s.Out[:dstSize]
263 // destination, offset to match first output
264 dstOut := s.Out
265 dstEvery := (dstSize + 3) / 4
266
267 const tlSize = 1 << tableLogMax
268 const tlMask = tlSize - 1
269 single := s.dt.single[:tlSize]
270
271 decode := func(br *bitReader) byte {
272 val := br.peekBitsFast(s.actualTableLog) /* note : actualTableLog >= 1 */
273 v := single[val&tlMask]
274 br.bitsRead += uint8(v.entry)
275 return uint8(v.entry >> 8)
276 }
277
278 // Use temp table to avoid bound checks/append penalty.
279 var tmp = s.huffWeight[:256]
280 var off uint8
281 var decoded int
282
283 // Decode 2 values from each decoder/loop.
284 const bufoff = 256 / 4
285bigloop:
286 for {
287 for i := range br {
288 br := &br[i]
289 if br.off < 4 {
290 break bigloop
291 }
292 br.fillFast()
293 }
294
295 {
296 const stream = 0
297 val := br[stream].peekBitsFast(s.actualTableLog)
298 v := single[val&tlMask]
299 br[stream].bitsRead += uint8(v.entry)
300
301 val2 := br[stream].peekBitsFast(s.actualTableLog)
302 v2 := single[val2&tlMask]
303 tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
304 tmp[off+bufoff*stream] = uint8(v.entry >> 8)
305 br[stream].bitsRead += uint8(v2.entry)
306 }
307
308 {
309 const stream = 1
310 val := br[stream].peekBitsFast(s.actualTableLog)
311 v := single[val&tlMask]
312 br[stream].bitsRead += uint8(v.entry)
313
314 val2 := br[stream].peekBitsFast(s.actualTableLog)
315 v2 := single[val2&tlMask]
316 tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
317 tmp[off+bufoff*stream] = uint8(v.entry >> 8)
318 br[stream].bitsRead += uint8(v2.entry)
319 }
320
321 {
322 const stream = 2
323 val := br[stream].peekBitsFast(s.actualTableLog)
324 v := single[val&tlMask]
325 br[stream].bitsRead += uint8(v.entry)
326
327 val2 := br[stream].peekBitsFast(s.actualTableLog)
328 v2 := single[val2&tlMask]
329 tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
330 tmp[off+bufoff*stream] = uint8(v.entry >> 8)
331 br[stream].bitsRead += uint8(v2.entry)
332 }
333
334 {
335 const stream = 3
336 val := br[stream].peekBitsFast(s.actualTableLog)
337 v := single[val&tlMask]
338 br[stream].bitsRead += uint8(v.entry)
339
340 val2 := br[stream].peekBitsFast(s.actualTableLog)
341 v2 := single[val2&tlMask]
342 tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
343 tmp[off+bufoff*stream] = uint8(v.entry >> 8)
344 br[stream].bitsRead += uint8(v2.entry)
345 }
346
347 off += 2
348
349 if off == bufoff {
350 if bufoff > dstEvery {
351 return nil, errors.New("corruption detected: stream overrun 1")
352 }
353 copy(dstOut, tmp[:bufoff])
354 copy(dstOut[dstEvery:], tmp[bufoff:bufoff*2])
355 copy(dstOut[dstEvery*2:], tmp[bufoff*2:bufoff*3])
356 copy(dstOut[dstEvery*3:], tmp[bufoff*3:bufoff*4])
357 off = 0
358 dstOut = dstOut[bufoff:]
359 decoded += 256
360 // There must at least be 3 buffers left.
361 if len(dstOut) < dstEvery*3 {
362 return nil, errors.New("corruption detected: stream overrun 2")
363 }
364 }
365 }
366 if off > 0 {
367 ioff := int(off)
368 if len(dstOut) < dstEvery*3+ioff {
369 return nil, errors.New("corruption detected: stream overrun 3")
370 }
371 copy(dstOut, tmp[:off])
372 copy(dstOut[dstEvery:dstEvery+ioff], tmp[bufoff:bufoff*2])
373 copy(dstOut[dstEvery*2:dstEvery*2+ioff], tmp[bufoff*2:bufoff*3])
374 copy(dstOut[dstEvery*3:dstEvery*3+ioff], tmp[bufoff*3:bufoff*4])
375 decoded += int(off) * 4
376 dstOut = dstOut[off:]
377 }
378
379 // Decode remaining.
380 for i := range br {
381 offset := dstEvery * i
382 br := &br[i]
383 for !br.finished() {
384 br.fill()
385 if offset >= len(dstOut) {
386 return nil, errors.New("corruption detected: stream overrun 4")
387 }
388 dstOut[offset] = decode(br)
389 offset++
390 }
391 decoded += offset - dstEvery*i
392 err = br.close()
393 if err != nil {
394 return nil, err
395 }
396 }
397 if dstSize != decoded {
398 return nil, errors.New("corruption detected: short output block")
399 }
400 return s.Out, nil
401}
402
403// matches will compare a decoding table to a coding table.
404// Errors are written to the writer.
405// Nothing will be written if table is ok.
406func (s *Scratch) matches(ct cTable, w io.Writer) {
407 if s == nil || len(s.dt.single) == 0 {
408 return
409 }
410 dt := s.dt.single[:1<<s.actualTableLog]
411 tablelog := s.actualTableLog
412 ok := 0
413 broken := 0
414 for sym, enc := range ct {
415 errs := 0
416 broken++
417 if enc.nBits == 0 {
418 for _, dec := range dt {
419 if uint8(dec.entry>>8) == byte(sym) {
420 fmt.Fprintf(w, "symbol %x has decoder, but no encoder\n", sym)
421 errs++
422 break
423 }
424 }
425 if errs == 0 {
426 broken--
427 }
428 continue
429 }
430 // Unused bits in input
431 ub := tablelog - enc.nBits
432 top := enc.val << ub
433 // decoder looks at top bits.
434 dec := dt[top]
435 if uint8(dec.entry) != enc.nBits {
436 fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", sym, enc.nBits, uint8(dec.entry))
437 errs++
438 }
439 if uint8(dec.entry>>8) != uint8(sym) {
440 fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", sym, sym, uint8(dec.entry>>8))
441 errs++
442 }
443 if errs > 0 {
444 fmt.Fprintf(w, "%d errros in base, stopping\n", errs)
445 continue
446 }
447 // Ensure that all combinations are covered.
448 for i := uint16(0); i < (1 << ub); i++ {
449 vval := top | i
450 dec := dt[vval]
451 if uint8(dec.entry) != enc.nBits {
452 fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", vval, enc.nBits, uint8(dec.entry))
453 errs++
454 }
455 if uint8(dec.entry>>8) != uint8(sym) {
456 fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", vval, sym, uint8(dec.entry>>8))
457 errs++
458 }
459 if errs > 20 {
460 fmt.Fprintf(w, "%d errros, stopping\n", errs)
461 break
462 }
463 }
464 if errs == 0 {
465 ok++
466 broken--
467 }
468 }
469 if broken > 0 {
470 fmt.Fprintf(w, "%d broken, %d ok\n", broken, ok)
471 }
472}