blob: 96028ecd8366ca7fcd32abc5c756b0ad3c94cd48 [file] [log] [blame]
kesavandc71914f2022-03-25 11:19:03 +05301// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8 "bytes"
9 "fmt"
10
11 "github.com/klauspost/compress"
12)
13
14const (
15 bestLongTableBits = 22 // Bits used in the long match table
16 bestLongTableSize = 1 << bestLongTableBits // Size of the table
17 bestLongLen = 8 // Bytes used for table hash
18
19 // Note: Increasing the short table bits or making the hash shorter
20 // can actually lead to compression degradation since it will 'steal' more from the
21 // long match table and match offsets are quite big.
22 // This greatly depends on the type of input.
23 bestShortTableBits = 18 // Bits used in the short match table
24 bestShortTableSize = 1 << bestShortTableBits // Size of the table
25 bestShortLen = 4 // Bytes used for table hash
26
27)
28
29type match struct {
30 offset int32
31 s int32
32 length int32
33 rep int32
34 est int32
35}
36
37const highScore = 25000
38
39// estBits will estimate output bits from predefined tables.
40func (m *match) estBits(bitsPerByte int32) {
41 mlc := mlCode(uint32(m.length - zstdMinMatch))
42 var ofc uint8
43 if m.rep < 0 {
44 ofc = ofCode(uint32(m.s-m.offset) + 3)
45 } else {
46 ofc = ofCode(uint32(m.rep))
47 }
48 // Cost, excluding
49 ofTT, mlTT := fsePredefEnc[tableOffsets].ct.symbolTT[ofc], fsePredefEnc[tableMatchLengths].ct.symbolTT[mlc]
50
51 // Add cost of match encoding...
52 m.est = int32(ofTT.outBits + mlTT.outBits)
53 m.est += int32(ofTT.deltaNbBits>>16 + mlTT.deltaNbBits>>16)
54 // Subtract savings compared to literal encoding...
55 m.est -= (m.length * bitsPerByte) >> 10
56 if m.est > 0 {
57 // Unlikely gain..
58 m.length = 0
59 m.est = highScore
60 }
61}
62
63// bestFastEncoder uses 2 tables, one for short matches (5 bytes) and one for long matches.
64// The long match table contains the previous entry with the same hash,
65// effectively making it a "chain" of length 2.
66// When we find a long match we choose between the two values and select the longest.
67// When we find a short match, after checking the long, we check if we can find a long at n+1
68// and that it is longer (lazy matching).
69type bestFastEncoder struct {
70 fastBase
71 table [bestShortTableSize]prevEntry
72 longTable [bestLongTableSize]prevEntry
73 dictTable []prevEntry
74 dictLongTable []prevEntry
75}
76
77// Encode improves compression...
78func (e *bestFastEncoder) Encode(blk *blockEnc, src []byte) {
79 const (
80 // Input margin is the number of bytes we read (8)
81 // and the maximum we will read ahead (2)
82 inputMargin = 8 + 4
83 minNonLiteralBlockSize = 16
84 )
85
86 // Protect against e.cur wraparound.
87 for e.cur >= bufferReset {
88 if len(e.hist) == 0 {
89 for i := range e.table[:] {
90 e.table[i] = prevEntry{}
91 }
92 for i := range e.longTable[:] {
93 e.longTable[i] = prevEntry{}
94 }
95 e.cur = e.maxMatchOff
96 break
97 }
98 // Shift down everything in the table that isn't already too far away.
99 minOff := e.cur + int32(len(e.hist)) - e.maxMatchOff
100 for i := range e.table[:] {
101 v := e.table[i].offset
102 v2 := e.table[i].prev
103 if v < minOff {
104 v = 0
105 v2 = 0
106 } else {
107 v = v - e.cur + e.maxMatchOff
108 if v2 < minOff {
109 v2 = 0
110 } else {
111 v2 = v2 - e.cur + e.maxMatchOff
112 }
113 }
114 e.table[i] = prevEntry{
115 offset: v,
116 prev: v2,
117 }
118 }
119 for i := range e.longTable[:] {
120 v := e.longTable[i].offset
121 v2 := e.longTable[i].prev
122 if v < minOff {
123 v = 0
124 v2 = 0
125 } else {
126 v = v - e.cur + e.maxMatchOff
127 if v2 < minOff {
128 v2 = 0
129 } else {
130 v2 = v2 - e.cur + e.maxMatchOff
131 }
132 }
133 e.longTable[i] = prevEntry{
134 offset: v,
135 prev: v2,
136 }
137 }
138 e.cur = e.maxMatchOff
139 break
140 }
141
142 s := e.addBlock(src)
143 blk.size = len(src)
144 if len(src) < minNonLiteralBlockSize {
145 blk.extraLits = len(src)
146 blk.literals = blk.literals[:len(src)]
147 copy(blk.literals, src)
148 return
149 }
150
151 // Use this to estimate literal cost.
152 // Scaled by 10 bits.
153 bitsPerByte := int32((compress.ShannonEntropyBits(src) * 1024) / len(src))
154 // Huffman can never go < 1 bit/byte
155 if bitsPerByte < 1024 {
156 bitsPerByte = 1024
157 }
158
159 // Override src
160 src = e.hist
161 sLimit := int32(len(src)) - inputMargin
162 const kSearchStrength = 10
163
164 // nextEmit is where in src the next emitLiteral should start from.
165 nextEmit := s
166 cv := load6432(src, s)
167
168 // Relative offsets
169 offset1 := int32(blk.recentOffsets[0])
170 offset2 := int32(blk.recentOffsets[1])
171 offset3 := int32(blk.recentOffsets[2])
172
173 addLiterals := func(s *seq, until int32) {
174 if until == nextEmit {
175 return
176 }
177 blk.literals = append(blk.literals, src[nextEmit:until]...)
178 s.litLen = uint32(until - nextEmit)
179 }
180 _ = addLiterals
181
182 if debugEncoder {
183 println("recent offsets:", blk.recentOffsets)
184 }
185
186encodeLoop:
187 for {
188 // We allow the encoder to optionally turn off repeat offsets across blocks
189 canRepeat := len(blk.sequences) > 2
190
191 if debugAsserts && canRepeat && offset1 == 0 {
192 panic("offset0 was 0")
193 }
194
195 bestOf := func(a, b match) match {
196 if a.est+(a.s-b.s)*bitsPerByte>>10 < b.est+(b.s-a.s)*bitsPerByte>>10 {
197 return a
198 }
199 return b
200 }
201 const goodEnough = 100
202
203 nextHashL := hashLen(cv, bestLongTableBits, bestLongLen)
204 nextHashS := hashLen(cv, bestShortTableBits, bestShortLen)
205 candidateL := e.longTable[nextHashL]
206 candidateS := e.table[nextHashS]
207
208 matchAt := func(offset int32, s int32, first uint32, rep int32) match {
209 if s-offset >= e.maxMatchOff || load3232(src, offset) != first {
210 return match{s: s, est: highScore}
211 }
212 if debugAsserts {
213 if !bytes.Equal(src[s:s+4], src[offset:offset+4]) {
214 panic(fmt.Sprintf("first match mismatch: %v != %v, first: %08x", src[s:s+4], src[offset:offset+4], first))
215 }
216 }
217 m := match{offset: offset, s: s, length: 4 + e.matchlen(s+4, offset+4, src), rep: rep}
218 m.estBits(bitsPerByte)
219 return m
220 }
221
222 best := bestOf(matchAt(candidateL.offset-e.cur, s, uint32(cv), -1), matchAt(candidateL.prev-e.cur, s, uint32(cv), -1))
223 best = bestOf(best, matchAt(candidateS.offset-e.cur, s, uint32(cv), -1))
224 best = bestOf(best, matchAt(candidateS.prev-e.cur, s, uint32(cv), -1))
225
226 if canRepeat && best.length < goodEnough {
227 cv32 := uint32(cv >> 8)
228 spp := s + 1
229 best = bestOf(best, matchAt(spp-offset1, spp, cv32, 1))
230 best = bestOf(best, matchAt(spp-offset2, spp, cv32, 2))
231 best = bestOf(best, matchAt(spp-offset3, spp, cv32, 3))
232 if best.length > 0 {
233 cv32 = uint32(cv >> 24)
234 spp += 2
235 best = bestOf(best, matchAt(spp-offset1, spp, cv32, 1))
236 best = bestOf(best, matchAt(spp-offset2, spp, cv32, 2))
237 best = bestOf(best, matchAt(spp-offset3, spp, cv32, 3))
238 }
239 }
240 // Load next and check...
241 e.longTable[nextHashL] = prevEntry{offset: s + e.cur, prev: candidateL.offset}
242 e.table[nextHashS] = prevEntry{offset: s + e.cur, prev: candidateS.offset}
243
244 // Look far ahead, unless we have a really long match already...
245 if best.length < goodEnough {
246 // No match found, move forward on input, no need to check forward...
247 if best.length < 4 {
248 s += 1 + (s-nextEmit)>>(kSearchStrength-1)
249 if s >= sLimit {
250 break encodeLoop
251 }
252 cv = load6432(src, s)
253 continue
254 }
255
256 s++
257 candidateS = e.table[hashLen(cv>>8, bestShortTableBits, bestShortLen)]
258 cv = load6432(src, s)
259 cv2 := load6432(src, s+1)
260 candidateL = e.longTable[hashLen(cv, bestLongTableBits, bestLongLen)]
261 candidateL2 := e.longTable[hashLen(cv2, bestLongTableBits, bestLongLen)]
262
263 // Short at s+1
264 best = bestOf(best, matchAt(candidateS.offset-e.cur, s, uint32(cv), -1))
265 // Long at s+1, s+2
266 best = bestOf(best, matchAt(candidateL.offset-e.cur, s, uint32(cv), -1))
267 best = bestOf(best, matchAt(candidateL.prev-e.cur, s, uint32(cv), -1))
268 best = bestOf(best, matchAt(candidateL2.offset-e.cur, s+1, uint32(cv2), -1))
269 best = bestOf(best, matchAt(candidateL2.prev-e.cur, s+1, uint32(cv2), -1))
270 if false {
271 // Short at s+3.
272 // Too often worse...
273 best = bestOf(best, matchAt(e.table[hashLen(cv2>>8, bestShortTableBits, bestShortLen)].offset-e.cur, s+2, uint32(cv2>>8), -1))
274 }
275 // See if we can find a better match by checking where the current best ends.
276 // Use that offset to see if we can find a better full match.
277 if sAt := best.s + best.length; sAt < sLimit {
278 nextHashL := hashLen(load6432(src, sAt), bestLongTableBits, bestLongLen)
279 candidateEnd := e.longTable[nextHashL]
280 if pos := candidateEnd.offset - e.cur - best.length; pos >= 0 {
281 bestEnd := bestOf(best, matchAt(pos, best.s, load3232(src, best.s), -1))
282 if pos := candidateEnd.prev - e.cur - best.length; pos >= 0 {
283 bestEnd = bestOf(bestEnd, matchAt(pos, best.s, load3232(src, best.s), -1))
284 }
285 best = bestEnd
286 }
287 }
288 }
289
290 if debugAsserts {
291 if !bytes.Equal(src[best.s:best.s+best.length], src[best.offset:best.offset+best.length]) {
292 panic(fmt.Sprintf("match mismatch: %v != %v", src[best.s:best.s+best.length], src[best.offset:best.offset+best.length]))
293 }
294 }
295
296 // We have a match, we can store the forward value
297 if best.rep > 0 {
298 s = best.s
299 var seq seq
300 seq.matchLen = uint32(best.length - zstdMinMatch)
301
302 // We might be able to match backwards.
303 // Extend as long as we can.
304 start := best.s
305 // We end the search early, so we don't risk 0 literals
306 // and have to do special offset treatment.
307 startLimit := nextEmit + 1
308
309 tMin := s - e.maxMatchOff
310 if tMin < 0 {
311 tMin = 0
312 }
313 repIndex := best.offset
314 for repIndex > tMin && start > startLimit && src[repIndex-1] == src[start-1] && seq.matchLen < maxMatchLength-zstdMinMatch-1 {
315 repIndex--
316 start--
317 seq.matchLen++
318 }
319 addLiterals(&seq, start)
320
321 // rep 0
322 seq.offset = uint32(best.rep)
323 if debugSequences {
324 println("repeat sequence", seq, "next s:", s)
325 }
326 blk.sequences = append(blk.sequences, seq)
327
328 // Index match start+1 (long) -> s - 1
329 index0 := s
330 s = best.s + best.length
331
332 nextEmit = s
333 if s >= sLimit {
334 if debugEncoder {
335 println("repeat ended", s, best.length)
336
337 }
338 break encodeLoop
339 }
340 // Index skipped...
341 off := index0 + e.cur
342 for index0 < s-1 {
343 cv0 := load6432(src, index0)
344 h0 := hashLen(cv0, bestLongTableBits, bestLongLen)
345 h1 := hashLen(cv0, bestShortTableBits, bestShortLen)
346 e.longTable[h0] = prevEntry{offset: off, prev: e.longTable[h0].offset}
347 e.table[h1] = prevEntry{offset: off, prev: e.table[h1].offset}
348 off++
349 index0++
350 }
351 switch best.rep {
352 case 2:
353 offset1, offset2 = offset2, offset1
354 case 3:
355 offset1, offset2, offset3 = offset3, offset1, offset2
356 }
357 cv = load6432(src, s)
358 continue
359 }
360
361 // A 4-byte match has been found. Update recent offsets.
362 // We'll later see if more than 4 bytes.
363 s = best.s
364 t := best.offset
365 offset1, offset2, offset3 = s-t, offset1, offset2
366
367 if debugAsserts && s <= t {
368 panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
369 }
370
371 if debugAsserts && int(offset1) > len(src) {
372 panic("invalid offset")
373 }
374
375 // Extend the n-byte match as long as possible.
376 l := best.length
377
378 // Extend backwards
379 tMin := s - e.maxMatchOff
380 if tMin < 0 {
381 tMin = 0
382 }
383 for t > tMin && s > nextEmit && src[t-1] == src[s-1] && l < maxMatchLength {
384 s--
385 t--
386 l++
387 }
388
389 // Write our sequence
390 var seq seq
391 seq.litLen = uint32(s - nextEmit)
392 seq.matchLen = uint32(l - zstdMinMatch)
393 if seq.litLen > 0 {
394 blk.literals = append(blk.literals, src[nextEmit:s]...)
395 }
396 seq.offset = uint32(s-t) + 3
397 s += l
398 if debugSequences {
399 println("sequence", seq, "next s:", s)
400 }
401 blk.sequences = append(blk.sequences, seq)
402 nextEmit = s
403 if s >= sLimit {
404 break encodeLoop
405 }
406
407 // Index match start+1 (long) -> s - 1
408 index0 := s - l + 1
409 // every entry
410 for index0 < s-1 {
411 cv0 := load6432(src, index0)
412 h0 := hashLen(cv0, bestLongTableBits, bestLongLen)
413 h1 := hashLen(cv0, bestShortTableBits, bestShortLen)
414 off := index0 + e.cur
415 e.longTable[h0] = prevEntry{offset: off, prev: e.longTable[h0].offset}
416 e.table[h1] = prevEntry{offset: off, prev: e.table[h1].offset}
417 index0++
418 }
419
420 cv = load6432(src, s)
421 if !canRepeat {
422 continue
423 }
424
425 // Check offset 2
426 for {
427 o2 := s - offset2
428 if load3232(src, o2) != uint32(cv) {
429 // Do regular search
430 break
431 }
432
433 // Store this, since we have it.
434 nextHashS := hashLen(cv, bestShortTableBits, bestShortLen)
435 nextHashL := hashLen(cv, bestLongTableBits, bestLongLen)
436
437 // We have at least 4 byte match.
438 // No need to check backwards. We come straight from a match
439 l := 4 + e.matchlen(s+4, o2+4, src)
440
441 e.longTable[nextHashL] = prevEntry{offset: s + e.cur, prev: e.longTable[nextHashL].offset}
442 e.table[nextHashS] = prevEntry{offset: s + e.cur, prev: e.table[nextHashS].offset}
443 seq.matchLen = uint32(l) - zstdMinMatch
444 seq.litLen = 0
445
446 // Since litlen is always 0, this is offset 1.
447 seq.offset = 1
448 s += l
449 nextEmit = s
450 if debugSequences {
451 println("sequence", seq, "next s:", s)
452 }
453 blk.sequences = append(blk.sequences, seq)
454
455 // Swap offset 1 and 2.
456 offset1, offset2 = offset2, offset1
457 if s >= sLimit {
458 // Finished
459 break encodeLoop
460 }
461 cv = load6432(src, s)
462 }
463 }
464
465 if int(nextEmit) < len(src) {
466 blk.literals = append(blk.literals, src[nextEmit:]...)
467 blk.extraLits = len(src) - int(nextEmit)
468 }
469 blk.recentOffsets[0] = uint32(offset1)
470 blk.recentOffsets[1] = uint32(offset2)
471 blk.recentOffsets[2] = uint32(offset3)
472 if debugEncoder {
473 println("returning, recent offsets:", blk.recentOffsets, "extra literals:", blk.extraLits)
474 }
475}
476
477// EncodeNoHist will encode a block with no history and no following blocks.
478// Most notable difference is that src will not be copied for history and
479// we do not need to check for max match length.
480func (e *bestFastEncoder) EncodeNoHist(blk *blockEnc, src []byte) {
481 e.ensureHist(len(src))
482 e.Encode(blk, src)
483}
484
485// Reset will reset and set a dictionary if not nil
486func (e *bestFastEncoder) Reset(d *dict, singleBlock bool) {
487 e.resetBase(d, singleBlock)
488 if d == nil {
489 return
490 }
491 // Init or copy dict table
492 if len(e.dictTable) != len(e.table) || d.id != e.lastDictID {
493 if len(e.dictTable) != len(e.table) {
494 e.dictTable = make([]prevEntry, len(e.table))
495 }
496 end := int32(len(d.content)) - 8 + e.maxMatchOff
497 for i := e.maxMatchOff; i < end; i += 4 {
498 const hashLog = bestShortTableBits
499
500 cv := load6432(d.content, i-e.maxMatchOff)
501 nextHash := hashLen(cv, hashLog, bestShortLen) // 0 -> 4
502 nextHash1 := hashLen(cv>>8, hashLog, bestShortLen) // 1 -> 5
503 nextHash2 := hashLen(cv>>16, hashLog, bestShortLen) // 2 -> 6
504 nextHash3 := hashLen(cv>>24, hashLog, bestShortLen) // 3 -> 7
505 e.dictTable[nextHash] = prevEntry{
506 prev: e.dictTable[nextHash].offset,
507 offset: i,
508 }
509 e.dictTable[nextHash1] = prevEntry{
510 prev: e.dictTable[nextHash1].offset,
511 offset: i + 1,
512 }
513 e.dictTable[nextHash2] = prevEntry{
514 prev: e.dictTable[nextHash2].offset,
515 offset: i + 2,
516 }
517 e.dictTable[nextHash3] = prevEntry{
518 prev: e.dictTable[nextHash3].offset,
519 offset: i + 3,
520 }
521 }
522 e.lastDictID = d.id
523 }
524
525 // Init or copy dict table
526 if len(e.dictLongTable) != len(e.longTable) || d.id != e.lastDictID {
527 if len(e.dictLongTable) != len(e.longTable) {
528 e.dictLongTable = make([]prevEntry, len(e.longTable))
529 }
530 if len(d.content) >= 8 {
531 cv := load6432(d.content, 0)
532 h := hashLen(cv, bestLongTableBits, bestLongLen)
533 e.dictLongTable[h] = prevEntry{
534 offset: e.maxMatchOff,
535 prev: e.dictLongTable[h].offset,
536 }
537
538 end := int32(len(d.content)) - 8 + e.maxMatchOff
539 off := 8 // First to read
540 for i := e.maxMatchOff + 1; i < end; i++ {
541 cv = cv>>8 | (uint64(d.content[off]) << 56)
542 h := hashLen(cv, bestLongTableBits, bestLongLen)
543 e.dictLongTable[h] = prevEntry{
544 offset: i,
545 prev: e.dictLongTable[h].offset,
546 }
547 off++
548 }
549 }
550 e.lastDictID = d.id
551 }
552 // Reset table to initial state
553 copy(e.longTable[:], e.dictLongTable)
554
555 e.cur = e.maxMatchOff
556 // Reset table to initial state
557 copy(e.table[:], e.dictTable)
558}