Dinesh Belwalkar | e63f7f9 | 2019-11-22 23:11:16 +0000 | [diff] [blame] | 1 | package huff0 |
| 2 | |
| 3 | import ( |
| 4 | "errors" |
| 5 | "fmt" |
| 6 | "io" |
| 7 | |
| 8 | "github.com/klauspost/compress/fse" |
| 9 | ) |
| 10 | |
| 11 | type dTable struct { |
| 12 | single []dEntrySingle |
| 13 | double []dEntryDouble |
| 14 | } |
| 15 | |
| 16 | // single-symbols decoding |
| 17 | type dEntrySingle struct { |
| 18 | byte uint8 |
| 19 | nBits uint8 |
| 20 | } |
| 21 | |
| 22 | // double-symbols decoding |
| 23 | type dEntryDouble struct { |
| 24 | seq uint16 |
| 25 | nBits uint8 |
| 26 | len uint8 |
| 27 | } |
| 28 | |
| 29 | // ReadTable will read a table from the input. |
| 30 | // The size of the input may be larger than the table definition. |
| 31 | // Any content remaining after the table definition will be returned. |
| 32 | // If no Scratch is provided a new one is allocated. |
| 33 | // The returned Scratch can be used for decoding input using this table. |
| 34 | func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) { |
| 35 | s, err = s.prepare(in) |
| 36 | if err != nil { |
| 37 | return s, nil, err |
| 38 | } |
| 39 | if len(in) <= 1 { |
| 40 | return s, nil, errors.New("input too small for table") |
| 41 | } |
| 42 | iSize := in[0] |
| 43 | in = in[1:] |
| 44 | if iSize >= 128 { |
| 45 | // Uncompressed |
| 46 | oSize := iSize - 127 |
| 47 | iSize = (oSize + 1) / 2 |
| 48 | if int(iSize) > len(in) { |
| 49 | return s, nil, errors.New("input too small for table") |
| 50 | } |
| 51 | for n := uint8(0); n < oSize; n += 2 { |
| 52 | v := in[n/2] |
| 53 | s.huffWeight[n] = v >> 4 |
| 54 | s.huffWeight[n+1] = v & 15 |
| 55 | } |
| 56 | s.symbolLen = uint16(oSize) |
| 57 | in = in[iSize:] |
| 58 | } else { |
| 59 | if len(in) <= int(iSize) { |
| 60 | return s, nil, errors.New("input too small for table") |
| 61 | } |
| 62 | // FSE compressed weights |
| 63 | s.fse.DecompressLimit = 255 |
| 64 | hw := s.huffWeight[:] |
| 65 | s.fse.Out = hw |
| 66 | b, err := fse.Decompress(in[:iSize], s.fse) |
| 67 | s.fse.Out = nil |
| 68 | if err != nil { |
| 69 | return s, nil, err |
| 70 | } |
| 71 | if len(b) > 255 { |
| 72 | return s, nil, errors.New("corrupt input: output table too large") |
| 73 | } |
| 74 | s.symbolLen = uint16(len(b)) |
| 75 | in = in[iSize:] |
| 76 | } |
| 77 | |
| 78 | // collect weight stats |
| 79 | var rankStats [tableLogMax + 1]uint32 |
| 80 | weightTotal := uint32(0) |
| 81 | for _, v := range s.huffWeight[:s.symbolLen] { |
| 82 | if v > tableLogMax { |
| 83 | return s, nil, errors.New("corrupt input: weight too large") |
| 84 | } |
| 85 | rankStats[v]++ |
| 86 | weightTotal += (1 << (v & 15)) >> 1 |
| 87 | } |
| 88 | if weightTotal == 0 { |
| 89 | return s, nil, errors.New("corrupt input: weights zero") |
| 90 | } |
| 91 | |
| 92 | // get last non-null symbol weight (implied, total must be 2^n) |
| 93 | { |
| 94 | tableLog := highBit32(weightTotal) + 1 |
| 95 | if tableLog > tableLogMax { |
| 96 | return s, nil, errors.New("corrupt input: tableLog too big") |
| 97 | } |
| 98 | s.actualTableLog = uint8(tableLog) |
| 99 | // determine last weight |
| 100 | { |
| 101 | total := uint32(1) << tableLog |
| 102 | rest := total - weightTotal |
| 103 | verif := uint32(1) << highBit32(rest) |
| 104 | lastWeight := highBit32(rest) + 1 |
| 105 | if verif != rest { |
| 106 | // last value must be a clean power of 2 |
| 107 | return s, nil, errors.New("corrupt input: last value not power of two") |
| 108 | } |
| 109 | s.huffWeight[s.symbolLen] = uint8(lastWeight) |
| 110 | s.symbolLen++ |
| 111 | rankStats[lastWeight]++ |
| 112 | } |
| 113 | } |
| 114 | |
| 115 | if (rankStats[1] < 2) || (rankStats[1]&1 != 0) { |
| 116 | // by construction : at least 2 elts of rank 1, must be even |
| 117 | return s, nil, errors.New("corrupt input: min elt size, even check failed ") |
| 118 | } |
| 119 | |
| 120 | // TODO: Choose between single/double symbol decoding |
| 121 | |
| 122 | // Calculate starting value for each rank |
| 123 | { |
| 124 | var nextRankStart uint32 |
| 125 | for n := uint8(1); n < s.actualTableLog+1; n++ { |
| 126 | current := nextRankStart |
| 127 | nextRankStart += rankStats[n] << (n - 1) |
| 128 | rankStats[n] = current |
| 129 | } |
| 130 | } |
| 131 | |
| 132 | // fill DTable (always full size) |
| 133 | tSize := 1 << tableLogMax |
| 134 | if len(s.dt.single) != tSize { |
| 135 | s.dt.single = make([]dEntrySingle, tSize) |
| 136 | } |
| 137 | |
| 138 | for n, w := range s.huffWeight[:s.symbolLen] { |
| 139 | length := (uint32(1) << w) >> 1 |
| 140 | d := dEntrySingle{ |
| 141 | byte: uint8(n), |
| 142 | nBits: s.actualTableLog + 1 - w, |
| 143 | } |
| 144 | for u := rankStats[w]; u < rankStats[w]+length; u++ { |
| 145 | s.dt.single[u] = d |
| 146 | } |
| 147 | rankStats[w] += length |
| 148 | } |
| 149 | return s, in, nil |
| 150 | } |
| 151 | |
| 152 | // Decompress1X will decompress a 1X encoded stream. |
| 153 | // The length of the supplied input must match the end of a block exactly. |
| 154 | // Before this is called, the table must be initialized with ReadTable unless |
| 155 | // the encoder re-used the table. |
| 156 | func (s *Scratch) Decompress1X(in []byte) (out []byte, err error) { |
| 157 | if len(s.dt.single) == 0 { |
| 158 | return nil, errors.New("no table loaded") |
| 159 | } |
| 160 | var br bitReader |
| 161 | err = br.init(in) |
| 162 | if err != nil { |
| 163 | return nil, err |
| 164 | } |
| 165 | s.Out = s.Out[:0] |
| 166 | |
| 167 | decode := func() byte { |
| 168 | val := br.peekBitsFast(s.actualTableLog) /* note : actualTableLog >= 1 */ |
| 169 | v := s.dt.single[val] |
| 170 | br.bitsRead += v.nBits |
| 171 | return v.byte |
| 172 | } |
| 173 | hasDec := func(v dEntrySingle) byte { |
| 174 | br.bitsRead += v.nBits |
| 175 | return v.byte |
| 176 | } |
| 177 | |
| 178 | // Avoid bounds check by always having full sized table. |
| 179 | const tlSize = 1 << tableLogMax |
| 180 | const tlMask = tlSize - 1 |
| 181 | dt := s.dt.single[:tlSize] |
| 182 | |
| 183 | // Use temp table to avoid bound checks/append penalty. |
| 184 | var tmp = s.huffWeight[:256] |
| 185 | var off uint8 |
| 186 | |
| 187 | for br.off >= 8 { |
| 188 | br.fillFast() |
| 189 | tmp[off+0] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask]) |
| 190 | tmp[off+1] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask]) |
| 191 | br.fillFast() |
| 192 | tmp[off+2] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask]) |
| 193 | tmp[off+3] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask]) |
| 194 | off += 4 |
| 195 | if off == 0 { |
| 196 | if len(s.Out)+256 > s.MaxDecodedSize { |
| 197 | br.close() |
| 198 | return nil, ErrMaxDecodedSizeExceeded |
| 199 | } |
| 200 | s.Out = append(s.Out, tmp...) |
| 201 | } |
| 202 | } |
| 203 | |
| 204 | if len(s.Out)+int(off) > s.MaxDecodedSize { |
| 205 | br.close() |
| 206 | return nil, ErrMaxDecodedSizeExceeded |
| 207 | } |
| 208 | s.Out = append(s.Out, tmp[:off]...) |
| 209 | |
| 210 | for !br.finished() { |
| 211 | br.fill() |
| 212 | if len(s.Out) >= s.MaxDecodedSize { |
| 213 | br.close() |
| 214 | return nil, ErrMaxDecodedSizeExceeded |
| 215 | } |
| 216 | s.Out = append(s.Out, decode()) |
| 217 | } |
| 218 | return s.Out, br.close() |
| 219 | } |
| 220 | |
| 221 | // Decompress4X will decompress a 4X encoded stream. |
| 222 | // Before this is called, the table must be initialized with ReadTable unless |
| 223 | // the encoder re-used the table. |
| 224 | // The length of the supplied input must match the end of a block exactly. |
| 225 | // The destination size of the uncompressed data must be known and provided. |
| 226 | func (s *Scratch) Decompress4X(in []byte, dstSize int) (out []byte, err error) { |
| 227 | if len(s.dt.single) == 0 { |
| 228 | return nil, errors.New("no table loaded") |
| 229 | } |
| 230 | if len(in) < 6+(4*1) { |
| 231 | return nil, errors.New("input too small") |
| 232 | } |
| 233 | if dstSize > s.MaxDecodedSize { |
| 234 | return nil, ErrMaxDecodedSizeExceeded |
| 235 | } |
| 236 | // TODO: We do not detect when we overrun a buffer, except if the last one does. |
| 237 | |
| 238 | var br [4]bitReader |
| 239 | start := 6 |
| 240 | for i := 0; i < 3; i++ { |
| 241 | length := int(in[i*2]) | (int(in[i*2+1]) << 8) |
| 242 | if start+length >= len(in) { |
| 243 | return nil, errors.New("truncated input (or invalid offset)") |
| 244 | } |
| 245 | err = br[i].init(in[start : start+length]) |
| 246 | if err != nil { |
| 247 | return nil, err |
| 248 | } |
| 249 | start += length |
| 250 | } |
| 251 | err = br[3].init(in[start:]) |
| 252 | if err != nil { |
| 253 | return nil, err |
| 254 | } |
| 255 | |
| 256 | // Prepare output |
| 257 | if cap(s.Out) < dstSize { |
| 258 | s.Out = make([]byte, 0, dstSize) |
| 259 | } |
| 260 | s.Out = s.Out[:dstSize] |
| 261 | // destination, offset to match first output |
| 262 | dstOut := s.Out |
| 263 | dstEvery := (dstSize + 3) / 4 |
| 264 | |
| 265 | const tlSize = 1 << tableLogMax |
| 266 | const tlMask = tlSize - 1 |
| 267 | single := s.dt.single[:tlSize] |
| 268 | |
| 269 | decode := func(br *bitReader) byte { |
| 270 | val := br.peekBitsFast(s.actualTableLog) /* note : actualTableLog >= 1 */ |
| 271 | v := single[val&tlMask] |
| 272 | br.bitsRead += v.nBits |
| 273 | return v.byte |
| 274 | } |
| 275 | |
| 276 | // Use temp table to avoid bound checks/append penalty. |
| 277 | var tmp = s.huffWeight[:256] |
| 278 | var off uint8 |
| 279 | |
| 280 | // Decode 2 values from each decoder/loop. |
| 281 | const bufoff = 256 / 4 |
| 282 | bigloop: |
| 283 | for { |
| 284 | for i := range br { |
| 285 | if br[i].off < 4 { |
| 286 | break bigloop |
| 287 | } |
| 288 | br[i].fillFast() |
| 289 | } |
| 290 | tmp[off] = decode(&br[0]) |
| 291 | tmp[off+bufoff] = decode(&br[1]) |
| 292 | tmp[off+bufoff*2] = decode(&br[2]) |
| 293 | tmp[off+bufoff*3] = decode(&br[3]) |
| 294 | tmp[off+1] = decode(&br[0]) |
| 295 | tmp[off+1+bufoff] = decode(&br[1]) |
| 296 | tmp[off+1+bufoff*2] = decode(&br[2]) |
| 297 | tmp[off+1+bufoff*3] = decode(&br[3]) |
| 298 | off += 2 |
| 299 | if off == bufoff { |
| 300 | if bufoff > dstEvery { |
| 301 | return nil, errors.New("corruption detected: stream overrun 1") |
| 302 | } |
| 303 | copy(dstOut, tmp[:bufoff]) |
| 304 | copy(dstOut[dstEvery:], tmp[bufoff:bufoff*2]) |
| 305 | copy(dstOut[dstEvery*2:], tmp[bufoff*2:bufoff*3]) |
| 306 | copy(dstOut[dstEvery*3:], tmp[bufoff*3:bufoff*4]) |
| 307 | off = 0 |
| 308 | dstOut = dstOut[bufoff:] |
| 309 | // There must at least be 3 buffers left. |
| 310 | if len(dstOut) < dstEvery*3 { |
| 311 | return nil, errors.New("corruption detected: stream overrun 2") |
| 312 | } |
| 313 | } |
| 314 | } |
| 315 | if off > 0 { |
| 316 | ioff := int(off) |
| 317 | if len(dstOut) < dstEvery*3+ioff { |
| 318 | return nil, errors.New("corruption detected: stream overrun 3") |
| 319 | } |
| 320 | copy(dstOut, tmp[:off]) |
| 321 | copy(dstOut[dstEvery:dstEvery+ioff], tmp[bufoff:bufoff*2]) |
| 322 | copy(dstOut[dstEvery*2:dstEvery*2+ioff], tmp[bufoff*2:bufoff*3]) |
| 323 | copy(dstOut[dstEvery*3:dstEvery*3+ioff], tmp[bufoff*3:bufoff*4]) |
| 324 | dstOut = dstOut[off:] |
| 325 | } |
| 326 | |
| 327 | for i := range br { |
| 328 | offset := dstEvery * i |
| 329 | br := &br[i] |
| 330 | for !br.finished() { |
| 331 | br.fill() |
| 332 | if offset >= len(dstOut) { |
| 333 | return nil, errors.New("corruption detected: stream overrun 4") |
| 334 | } |
| 335 | dstOut[offset] = decode(br) |
| 336 | offset++ |
| 337 | } |
| 338 | err = br.close() |
| 339 | if err != nil { |
| 340 | return nil, err |
| 341 | } |
| 342 | } |
| 343 | |
| 344 | return s.Out, nil |
| 345 | } |
| 346 | |
| 347 | // matches will compare a decoding table to a coding table. |
| 348 | // Errors are written to the writer. |
| 349 | // Nothing will be written if table is ok. |
| 350 | func (s *Scratch) matches(ct cTable, w io.Writer) { |
| 351 | if s == nil || len(s.dt.single) == 0 { |
| 352 | return |
| 353 | } |
| 354 | dt := s.dt.single[:1<<s.actualTableLog] |
| 355 | tablelog := s.actualTableLog |
| 356 | ok := 0 |
| 357 | broken := 0 |
| 358 | for sym, enc := range ct { |
| 359 | errs := 0 |
| 360 | broken++ |
| 361 | if enc.nBits == 0 { |
| 362 | for _, dec := range dt { |
| 363 | if dec.byte == byte(sym) { |
| 364 | fmt.Fprintf(w, "symbol %x has decoder, but no encoder\n", sym) |
| 365 | errs++ |
| 366 | break |
| 367 | } |
| 368 | } |
| 369 | if errs == 0 { |
| 370 | broken-- |
| 371 | } |
| 372 | continue |
| 373 | } |
| 374 | // Unused bits in input |
| 375 | ub := tablelog - enc.nBits |
| 376 | top := enc.val << ub |
| 377 | // decoder looks at top bits. |
| 378 | dec := dt[top] |
| 379 | if dec.nBits != enc.nBits { |
| 380 | fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", sym, enc.nBits, dec.nBits) |
| 381 | errs++ |
| 382 | } |
| 383 | if dec.byte != uint8(sym) { |
| 384 | fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", sym, sym, dec.byte) |
| 385 | errs++ |
| 386 | } |
| 387 | if errs > 0 { |
| 388 | fmt.Fprintf(w, "%d errros in base, stopping\n", errs) |
| 389 | continue |
| 390 | } |
| 391 | // Ensure that all combinations are covered. |
| 392 | for i := uint16(0); i < (1 << ub); i++ { |
| 393 | vval := top | i |
| 394 | dec := dt[vval] |
| 395 | if dec.nBits != enc.nBits { |
| 396 | fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", vval, enc.nBits, dec.nBits) |
| 397 | errs++ |
| 398 | } |
| 399 | if dec.byte != uint8(sym) { |
| 400 | fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", vval, sym, dec.byte) |
| 401 | errs++ |
| 402 | } |
| 403 | if errs > 20 { |
| 404 | fmt.Fprintf(w, "%d errros, stopping\n", errs) |
| 405 | break |
| 406 | } |
| 407 | } |
| 408 | if errs == 0 { |
| 409 | ok++ |
| 410 | broken-- |
| 411 | } |
| 412 | } |
| 413 | if broken > 0 { |
| 414 | fmt.Fprintf(w, "%d broken, %d ok\n", broken, ok) |
| 415 | } |
| 416 | } |