Pragya Arya | 324337e | 2020-02-20 14:35:08 +0530 | [diff] [blame] | 1 | package huff0 |
| 2 | |
| 3 | import ( |
| 4 | "errors" |
| 5 | "fmt" |
| 6 | "io" |
| 7 | |
| 8 | "github.com/klauspost/compress/fse" |
| 9 | ) |
| 10 | |
| 11 | type dTable struct { |
| 12 | single []dEntrySingle |
| 13 | double []dEntryDouble |
| 14 | } |
| 15 | |
| 16 | // single-symbols decoding |
| 17 | type dEntrySingle struct { |
| 18 | entry uint16 |
| 19 | } |
| 20 | |
| 21 | // double-symbols decoding |
| 22 | type dEntryDouble struct { |
| 23 | seq uint16 |
| 24 | nBits uint8 |
| 25 | len uint8 |
| 26 | } |
| 27 | |
| 28 | // ReadTable will read a table from the input. |
| 29 | // The size of the input may be larger than the table definition. |
| 30 | // Any content remaining after the table definition will be returned. |
| 31 | // If no Scratch is provided a new one is allocated. |
| 32 | // The returned Scratch can be used for decoding input using this table. |
| 33 | func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) { |
| 34 | s, err = s.prepare(in) |
| 35 | if err != nil { |
| 36 | return s, nil, err |
| 37 | } |
| 38 | if len(in) <= 1 { |
| 39 | return s, nil, errors.New("input too small for table") |
| 40 | } |
| 41 | iSize := in[0] |
| 42 | in = in[1:] |
| 43 | if iSize >= 128 { |
| 44 | // Uncompressed |
| 45 | oSize := iSize - 127 |
| 46 | iSize = (oSize + 1) / 2 |
| 47 | if int(iSize) > len(in) { |
| 48 | return s, nil, errors.New("input too small for table") |
| 49 | } |
| 50 | for n := uint8(0); n < oSize; n += 2 { |
| 51 | v := in[n/2] |
| 52 | s.huffWeight[n] = v >> 4 |
| 53 | s.huffWeight[n+1] = v & 15 |
| 54 | } |
| 55 | s.symbolLen = uint16(oSize) |
| 56 | in = in[iSize:] |
| 57 | } else { |
| 58 | if len(in) <= int(iSize) { |
| 59 | return s, nil, errors.New("input too small for table") |
| 60 | } |
| 61 | // FSE compressed weights |
| 62 | s.fse.DecompressLimit = 255 |
| 63 | hw := s.huffWeight[:] |
| 64 | s.fse.Out = hw |
| 65 | b, err := fse.Decompress(in[:iSize], s.fse) |
| 66 | s.fse.Out = nil |
| 67 | if err != nil { |
| 68 | return s, nil, err |
| 69 | } |
| 70 | if len(b) > 255 { |
| 71 | return s, nil, errors.New("corrupt input: output table too large") |
| 72 | } |
| 73 | s.symbolLen = uint16(len(b)) |
| 74 | in = in[iSize:] |
| 75 | } |
| 76 | |
| 77 | // collect weight stats |
| 78 | var rankStats [16]uint32 |
| 79 | weightTotal := uint32(0) |
| 80 | for _, v := range s.huffWeight[:s.symbolLen] { |
| 81 | if v > tableLogMax { |
| 82 | return s, nil, errors.New("corrupt input: weight too large") |
| 83 | } |
| 84 | v2 := v & 15 |
| 85 | rankStats[v2]++ |
| 86 | weightTotal += (1 << v2) >> 1 |
| 87 | } |
| 88 | if weightTotal == 0 { |
| 89 | return s, nil, errors.New("corrupt input: weights zero") |
| 90 | } |
| 91 | |
| 92 | // get last non-null symbol weight (implied, total must be 2^n) |
| 93 | { |
| 94 | tableLog := highBit32(weightTotal) + 1 |
| 95 | if tableLog > tableLogMax { |
| 96 | return s, nil, errors.New("corrupt input: tableLog too big") |
| 97 | } |
| 98 | s.actualTableLog = uint8(tableLog) |
| 99 | // determine last weight |
| 100 | { |
| 101 | total := uint32(1) << tableLog |
| 102 | rest := total - weightTotal |
| 103 | verif := uint32(1) << highBit32(rest) |
| 104 | lastWeight := highBit32(rest) + 1 |
| 105 | if verif != rest { |
| 106 | // last value must be a clean power of 2 |
| 107 | return s, nil, errors.New("corrupt input: last value not power of two") |
| 108 | } |
| 109 | s.huffWeight[s.symbolLen] = uint8(lastWeight) |
| 110 | s.symbolLen++ |
| 111 | rankStats[lastWeight]++ |
| 112 | } |
| 113 | } |
| 114 | |
| 115 | if (rankStats[1] < 2) || (rankStats[1]&1 != 0) { |
| 116 | // by construction : at least 2 elts of rank 1, must be even |
| 117 | return s, nil, errors.New("corrupt input: min elt size, even check failed ") |
| 118 | } |
| 119 | |
| 120 | // TODO: Choose between single/double symbol decoding |
| 121 | |
| 122 | // Calculate starting value for each rank |
| 123 | { |
| 124 | var nextRankStart uint32 |
| 125 | for n := uint8(1); n < s.actualTableLog+1; n++ { |
| 126 | current := nextRankStart |
| 127 | nextRankStart += rankStats[n] << (n - 1) |
| 128 | rankStats[n] = current |
| 129 | } |
| 130 | } |
| 131 | |
| 132 | // fill DTable (always full size) |
| 133 | tSize := 1 << tableLogMax |
| 134 | if len(s.dt.single) != tSize { |
| 135 | s.dt.single = make([]dEntrySingle, tSize) |
| 136 | } |
| 137 | for n, w := range s.huffWeight[:s.symbolLen] { |
| 138 | if w == 0 { |
| 139 | continue |
| 140 | } |
| 141 | length := (uint32(1) << w) >> 1 |
| 142 | d := dEntrySingle{ |
| 143 | entry: uint16(s.actualTableLog+1-w) | (uint16(n) << 8), |
| 144 | } |
| 145 | single := s.dt.single[rankStats[w] : rankStats[w]+length] |
| 146 | for i := range single { |
| 147 | single[i] = d |
| 148 | } |
| 149 | rankStats[w] += length |
| 150 | } |
| 151 | return s, in, nil |
| 152 | } |
| 153 | |
| 154 | // Decompress1X will decompress a 1X encoded stream. |
| 155 | // The length of the supplied input must match the end of a block exactly. |
| 156 | // Before this is called, the table must be initialized with ReadTable unless |
| 157 | // the encoder re-used the table. |
| 158 | func (s *Scratch) Decompress1X(in []byte) (out []byte, err error) { |
| 159 | if len(s.dt.single) == 0 { |
| 160 | return nil, errors.New("no table loaded") |
| 161 | } |
| 162 | var br bitReader |
| 163 | err = br.init(in) |
| 164 | if err != nil { |
| 165 | return nil, err |
| 166 | } |
| 167 | s.Out = s.Out[:0] |
| 168 | |
| 169 | decode := func() byte { |
| 170 | val := br.peekBitsFast(s.actualTableLog) /* note : actualTableLog >= 1 */ |
| 171 | v := s.dt.single[val] |
| 172 | br.bitsRead += uint8(v.entry) |
| 173 | return uint8(v.entry >> 8) |
| 174 | } |
| 175 | hasDec := func(v dEntrySingle) byte { |
| 176 | br.bitsRead += uint8(v.entry) |
| 177 | return uint8(v.entry >> 8) |
| 178 | } |
| 179 | |
| 180 | // Avoid bounds check by always having full sized table. |
| 181 | const tlSize = 1 << tableLogMax |
| 182 | const tlMask = tlSize - 1 |
| 183 | dt := s.dt.single[:tlSize] |
| 184 | |
| 185 | // Use temp table to avoid bound checks/append penalty. |
| 186 | var tmp = s.huffWeight[:256] |
| 187 | var off uint8 |
| 188 | |
| 189 | for br.off >= 8 { |
| 190 | br.fillFast() |
| 191 | tmp[off+0] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask]) |
| 192 | tmp[off+1] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask]) |
| 193 | br.fillFast() |
| 194 | tmp[off+2] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask]) |
| 195 | tmp[off+3] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask]) |
| 196 | off += 4 |
| 197 | if off == 0 { |
| 198 | if len(s.Out)+256 > s.MaxDecodedSize { |
| 199 | br.close() |
| 200 | return nil, ErrMaxDecodedSizeExceeded |
| 201 | } |
| 202 | s.Out = append(s.Out, tmp...) |
| 203 | } |
| 204 | } |
| 205 | |
| 206 | if len(s.Out)+int(off) > s.MaxDecodedSize { |
| 207 | br.close() |
| 208 | return nil, ErrMaxDecodedSizeExceeded |
| 209 | } |
| 210 | s.Out = append(s.Out, tmp[:off]...) |
| 211 | |
| 212 | for !br.finished() { |
| 213 | br.fill() |
| 214 | if len(s.Out) >= s.MaxDecodedSize { |
| 215 | br.close() |
| 216 | return nil, ErrMaxDecodedSizeExceeded |
| 217 | } |
| 218 | s.Out = append(s.Out, decode()) |
| 219 | } |
| 220 | return s.Out, br.close() |
| 221 | } |
| 222 | |
| 223 | // Decompress4X will decompress a 4X encoded stream. |
| 224 | // Before this is called, the table must be initialized with ReadTable unless |
| 225 | // the encoder re-used the table. |
| 226 | // The length of the supplied input must match the end of a block exactly. |
| 227 | // The destination size of the uncompressed data must be known and provided. |
| 228 | func (s *Scratch) Decompress4X(in []byte, dstSize int) (out []byte, err error) { |
| 229 | if len(s.dt.single) == 0 { |
| 230 | return nil, errors.New("no table loaded") |
| 231 | } |
| 232 | if len(in) < 6+(4*1) { |
| 233 | return nil, errors.New("input too small") |
| 234 | } |
| 235 | if dstSize > s.MaxDecodedSize { |
| 236 | return nil, ErrMaxDecodedSizeExceeded |
| 237 | } |
| 238 | // TODO: We do not detect when we overrun a buffer, except if the last one does. |
| 239 | |
| 240 | var br [4]bitReader |
| 241 | start := 6 |
| 242 | for i := 0; i < 3; i++ { |
| 243 | length := int(in[i*2]) | (int(in[i*2+1]) << 8) |
| 244 | if start+length >= len(in) { |
| 245 | return nil, errors.New("truncated input (or invalid offset)") |
| 246 | } |
| 247 | err = br[i].init(in[start : start+length]) |
| 248 | if err != nil { |
| 249 | return nil, err |
| 250 | } |
| 251 | start += length |
| 252 | } |
| 253 | err = br[3].init(in[start:]) |
| 254 | if err != nil { |
| 255 | return nil, err |
| 256 | } |
| 257 | |
| 258 | // Prepare output |
| 259 | if cap(s.Out) < dstSize { |
| 260 | s.Out = make([]byte, 0, dstSize) |
| 261 | } |
| 262 | s.Out = s.Out[:dstSize] |
| 263 | // destination, offset to match first output |
| 264 | dstOut := s.Out |
| 265 | dstEvery := (dstSize + 3) / 4 |
| 266 | |
| 267 | const tlSize = 1 << tableLogMax |
| 268 | const tlMask = tlSize - 1 |
| 269 | single := s.dt.single[:tlSize] |
| 270 | |
| 271 | decode := func(br *bitReader) byte { |
| 272 | val := br.peekBitsFast(s.actualTableLog) /* note : actualTableLog >= 1 */ |
| 273 | v := single[val&tlMask] |
| 274 | br.bitsRead += uint8(v.entry) |
| 275 | return uint8(v.entry >> 8) |
| 276 | } |
| 277 | |
| 278 | // Use temp table to avoid bound checks/append penalty. |
| 279 | var tmp = s.huffWeight[:256] |
| 280 | var off uint8 |
| 281 | var decoded int |
| 282 | |
| 283 | // Decode 2 values from each decoder/loop. |
| 284 | const bufoff = 256 / 4 |
| 285 | bigloop: |
| 286 | for { |
| 287 | for i := range br { |
| 288 | br := &br[i] |
| 289 | if br.off < 4 { |
| 290 | break bigloop |
| 291 | } |
| 292 | br.fillFast() |
| 293 | } |
| 294 | |
| 295 | { |
| 296 | const stream = 0 |
| 297 | val := br[stream].peekBitsFast(s.actualTableLog) |
| 298 | v := single[val&tlMask] |
| 299 | br[stream].bitsRead += uint8(v.entry) |
| 300 | |
| 301 | val2 := br[stream].peekBitsFast(s.actualTableLog) |
| 302 | v2 := single[val2&tlMask] |
| 303 | tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8) |
| 304 | tmp[off+bufoff*stream] = uint8(v.entry >> 8) |
| 305 | br[stream].bitsRead += uint8(v2.entry) |
| 306 | } |
| 307 | |
| 308 | { |
| 309 | const stream = 1 |
| 310 | val := br[stream].peekBitsFast(s.actualTableLog) |
| 311 | v := single[val&tlMask] |
| 312 | br[stream].bitsRead += uint8(v.entry) |
| 313 | |
| 314 | val2 := br[stream].peekBitsFast(s.actualTableLog) |
| 315 | v2 := single[val2&tlMask] |
| 316 | tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8) |
| 317 | tmp[off+bufoff*stream] = uint8(v.entry >> 8) |
| 318 | br[stream].bitsRead += uint8(v2.entry) |
| 319 | } |
| 320 | |
| 321 | { |
| 322 | const stream = 2 |
| 323 | val := br[stream].peekBitsFast(s.actualTableLog) |
| 324 | v := single[val&tlMask] |
| 325 | br[stream].bitsRead += uint8(v.entry) |
| 326 | |
| 327 | val2 := br[stream].peekBitsFast(s.actualTableLog) |
| 328 | v2 := single[val2&tlMask] |
| 329 | tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8) |
| 330 | tmp[off+bufoff*stream] = uint8(v.entry >> 8) |
| 331 | br[stream].bitsRead += uint8(v2.entry) |
| 332 | } |
| 333 | |
| 334 | { |
| 335 | const stream = 3 |
| 336 | val := br[stream].peekBitsFast(s.actualTableLog) |
| 337 | v := single[val&tlMask] |
| 338 | br[stream].bitsRead += uint8(v.entry) |
| 339 | |
| 340 | val2 := br[stream].peekBitsFast(s.actualTableLog) |
| 341 | v2 := single[val2&tlMask] |
| 342 | tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8) |
| 343 | tmp[off+bufoff*stream] = uint8(v.entry >> 8) |
| 344 | br[stream].bitsRead += uint8(v2.entry) |
| 345 | } |
| 346 | |
| 347 | off += 2 |
| 348 | |
| 349 | if off == bufoff { |
| 350 | if bufoff > dstEvery { |
| 351 | return nil, errors.New("corruption detected: stream overrun 1") |
| 352 | } |
| 353 | copy(dstOut, tmp[:bufoff]) |
| 354 | copy(dstOut[dstEvery:], tmp[bufoff:bufoff*2]) |
| 355 | copy(dstOut[dstEvery*2:], tmp[bufoff*2:bufoff*3]) |
| 356 | copy(dstOut[dstEvery*3:], tmp[bufoff*3:bufoff*4]) |
| 357 | off = 0 |
| 358 | dstOut = dstOut[bufoff:] |
| 359 | decoded += 256 |
| 360 | // There must at least be 3 buffers left. |
| 361 | if len(dstOut) < dstEvery*3 { |
| 362 | return nil, errors.New("corruption detected: stream overrun 2") |
| 363 | } |
| 364 | } |
| 365 | } |
| 366 | if off > 0 { |
| 367 | ioff := int(off) |
| 368 | if len(dstOut) < dstEvery*3+ioff { |
| 369 | return nil, errors.New("corruption detected: stream overrun 3") |
| 370 | } |
| 371 | copy(dstOut, tmp[:off]) |
| 372 | copy(dstOut[dstEvery:dstEvery+ioff], tmp[bufoff:bufoff*2]) |
| 373 | copy(dstOut[dstEvery*2:dstEvery*2+ioff], tmp[bufoff*2:bufoff*3]) |
| 374 | copy(dstOut[dstEvery*3:dstEvery*3+ioff], tmp[bufoff*3:bufoff*4]) |
| 375 | decoded += int(off) * 4 |
| 376 | dstOut = dstOut[off:] |
| 377 | } |
| 378 | |
| 379 | // Decode remaining. |
| 380 | for i := range br { |
| 381 | offset := dstEvery * i |
| 382 | br := &br[i] |
| 383 | for !br.finished() { |
| 384 | br.fill() |
| 385 | if offset >= len(dstOut) { |
| 386 | return nil, errors.New("corruption detected: stream overrun 4") |
| 387 | } |
| 388 | dstOut[offset] = decode(br) |
| 389 | offset++ |
| 390 | } |
| 391 | decoded += offset - dstEvery*i |
| 392 | err = br.close() |
| 393 | if err != nil { |
| 394 | return nil, err |
| 395 | } |
| 396 | } |
| 397 | if dstSize != decoded { |
| 398 | return nil, errors.New("corruption detected: short output block") |
| 399 | } |
| 400 | return s.Out, nil |
| 401 | } |
| 402 | |
| 403 | // matches will compare a decoding table to a coding table. |
| 404 | // Errors are written to the writer. |
| 405 | // Nothing will be written if table is ok. |
| 406 | func (s *Scratch) matches(ct cTable, w io.Writer) { |
| 407 | if s == nil || len(s.dt.single) == 0 { |
| 408 | return |
| 409 | } |
| 410 | dt := s.dt.single[:1<<s.actualTableLog] |
| 411 | tablelog := s.actualTableLog |
| 412 | ok := 0 |
| 413 | broken := 0 |
| 414 | for sym, enc := range ct { |
| 415 | errs := 0 |
| 416 | broken++ |
| 417 | if enc.nBits == 0 { |
| 418 | for _, dec := range dt { |
| 419 | if uint8(dec.entry>>8) == byte(sym) { |
| 420 | fmt.Fprintf(w, "symbol %x has decoder, but no encoder\n", sym) |
| 421 | errs++ |
| 422 | break |
| 423 | } |
| 424 | } |
| 425 | if errs == 0 { |
| 426 | broken-- |
| 427 | } |
| 428 | continue |
| 429 | } |
| 430 | // Unused bits in input |
| 431 | ub := tablelog - enc.nBits |
| 432 | top := enc.val << ub |
| 433 | // decoder looks at top bits. |
| 434 | dec := dt[top] |
| 435 | if uint8(dec.entry) != enc.nBits { |
| 436 | fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", sym, enc.nBits, uint8(dec.entry)) |
| 437 | errs++ |
| 438 | } |
| 439 | if uint8(dec.entry>>8) != uint8(sym) { |
| 440 | fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", sym, sym, uint8(dec.entry>>8)) |
| 441 | errs++ |
| 442 | } |
| 443 | if errs > 0 { |
| 444 | fmt.Fprintf(w, "%d errros in base, stopping\n", errs) |
| 445 | continue |
| 446 | } |
| 447 | // Ensure that all combinations are covered. |
| 448 | for i := uint16(0); i < (1 << ub); i++ { |
| 449 | vval := top | i |
| 450 | dec := dt[vval] |
| 451 | if uint8(dec.entry) != enc.nBits { |
| 452 | fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", vval, enc.nBits, uint8(dec.entry)) |
| 453 | errs++ |
| 454 | } |
| 455 | if uint8(dec.entry>>8) != uint8(sym) { |
| 456 | fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", vval, sym, uint8(dec.entry>>8)) |
| 457 | errs++ |
| 458 | } |
| 459 | if errs > 20 { |
| 460 | fmt.Fprintf(w, "%d errros, stopping\n", errs) |
| 461 | break |
| 462 | } |
| 463 | } |
| 464 | if errs == 0 { |
| 465 | ok++ |
| 466 | broken-- |
| 467 | } |
| 468 | } |
| 469 | if broken > 0 { |
| 470 | fmt.Fprintf(w, "%d broken, %d ok\n", broken, ok) |
| 471 | } |
| 472 | } |