blob: 43b4815b3792da8f6fc46f366bd2ac0589548229 [file] [log] [blame]
Dinesh Belwalkare63f7f92019-11-22 23:11:16 +00001package huff0
2
3import (
4 "errors"
5 "fmt"
6 "io"
7
8 "github.com/klauspost/compress/fse"
9)
10
11type dTable struct {
12 single []dEntrySingle
13 double []dEntryDouble
14}
15
16// single-symbols decoding
17type dEntrySingle struct {
18 byte uint8
19 nBits uint8
20}
21
22// double-symbols decoding
23type dEntryDouble struct {
24 seq uint16
25 nBits uint8
26 len uint8
27}
28
29// ReadTable will read a table from the input.
30// The size of the input may be larger than the table definition.
31// Any content remaining after the table definition will be returned.
32// If no Scratch is provided a new one is allocated.
33// The returned Scratch can be used for decoding input using this table.
34func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) {
35 s, err = s.prepare(in)
36 if err != nil {
37 return s, nil, err
38 }
39 if len(in) <= 1 {
40 return s, nil, errors.New("input too small for table")
41 }
42 iSize := in[0]
43 in = in[1:]
44 if iSize >= 128 {
45 // Uncompressed
46 oSize := iSize - 127
47 iSize = (oSize + 1) / 2
48 if int(iSize) > len(in) {
49 return s, nil, errors.New("input too small for table")
50 }
51 for n := uint8(0); n < oSize; n += 2 {
52 v := in[n/2]
53 s.huffWeight[n] = v >> 4
54 s.huffWeight[n+1] = v & 15
55 }
56 s.symbolLen = uint16(oSize)
57 in = in[iSize:]
58 } else {
59 if len(in) <= int(iSize) {
60 return s, nil, errors.New("input too small for table")
61 }
62 // FSE compressed weights
63 s.fse.DecompressLimit = 255
64 hw := s.huffWeight[:]
65 s.fse.Out = hw
66 b, err := fse.Decompress(in[:iSize], s.fse)
67 s.fse.Out = nil
68 if err != nil {
69 return s, nil, err
70 }
71 if len(b) > 255 {
72 return s, nil, errors.New("corrupt input: output table too large")
73 }
74 s.symbolLen = uint16(len(b))
75 in = in[iSize:]
76 }
77
78 // collect weight stats
79 var rankStats [tableLogMax + 1]uint32
80 weightTotal := uint32(0)
81 for _, v := range s.huffWeight[:s.symbolLen] {
82 if v > tableLogMax {
83 return s, nil, errors.New("corrupt input: weight too large")
84 }
85 rankStats[v]++
86 weightTotal += (1 << (v & 15)) >> 1
87 }
88 if weightTotal == 0 {
89 return s, nil, errors.New("corrupt input: weights zero")
90 }
91
92 // get last non-null symbol weight (implied, total must be 2^n)
93 {
94 tableLog := highBit32(weightTotal) + 1
95 if tableLog > tableLogMax {
96 return s, nil, errors.New("corrupt input: tableLog too big")
97 }
98 s.actualTableLog = uint8(tableLog)
99 // determine last weight
100 {
101 total := uint32(1) << tableLog
102 rest := total - weightTotal
103 verif := uint32(1) << highBit32(rest)
104 lastWeight := highBit32(rest) + 1
105 if verif != rest {
106 // last value must be a clean power of 2
107 return s, nil, errors.New("corrupt input: last value not power of two")
108 }
109 s.huffWeight[s.symbolLen] = uint8(lastWeight)
110 s.symbolLen++
111 rankStats[lastWeight]++
112 }
113 }
114
115 if (rankStats[1] < 2) || (rankStats[1]&1 != 0) {
116 // by construction : at least 2 elts of rank 1, must be even
117 return s, nil, errors.New("corrupt input: min elt size, even check failed ")
118 }
119
120 // TODO: Choose between single/double symbol decoding
121
122 // Calculate starting value for each rank
123 {
124 var nextRankStart uint32
125 for n := uint8(1); n < s.actualTableLog+1; n++ {
126 current := nextRankStart
127 nextRankStart += rankStats[n] << (n - 1)
128 rankStats[n] = current
129 }
130 }
131
132 // fill DTable (always full size)
133 tSize := 1 << tableLogMax
134 if len(s.dt.single) != tSize {
135 s.dt.single = make([]dEntrySingle, tSize)
136 }
137
138 for n, w := range s.huffWeight[:s.symbolLen] {
139 length := (uint32(1) << w) >> 1
140 d := dEntrySingle{
141 byte: uint8(n),
142 nBits: s.actualTableLog + 1 - w,
143 }
144 for u := rankStats[w]; u < rankStats[w]+length; u++ {
145 s.dt.single[u] = d
146 }
147 rankStats[w] += length
148 }
149 return s, in, nil
150}
151
152// Decompress1X will decompress a 1X encoded stream.
153// The length of the supplied input must match the end of a block exactly.
154// Before this is called, the table must be initialized with ReadTable unless
155// the encoder re-used the table.
156func (s *Scratch) Decompress1X(in []byte) (out []byte, err error) {
157 if len(s.dt.single) == 0 {
158 return nil, errors.New("no table loaded")
159 }
160 var br bitReader
161 err = br.init(in)
162 if err != nil {
163 return nil, err
164 }
165 s.Out = s.Out[:0]
166
167 decode := func() byte {
168 val := br.peekBitsFast(s.actualTableLog) /* note : actualTableLog >= 1 */
169 v := s.dt.single[val]
170 br.bitsRead += v.nBits
171 return v.byte
172 }
173 hasDec := func(v dEntrySingle) byte {
174 br.bitsRead += v.nBits
175 return v.byte
176 }
177
178 // Avoid bounds check by always having full sized table.
179 const tlSize = 1 << tableLogMax
180 const tlMask = tlSize - 1
181 dt := s.dt.single[:tlSize]
182
183 // Use temp table to avoid bound checks/append penalty.
184 var tmp = s.huffWeight[:256]
185 var off uint8
186
187 for br.off >= 8 {
188 br.fillFast()
189 tmp[off+0] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask])
190 tmp[off+1] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask])
191 br.fillFast()
192 tmp[off+2] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask])
193 tmp[off+3] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask])
194 off += 4
195 if off == 0 {
196 if len(s.Out)+256 > s.MaxDecodedSize {
197 br.close()
198 return nil, ErrMaxDecodedSizeExceeded
199 }
200 s.Out = append(s.Out, tmp...)
201 }
202 }
203
204 if len(s.Out)+int(off) > s.MaxDecodedSize {
205 br.close()
206 return nil, ErrMaxDecodedSizeExceeded
207 }
208 s.Out = append(s.Out, tmp[:off]...)
209
210 for !br.finished() {
211 br.fill()
212 if len(s.Out) >= s.MaxDecodedSize {
213 br.close()
214 return nil, ErrMaxDecodedSizeExceeded
215 }
216 s.Out = append(s.Out, decode())
217 }
218 return s.Out, br.close()
219}
220
221// Decompress4X will decompress a 4X encoded stream.
222// Before this is called, the table must be initialized with ReadTable unless
223// the encoder re-used the table.
224// The length of the supplied input must match the end of a block exactly.
225// The destination size of the uncompressed data must be known and provided.
226func (s *Scratch) Decompress4X(in []byte, dstSize int) (out []byte, err error) {
227 if len(s.dt.single) == 0 {
228 return nil, errors.New("no table loaded")
229 }
230 if len(in) < 6+(4*1) {
231 return nil, errors.New("input too small")
232 }
233 if dstSize > s.MaxDecodedSize {
234 return nil, ErrMaxDecodedSizeExceeded
235 }
236 // TODO: We do not detect when we overrun a buffer, except if the last one does.
237
238 var br [4]bitReader
239 start := 6
240 for i := 0; i < 3; i++ {
241 length := int(in[i*2]) | (int(in[i*2+1]) << 8)
242 if start+length >= len(in) {
243 return nil, errors.New("truncated input (or invalid offset)")
244 }
245 err = br[i].init(in[start : start+length])
246 if err != nil {
247 return nil, err
248 }
249 start += length
250 }
251 err = br[3].init(in[start:])
252 if err != nil {
253 return nil, err
254 }
255
256 // Prepare output
257 if cap(s.Out) < dstSize {
258 s.Out = make([]byte, 0, dstSize)
259 }
260 s.Out = s.Out[:dstSize]
261 // destination, offset to match first output
262 dstOut := s.Out
263 dstEvery := (dstSize + 3) / 4
264
265 const tlSize = 1 << tableLogMax
266 const tlMask = tlSize - 1
267 single := s.dt.single[:tlSize]
268
269 decode := func(br *bitReader) byte {
270 val := br.peekBitsFast(s.actualTableLog) /* note : actualTableLog >= 1 */
271 v := single[val&tlMask]
272 br.bitsRead += v.nBits
273 return v.byte
274 }
275
276 // Use temp table to avoid bound checks/append penalty.
277 var tmp = s.huffWeight[:256]
278 var off uint8
279
280 // Decode 2 values from each decoder/loop.
281 const bufoff = 256 / 4
282bigloop:
283 for {
284 for i := range br {
285 if br[i].off < 4 {
286 break bigloop
287 }
288 br[i].fillFast()
289 }
290 tmp[off] = decode(&br[0])
291 tmp[off+bufoff] = decode(&br[1])
292 tmp[off+bufoff*2] = decode(&br[2])
293 tmp[off+bufoff*3] = decode(&br[3])
294 tmp[off+1] = decode(&br[0])
295 tmp[off+1+bufoff] = decode(&br[1])
296 tmp[off+1+bufoff*2] = decode(&br[2])
297 tmp[off+1+bufoff*3] = decode(&br[3])
298 off += 2
299 if off == bufoff {
300 if bufoff > dstEvery {
301 return nil, errors.New("corruption detected: stream overrun 1")
302 }
303 copy(dstOut, tmp[:bufoff])
304 copy(dstOut[dstEvery:], tmp[bufoff:bufoff*2])
305 copy(dstOut[dstEvery*2:], tmp[bufoff*2:bufoff*3])
306 copy(dstOut[dstEvery*3:], tmp[bufoff*3:bufoff*4])
307 off = 0
308 dstOut = dstOut[bufoff:]
309 // There must at least be 3 buffers left.
310 if len(dstOut) < dstEvery*3 {
311 return nil, errors.New("corruption detected: stream overrun 2")
312 }
313 }
314 }
315 if off > 0 {
316 ioff := int(off)
317 if len(dstOut) < dstEvery*3+ioff {
318 return nil, errors.New("corruption detected: stream overrun 3")
319 }
320 copy(dstOut, tmp[:off])
321 copy(dstOut[dstEvery:dstEvery+ioff], tmp[bufoff:bufoff*2])
322 copy(dstOut[dstEvery*2:dstEvery*2+ioff], tmp[bufoff*2:bufoff*3])
323 copy(dstOut[dstEvery*3:dstEvery*3+ioff], tmp[bufoff*3:bufoff*4])
324 dstOut = dstOut[off:]
325 }
326
327 for i := range br {
328 offset := dstEvery * i
329 br := &br[i]
330 for !br.finished() {
331 br.fill()
332 if offset >= len(dstOut) {
333 return nil, errors.New("corruption detected: stream overrun 4")
334 }
335 dstOut[offset] = decode(br)
336 offset++
337 }
338 err = br.close()
339 if err != nil {
340 return nil, err
341 }
342 }
343
344 return s.Out, nil
345}
346
347// matches will compare a decoding table to a coding table.
348// Errors are written to the writer.
349// Nothing will be written if table is ok.
350func (s *Scratch) matches(ct cTable, w io.Writer) {
351 if s == nil || len(s.dt.single) == 0 {
352 return
353 }
354 dt := s.dt.single[:1<<s.actualTableLog]
355 tablelog := s.actualTableLog
356 ok := 0
357 broken := 0
358 for sym, enc := range ct {
359 errs := 0
360 broken++
361 if enc.nBits == 0 {
362 for _, dec := range dt {
363 if dec.byte == byte(sym) {
364 fmt.Fprintf(w, "symbol %x has decoder, but no encoder\n", sym)
365 errs++
366 break
367 }
368 }
369 if errs == 0 {
370 broken--
371 }
372 continue
373 }
374 // Unused bits in input
375 ub := tablelog - enc.nBits
376 top := enc.val << ub
377 // decoder looks at top bits.
378 dec := dt[top]
379 if dec.nBits != enc.nBits {
380 fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", sym, enc.nBits, dec.nBits)
381 errs++
382 }
383 if dec.byte != uint8(sym) {
384 fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", sym, sym, dec.byte)
385 errs++
386 }
387 if errs > 0 {
388 fmt.Fprintf(w, "%d errros in base, stopping\n", errs)
389 continue
390 }
391 // Ensure that all combinations are covered.
392 for i := uint16(0); i < (1 << ub); i++ {
393 vval := top | i
394 dec := dt[vval]
395 if dec.nBits != enc.nBits {
396 fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", vval, enc.nBits, dec.nBits)
397 errs++
398 }
399 if dec.byte != uint8(sym) {
400 fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", vval, sym, dec.byte)
401 errs++
402 }
403 if errs > 20 {
404 fmt.Fprintf(w, "%d errros, stopping\n", errs)
405 break
406 }
407 }
408 if errs == 0 {
409 ok++
410 broken--
411 }
412 }
413 if broken > 0 {
414 fmt.Fprintf(w, "%d broken, %d ok\n", broken, ok)
415 }
416}