blob: dc1eed5f00a0b043978970e46e0b062b217aaffa [file] [log] [blame]
khenaidoo106c61a2021-08-11 18:05:46 -04001// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8 "fmt"
9 "math/bits"
10)
11
12const (
13 bestLongTableBits = 20 // Bits used in the long match table
14 bestLongTableSize = 1 << bestLongTableBits // Size of the table
15
16 // Note: Increasing the short table bits or making the hash shorter
17 // can actually lead to compression degradation since it will 'steal' more from the
18 // long match table and match offsets are quite big.
19 // This greatly depends on the type of input.
20 bestShortTableBits = 16 // Bits used in the short match table
21 bestShortTableSize = 1 << bestShortTableBits // Size of the table
22)
23
24// bestFastEncoder uses 2 tables, one for short matches (5 bytes) and one for long matches.
25// The long match table contains the previous entry with the same hash,
26// effectively making it a "chain" of length 2.
27// When we find a long match we choose between the two values and select the longest.
28// When we find a short match, after checking the long, we check if we can find a long at n+1
29// and that it is longer (lazy matching).
30type bestFastEncoder struct {
31 fastBase
32 table [bestShortTableSize]prevEntry
33 longTable [bestLongTableSize]prevEntry
34 dictTable []prevEntry
35 dictLongTable []prevEntry
36}
37
38// Encode improves compression...
39func (e *bestFastEncoder) Encode(blk *blockEnc, src []byte) {
40 const (
41 // Input margin is the number of bytes we read (8)
42 // and the maximum we will read ahead (2)
43 inputMargin = 8 + 4
44 minNonLiteralBlockSize = 16
45 )
46
47 // Protect against e.cur wraparound.
48 for e.cur >= bufferReset {
49 if len(e.hist) == 0 {
50 for i := range e.table[:] {
51 e.table[i] = prevEntry{}
52 }
53 for i := range e.longTable[:] {
54 e.longTable[i] = prevEntry{}
55 }
56 e.cur = e.maxMatchOff
57 break
58 }
59 // Shift down everything in the table that isn't already too far away.
60 minOff := e.cur + int32(len(e.hist)) - e.maxMatchOff
61 for i := range e.table[:] {
62 v := e.table[i].offset
63 v2 := e.table[i].prev
64 if v < minOff {
65 v = 0
66 v2 = 0
67 } else {
68 v = v - e.cur + e.maxMatchOff
69 if v2 < minOff {
70 v2 = 0
71 } else {
72 v2 = v2 - e.cur + e.maxMatchOff
73 }
74 }
75 e.table[i] = prevEntry{
76 offset: v,
77 prev: v2,
78 }
79 }
80 for i := range e.longTable[:] {
81 v := e.longTable[i].offset
82 v2 := e.longTable[i].prev
83 if v < minOff {
84 v = 0
85 v2 = 0
86 } else {
87 v = v - e.cur + e.maxMatchOff
88 if v2 < minOff {
89 v2 = 0
90 } else {
91 v2 = v2 - e.cur + e.maxMatchOff
92 }
93 }
94 e.longTable[i] = prevEntry{
95 offset: v,
96 prev: v2,
97 }
98 }
99 e.cur = e.maxMatchOff
100 break
101 }
102
103 s := e.addBlock(src)
104 blk.size = len(src)
105 if len(src) < minNonLiteralBlockSize {
106 blk.extraLits = len(src)
107 blk.literals = blk.literals[:len(src)]
108 copy(blk.literals, src)
109 return
110 }
111
112 // Override src
113 src = e.hist
114 sLimit := int32(len(src)) - inputMargin
115 const kSearchStrength = 10
116
117 // nextEmit is where in src the next emitLiteral should start from.
118 nextEmit := s
119 cv := load6432(src, s)
120
121 // Relative offsets
122 offset1 := int32(blk.recentOffsets[0])
123 offset2 := int32(blk.recentOffsets[1])
124 offset3 := int32(blk.recentOffsets[2])
125
126 addLiterals := func(s *seq, until int32) {
127 if until == nextEmit {
128 return
129 }
130 blk.literals = append(blk.literals, src[nextEmit:until]...)
131 s.litLen = uint32(until - nextEmit)
132 }
133 _ = addLiterals
134
135 if debug {
136 println("recent offsets:", blk.recentOffsets)
137 }
138
139encodeLoop:
140 for {
141 // We allow the encoder to optionally turn off repeat offsets across blocks
142 canRepeat := len(blk.sequences) > 2
143
144 if debugAsserts && canRepeat && offset1 == 0 {
145 panic("offset0 was 0")
146 }
147
148 type match struct {
149 offset int32
150 s int32
151 length int32
152 rep int32
153 }
154 matchAt := func(offset int32, s int32, first uint32, rep int32) match {
155 if s-offset >= e.maxMatchOff || load3232(src, offset) != first {
156 return match{offset: offset, s: s}
157 }
158 return match{offset: offset, s: s, length: 4 + e.matchlen(s+4, offset+4, src), rep: rep}
159 }
160
161 bestOf := func(a, b match) match {
162 aScore := b.s - a.s + a.length
163 bScore := a.s - b.s + b.length
164 if a.rep < 0 {
165 aScore = aScore - int32(bits.Len32(uint32(a.offset)))/8
166 }
167 if b.rep < 0 {
168 bScore = bScore - int32(bits.Len32(uint32(b.offset)))/8
169 }
170 if aScore >= bScore {
171 return a
172 }
173 return b
174 }
175 const goodEnough = 100
176
177 nextHashL := hash8(cv, bestLongTableBits)
178 nextHashS := hash4x64(cv, bestShortTableBits)
179 candidateL := e.longTable[nextHashL]
180 candidateS := e.table[nextHashS]
181
182 best := bestOf(matchAt(candidateL.offset-e.cur, s, uint32(cv), -1), matchAt(candidateL.prev-e.cur, s, uint32(cv), -1))
183 best = bestOf(best, matchAt(candidateS.offset-e.cur, s, uint32(cv), -1))
184 best = bestOf(best, matchAt(candidateS.prev-e.cur, s, uint32(cv), -1))
185 if canRepeat && best.length < goodEnough {
186 best = bestOf(best, matchAt(s-offset1+1, s+1, uint32(cv>>8), 1))
187 best = bestOf(best, matchAt(s-offset2+1, s+1, uint32(cv>>8), 2))
188 best = bestOf(best, matchAt(s-offset3+1, s+1, uint32(cv>>8), 3))
189 if best.length > 0 {
190 best = bestOf(best, matchAt(s-offset1+3, s+3, uint32(cv>>24), 1))
191 best = bestOf(best, matchAt(s-offset2+3, s+3, uint32(cv>>24), 2))
192 best = bestOf(best, matchAt(s-offset3+3, s+3, uint32(cv>>24), 3))
193 }
194 }
195 // Load next and check...
196 e.longTable[nextHashL] = prevEntry{offset: s + e.cur, prev: candidateL.offset}
197 e.table[nextHashS] = prevEntry{offset: s + e.cur, prev: candidateS.offset}
198
199 // Look far ahead, unless we have a really long match already...
200 if best.length < goodEnough {
201 // No match found, move forward on input, no need to check forward...
202 if best.length < 4 {
203 s += 1 + (s-nextEmit)>>(kSearchStrength-1)
204 if s >= sLimit {
205 break encodeLoop
206 }
207 cv = load6432(src, s)
208 continue
209 }
210
211 s++
212 candidateS = e.table[hash4x64(cv>>8, bestShortTableBits)]
213 cv = load6432(src, s)
214 cv2 := load6432(src, s+1)
215 candidateL = e.longTable[hash8(cv, bestLongTableBits)]
216 candidateL2 := e.longTable[hash8(cv2, bestLongTableBits)]
217
218 best = bestOf(best, matchAt(candidateS.offset-e.cur, s, uint32(cv), -1))
219 best = bestOf(best, matchAt(candidateL.offset-e.cur, s, uint32(cv), -1))
220 best = bestOf(best, matchAt(candidateL.prev-e.cur, s, uint32(cv), -1))
221 best = bestOf(best, matchAt(candidateL2.offset-e.cur, s+1, uint32(cv2), -1))
222 best = bestOf(best, matchAt(candidateL2.prev-e.cur, s+1, uint32(cv2), -1))
223
224 // See if we can find a better match by checking where the current best ends.
225 // Use that offset to see if we can find a better full match.
226 if sAt := best.s + best.length; sAt < sLimit {
227 nextHashL := hash8(load6432(src, sAt), bestLongTableBits)
228 candidateEnd := e.longTable[nextHashL]
229 if pos := candidateEnd.offset - e.cur - best.length; pos >= 0 {
230 bestEnd := bestOf(best, matchAt(pos, best.s, load3232(src, best.s), -1))
231 if pos := candidateEnd.prev - e.cur - best.length; pos >= 0 {
232 bestEnd = bestOf(bestEnd, matchAt(pos, best.s, load3232(src, best.s), -1))
233 }
234 best = bestEnd
235 }
236 }
237 }
238
239 // We have a match, we can store the forward value
240 if best.rep > 0 {
241 s = best.s
242 var seq seq
243 seq.matchLen = uint32(best.length - zstdMinMatch)
244
245 // We might be able to match backwards.
246 // Extend as long as we can.
247 start := best.s
248 // We end the search early, so we don't risk 0 literals
249 // and have to do special offset treatment.
250 startLimit := nextEmit + 1
251
252 tMin := s - e.maxMatchOff
253 if tMin < 0 {
254 tMin = 0
255 }
256 repIndex := best.offset
257 for repIndex > tMin && start > startLimit && src[repIndex-1] == src[start-1] && seq.matchLen < maxMatchLength-zstdMinMatch-1 {
258 repIndex--
259 start--
260 seq.matchLen++
261 }
262 addLiterals(&seq, start)
263
264 // rep 0
265 seq.offset = uint32(best.rep)
266 if debugSequences {
267 println("repeat sequence", seq, "next s:", s)
268 }
269 blk.sequences = append(blk.sequences, seq)
270
271 // Index match start+1 (long) -> s - 1
272 index0 := s
273 s = best.s + best.length
274
275 nextEmit = s
276 if s >= sLimit {
277 if debug {
278 println("repeat ended", s, best.length)
279
280 }
281 break encodeLoop
282 }
283 // Index skipped...
284 off := index0 + e.cur
285 for index0 < s-1 {
286 cv0 := load6432(src, index0)
287 h0 := hash8(cv0, bestLongTableBits)
288 h1 := hash4x64(cv0, bestShortTableBits)
289 e.longTable[h0] = prevEntry{offset: off, prev: e.longTable[h0].offset}
290 e.table[h1] = prevEntry{offset: off, prev: e.table[h1].offset}
291 off++
292 index0++
293 }
294 switch best.rep {
295 case 2:
296 offset1, offset2 = offset2, offset1
297 case 3:
298 offset1, offset2, offset3 = offset3, offset1, offset2
299 }
300 cv = load6432(src, s)
301 continue
302 }
303
304 // A 4-byte match has been found. Update recent offsets.
305 // We'll later see if more than 4 bytes.
306 s = best.s
307 t := best.offset
308 offset1, offset2, offset3 = s-t, offset1, offset2
309
310 if debugAsserts && s <= t {
311 panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
312 }
313
314 if debugAsserts && canRepeat && int(offset1) > len(src) {
315 panic("invalid offset")
316 }
317
318 // Extend the n-byte match as long as possible.
319 l := best.length
320
321 // Extend backwards
322 tMin := s - e.maxMatchOff
323 if tMin < 0 {
324 tMin = 0
325 }
326 for t > tMin && s > nextEmit && src[t-1] == src[s-1] && l < maxMatchLength {
327 s--
328 t--
329 l++
330 }
331
332 // Write our sequence
333 var seq seq
334 seq.litLen = uint32(s - nextEmit)
335 seq.matchLen = uint32(l - zstdMinMatch)
336 if seq.litLen > 0 {
337 blk.literals = append(blk.literals, src[nextEmit:s]...)
338 }
339 seq.offset = uint32(s-t) + 3
340 s += l
341 if debugSequences {
342 println("sequence", seq, "next s:", s)
343 }
344 blk.sequences = append(blk.sequences, seq)
345 nextEmit = s
346 if s >= sLimit {
347 break encodeLoop
348 }
349
350 // Index match start+1 (long) -> s - 1
351 index0 := s - l + 1
352 // every entry
353 for index0 < s-1 {
354 cv0 := load6432(src, index0)
355 h0 := hash8(cv0, bestLongTableBits)
356 h1 := hash4x64(cv0, bestShortTableBits)
357 off := index0 + e.cur
358 e.longTable[h0] = prevEntry{offset: off, prev: e.longTable[h0].offset}
359 e.table[h1] = prevEntry{offset: off, prev: e.table[h1].offset}
360 index0++
361 }
362
363 cv = load6432(src, s)
364 if !canRepeat {
365 continue
366 }
367
368 // Check offset 2
369 for {
370 o2 := s - offset2
371 if load3232(src, o2) != uint32(cv) {
372 // Do regular search
373 break
374 }
375
376 // Store this, since we have it.
377 nextHashS := hash4x64(cv, bestShortTableBits)
378 nextHashL := hash8(cv, bestLongTableBits)
379
380 // We have at least 4 byte match.
381 // No need to check backwards. We come straight from a match
382 l := 4 + e.matchlen(s+4, o2+4, src)
383
384 e.longTable[nextHashL] = prevEntry{offset: s + e.cur, prev: e.longTable[nextHashL].offset}
385 e.table[nextHashS] = prevEntry{offset: s + e.cur, prev: e.table[nextHashS].offset}
386 seq.matchLen = uint32(l) - zstdMinMatch
387 seq.litLen = 0
388
389 // Since litlen is always 0, this is offset 1.
390 seq.offset = 1
391 s += l
392 nextEmit = s
393 if debugSequences {
394 println("sequence", seq, "next s:", s)
395 }
396 blk.sequences = append(blk.sequences, seq)
397
398 // Swap offset 1 and 2.
399 offset1, offset2 = offset2, offset1
400 if s >= sLimit {
401 // Finished
402 break encodeLoop
403 }
404 cv = load6432(src, s)
405 }
406 }
407
408 if int(nextEmit) < len(src) {
409 blk.literals = append(blk.literals, src[nextEmit:]...)
410 blk.extraLits = len(src) - int(nextEmit)
411 }
412 blk.recentOffsets[0] = uint32(offset1)
413 blk.recentOffsets[1] = uint32(offset2)
414 blk.recentOffsets[2] = uint32(offset3)
415 if debug {
416 println("returning, recent offsets:", blk.recentOffsets, "extra literals:", blk.extraLits)
417 }
418}
419
420// EncodeNoHist will encode a block with no history and no following blocks.
421// Most notable difference is that src will not be copied for history and
422// we do not need to check for max match length.
423func (e *bestFastEncoder) EncodeNoHist(blk *blockEnc, src []byte) {
424 e.ensureHist(len(src))
425 e.Encode(blk, src)
426}
427
428// ResetDict will reset and set a dictionary if not nil
429func (e *bestFastEncoder) Reset(d *dict, singleBlock bool) {
430 e.resetBase(d, singleBlock)
431 if d == nil {
432 return
433 }
434 // Init or copy dict table
435 if len(e.dictTable) != len(e.table) || d.id != e.lastDictID {
436 if len(e.dictTable) != len(e.table) {
437 e.dictTable = make([]prevEntry, len(e.table))
438 }
439 end := int32(len(d.content)) - 8 + e.maxMatchOff
440 for i := e.maxMatchOff; i < end; i += 4 {
441 const hashLog = bestShortTableBits
442
443 cv := load6432(d.content, i-e.maxMatchOff)
444 nextHash := hash4x64(cv, hashLog) // 0 -> 4
445 nextHash1 := hash4x64(cv>>8, hashLog) // 1 -> 5
446 nextHash2 := hash4x64(cv>>16, hashLog) // 2 -> 6
447 nextHash3 := hash4x64(cv>>24, hashLog) // 3 -> 7
448 e.dictTable[nextHash] = prevEntry{
449 prev: e.dictTable[nextHash].offset,
450 offset: i,
451 }
452 e.dictTable[nextHash1] = prevEntry{
453 prev: e.dictTable[nextHash1].offset,
454 offset: i + 1,
455 }
456 e.dictTable[nextHash2] = prevEntry{
457 prev: e.dictTable[nextHash2].offset,
458 offset: i + 2,
459 }
460 e.dictTable[nextHash3] = prevEntry{
461 prev: e.dictTable[nextHash3].offset,
462 offset: i + 3,
463 }
464 }
465 e.lastDictID = d.id
466 }
467
468 // Init or copy dict table
469 if len(e.dictLongTable) != len(e.longTable) || d.id != e.lastDictID {
470 if len(e.dictLongTable) != len(e.longTable) {
471 e.dictLongTable = make([]prevEntry, len(e.longTable))
472 }
473 if len(d.content) >= 8 {
474 cv := load6432(d.content, 0)
475 h := hash8(cv, bestLongTableBits)
476 e.dictLongTable[h] = prevEntry{
477 offset: e.maxMatchOff,
478 prev: e.dictLongTable[h].offset,
479 }
480
481 end := int32(len(d.content)) - 8 + e.maxMatchOff
482 off := 8 // First to read
483 for i := e.maxMatchOff + 1; i < end; i++ {
484 cv = cv>>8 | (uint64(d.content[off]) << 56)
485 h := hash8(cv, bestLongTableBits)
486 e.dictLongTable[h] = prevEntry{
487 offset: i,
488 prev: e.dictLongTable[h].offset,
489 }
490 off++
491 }
492 }
493 e.lastDictID = d.id
494 }
495 // Reset table to initial state
496 copy(e.longTable[:], e.dictLongTable)
497
498 e.cur = e.maxMatchOff
499 // Reset table to initial state
500 copy(e.table[:], e.dictTable)
501}