blob: 4871dd03affc6481fc0e031d6ca6c577b802d1f4 [file] [log] [blame]
khenaidoo106c61a2021-08-11 18:05:46 -04001// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8 "crypto/rand"
9 "fmt"
10 "io"
11 rdebug "runtime/debug"
12 "sync"
13
14 "github.com/klauspost/compress/zstd/internal/xxhash"
15)
16
17// Encoder provides encoding to Zstandard.
18// An Encoder can be used for either compressing a stream via the
19// io.WriteCloser interface supported by the Encoder or as multiple independent
20// tasks via the EncodeAll function.
21// Smaller encodes are encouraged to use the EncodeAll function.
22// Use NewWriter to create a new instance.
23type Encoder struct {
24 o encoderOptions
25 encoders chan encoder
26 state encoderState
27 init sync.Once
28}
29
30type encoder interface {
31 Encode(blk *blockEnc, src []byte)
32 EncodeNoHist(blk *blockEnc, src []byte)
33 Block() *blockEnc
34 CRC() *xxhash.Digest
35 AppendCRC([]byte) []byte
36 WindowSize(size int) int32
37 UseBlock(*blockEnc)
38 Reset(d *dict, singleBlock bool)
39}
40
41type encoderState struct {
42 w io.Writer
43 filling []byte
44 current []byte
45 previous []byte
46 encoder encoder
47 writing *blockEnc
48 err error
49 writeErr error
50 nWritten int64
51 headerWritten bool
52 eofWritten bool
53 fullFrameWritten bool
54
55 // This waitgroup indicates an encode is running.
56 wg sync.WaitGroup
57 // This waitgroup indicates we have a block encoding/writing.
58 wWg sync.WaitGroup
59}
60
61// NewWriter will create a new Zstandard encoder.
62// If the encoder will be used for encoding blocks a nil writer can be used.
63func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) {
64 initPredefined()
65 var e Encoder
66 e.o.setDefault()
67 for _, o := range opts {
68 err := o(&e.o)
69 if err != nil {
70 return nil, err
71 }
72 }
73 if w != nil {
74 e.Reset(w)
75 }
76 return &e, nil
77}
78
79func (e *Encoder) initialize() {
80 if e.o.concurrent == 0 {
81 e.o.setDefault()
82 }
83 e.encoders = make(chan encoder, e.o.concurrent)
84 for i := 0; i < e.o.concurrent; i++ {
85 enc := e.o.encoder()
86 e.encoders <- enc
87 }
88}
89
90// Reset will re-initialize the writer and new writes will encode to the supplied writer
91// as a new, independent stream.
92func (e *Encoder) Reset(w io.Writer) {
93 s := &e.state
94 s.wg.Wait()
95 s.wWg.Wait()
96 if cap(s.filling) == 0 {
97 s.filling = make([]byte, 0, e.o.blockSize)
98 }
99 if cap(s.current) == 0 {
100 s.current = make([]byte, 0, e.o.blockSize)
101 }
102 if cap(s.previous) == 0 {
103 s.previous = make([]byte, 0, e.o.blockSize)
104 }
105 if s.encoder == nil {
106 s.encoder = e.o.encoder()
107 }
108 if s.writing == nil {
109 s.writing = &blockEnc{lowMem: e.o.lowMem}
110 s.writing.init()
111 }
112 s.writing.initNewEncode()
113 s.filling = s.filling[:0]
114 s.current = s.current[:0]
115 s.previous = s.previous[:0]
116 s.encoder.Reset(e.o.dict, false)
117 s.headerWritten = false
118 s.eofWritten = false
119 s.fullFrameWritten = false
120 s.w = w
121 s.err = nil
122 s.nWritten = 0
123 s.writeErr = nil
124}
125
126// Write data to the encoder.
127// Input data will be buffered and as the buffer fills up
128// content will be compressed and written to the output.
129// When done writing, use Close to flush the remaining output
130// and write CRC if requested.
131func (e *Encoder) Write(p []byte) (n int, err error) {
132 s := &e.state
133 for len(p) > 0 {
134 if len(p)+len(s.filling) < e.o.blockSize {
135 if e.o.crc {
136 _, _ = s.encoder.CRC().Write(p)
137 }
138 s.filling = append(s.filling, p...)
139 return n + len(p), nil
140 }
141 add := p
142 if len(p)+len(s.filling) > e.o.blockSize {
143 add = add[:e.o.blockSize-len(s.filling)]
144 }
145 if e.o.crc {
146 _, _ = s.encoder.CRC().Write(add)
147 }
148 s.filling = append(s.filling, add...)
149 p = p[len(add):]
150 n += len(add)
151 if len(s.filling) < e.o.blockSize {
152 return n, nil
153 }
154 err := e.nextBlock(false)
155 if err != nil {
156 return n, err
157 }
158 if debugAsserts && len(s.filling) > 0 {
159 panic(len(s.filling))
160 }
161 }
162 return n, nil
163}
164
165// nextBlock will synchronize and start compressing input in e.state.filling.
166// If an error has occurred during encoding it will be returned.
167func (e *Encoder) nextBlock(final bool) error {
168 s := &e.state
169 // Wait for current block.
170 s.wg.Wait()
171 if s.err != nil {
172 return s.err
173 }
174 if len(s.filling) > e.o.blockSize {
175 return fmt.Errorf("block > maxStoreBlockSize")
176 }
177 if !s.headerWritten {
178 // If we have a single block encode, do a sync compression.
179 if final && len(s.filling) == 0 && !e.o.fullZero {
180 s.headerWritten = true
181 s.fullFrameWritten = true
182 s.eofWritten = true
183 return nil
184 }
185 if final && len(s.filling) > 0 {
186 s.current = e.EncodeAll(s.filling, s.current[:0])
187 var n2 int
188 n2, s.err = s.w.Write(s.current)
189 if s.err != nil {
190 return s.err
191 }
192 s.nWritten += int64(n2)
193 s.current = s.current[:0]
194 s.filling = s.filling[:0]
195 s.headerWritten = true
196 s.fullFrameWritten = true
197 s.eofWritten = true
198 return nil
199 }
200
201 var tmp [maxHeaderSize]byte
202 fh := frameHeader{
203 ContentSize: 0,
204 WindowSize: uint32(s.encoder.WindowSize(0)),
205 SingleSegment: false,
206 Checksum: e.o.crc,
207 DictID: e.o.dict.ID(),
208 }
209
210 dst, err := fh.appendTo(tmp[:0])
211 if err != nil {
212 return err
213 }
214 s.headerWritten = true
215 s.wWg.Wait()
216 var n2 int
217 n2, s.err = s.w.Write(dst)
218 if s.err != nil {
219 return s.err
220 }
221 s.nWritten += int64(n2)
222 }
223 if s.eofWritten {
224 // Ensure we only write it once.
225 final = false
226 }
227
228 if len(s.filling) == 0 {
229 // Final block, but no data.
230 if final {
231 enc := s.encoder
232 blk := enc.Block()
233 blk.reset(nil)
234 blk.last = true
235 blk.encodeRaw(nil)
236 s.wWg.Wait()
237 _, s.err = s.w.Write(blk.output)
238 s.nWritten += int64(len(blk.output))
239 s.eofWritten = true
240 }
241 return s.err
242 }
243
244 // Move blocks forward.
245 s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
246 s.wg.Add(1)
247 go func(src []byte) {
248 if debug {
249 println("Adding block,", len(src), "bytes, final:", final)
250 }
251 defer func() {
252 if r := recover(); r != nil {
253 s.err = fmt.Errorf("panic while encoding: %v", r)
254 rdebug.PrintStack()
255 }
256 s.wg.Done()
257 }()
258 enc := s.encoder
259 blk := enc.Block()
260 enc.Encode(blk, src)
261 blk.last = final
262 if final {
263 s.eofWritten = true
264 }
265 // Wait for pending writes.
266 s.wWg.Wait()
267 if s.writeErr != nil {
268 s.err = s.writeErr
269 return
270 }
271 // Transfer encoders from previous write block.
272 blk.swapEncoders(s.writing)
273 // Transfer recent offsets to next.
274 enc.UseBlock(s.writing)
275 s.writing = blk
276 s.wWg.Add(1)
277 go func() {
278 defer func() {
279 if r := recover(); r != nil {
280 s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r)
281 rdebug.PrintStack()
282 }
283 s.wWg.Done()
284 }()
285 err := errIncompressible
286 // If we got the exact same number of literals as input,
287 // assume the literals cannot be compressed.
288 if len(src) != len(blk.literals) || len(src) != e.o.blockSize {
289 err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
290 }
291 switch err {
292 case errIncompressible:
293 if debug {
294 println("Storing incompressible block as raw")
295 }
296 blk.encodeRaw(src)
297 // In fast mode, we do not transfer offsets, so we don't have to deal with changing the.
298 case nil:
299 default:
300 s.writeErr = err
301 return
302 }
303 _, s.writeErr = s.w.Write(blk.output)
304 s.nWritten += int64(len(blk.output))
305 }()
306 }(s.current)
307 return nil
308}
309
310// ReadFrom reads data from r until EOF or error.
311// The return value n is the number of bytes read.
312// Any error except io.EOF encountered during the read is also returned.
313//
314// The Copy function uses ReaderFrom if available.
315func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
316 if debug {
317 println("Using ReadFrom")
318 }
319
320 // Flush any current writes.
321 if len(e.state.filling) > 0 {
322 if err := e.nextBlock(false); err != nil {
323 return 0, err
324 }
325 }
326 e.state.filling = e.state.filling[:e.o.blockSize]
327 src := e.state.filling
328 for {
329 n2, err := r.Read(src)
330 if e.o.crc {
331 _, _ = e.state.encoder.CRC().Write(src[:n2])
332 }
333 // src is now the unfilled part...
334 src = src[n2:]
335 n += int64(n2)
336 switch err {
337 case io.EOF:
338 e.state.filling = e.state.filling[:len(e.state.filling)-len(src)]
339 if debug {
340 println("ReadFrom: got EOF final block:", len(e.state.filling))
341 }
342 return n, nil
343 case nil:
344 default:
345 if debug {
346 println("ReadFrom: got error:", err)
347 }
348 e.state.err = err
349 return n, err
350 }
351 if len(src) > 0 {
352 if debug {
353 println("ReadFrom: got space left in source:", len(src))
354 }
355 continue
356 }
357 err = e.nextBlock(false)
358 if err != nil {
359 return n, err
360 }
361 e.state.filling = e.state.filling[:e.o.blockSize]
362 src = e.state.filling
363 }
364}
365
366// Flush will send the currently written data to output
367// and block until everything has been written.
368// This should only be used on rare occasions where pushing the currently queued data is critical.
369func (e *Encoder) Flush() error {
370 s := &e.state
371 if len(s.filling) > 0 {
372 err := e.nextBlock(false)
373 if err != nil {
374 return err
375 }
376 }
377 s.wg.Wait()
378 s.wWg.Wait()
379 if s.err != nil {
380 return s.err
381 }
382 return s.writeErr
383}
384
385// Close will flush the final output and close the stream.
386// The function will block until everything has been written.
387// The Encoder can still be re-used after calling this.
388func (e *Encoder) Close() error {
389 s := &e.state
390 if s.encoder == nil {
391 return nil
392 }
393 err := e.nextBlock(true)
394 if err != nil {
395 return err
396 }
397 if e.state.fullFrameWritten {
398 return s.err
399 }
400 s.wg.Wait()
401 s.wWg.Wait()
402
403 if s.err != nil {
404 return s.err
405 }
406 if s.writeErr != nil {
407 return s.writeErr
408 }
409
410 // Write CRC
411 if e.o.crc && s.err == nil {
412 // heap alloc.
413 var tmp [4]byte
414 _, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0]))
415 s.nWritten += 4
416 }
417
418 // Add padding with content from crypto/rand.Reader
419 if s.err == nil && e.o.pad > 0 {
420 add := calcSkippableFrame(s.nWritten, int64(e.o.pad))
421 frame, err := skippableFrame(s.filling[:0], add, rand.Reader)
422 if err != nil {
423 return err
424 }
425 _, s.err = s.w.Write(frame)
426 }
427 return s.err
428}
429
430// EncodeAll will encode all input in src and append it to dst.
431// This function can be called concurrently, but each call will only run on a single goroutine.
432// If empty input is given, nothing is returned, unless WithZeroFrames is specified.
433// Encoded blocks can be concatenated and the result will be the combined input stream.
434// Data compressed with EncodeAll can be decoded with the Decoder,
435// using either a stream or DecodeAll.
436func (e *Encoder) EncodeAll(src, dst []byte) []byte {
437 if len(src) == 0 {
438 if e.o.fullZero {
439 // Add frame header.
440 fh := frameHeader{
441 ContentSize: 0,
442 WindowSize: MinWindowSize,
443 SingleSegment: true,
444 // Adding a checksum would be a waste of space.
445 Checksum: false,
446 DictID: 0,
447 }
448 dst, _ = fh.appendTo(dst)
449
450 // Write raw block as last one only.
451 var blk blockHeader
452 blk.setSize(0)
453 blk.setType(blockTypeRaw)
454 blk.setLast(true)
455 dst = blk.appendTo(dst)
456 }
457 return dst
458 }
459 e.init.Do(e.initialize)
460 enc := <-e.encoders
461 defer func() {
462 // Release encoder reference to last block.
463 // If a non-single block is needed the encoder will reset again.
464 e.encoders <- enc
465 }()
466 // Use single segments when above minimum window and below 1MB.
467 single := len(src) < 1<<20 && len(src) > MinWindowSize
468 if e.o.single != nil {
469 single = *e.o.single
470 }
471 fh := frameHeader{
472 ContentSize: uint64(len(src)),
473 WindowSize: uint32(enc.WindowSize(len(src))),
474 SingleSegment: single,
475 Checksum: e.o.crc,
476 DictID: e.o.dict.ID(),
477 }
478
479 // If less than 1MB, allocate a buffer up front.
480 if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 && !e.o.lowMem {
481 dst = make([]byte, 0, len(src))
482 }
483 dst, err := fh.appendTo(dst)
484 if err != nil {
485 panic(err)
486 }
487
488 // If we can do everything in one block, prefer that.
489 if len(src) <= maxCompressedBlockSize {
490 enc.Reset(e.o.dict, true)
491 // Slightly faster with no history and everything in one block.
492 if e.o.crc {
493 _, _ = enc.CRC().Write(src)
494 }
495 blk := enc.Block()
496 blk.last = true
497 if e.o.dict == nil {
498 enc.EncodeNoHist(blk, src)
499 } else {
500 enc.Encode(blk, src)
501 }
502
503 // If we got the exact same number of literals as input,
504 // assume the literals cannot be compressed.
505 err := errIncompressible
506 oldout := blk.output
507 if len(blk.literals) != len(src) || len(src) != e.o.blockSize {
508 // Output directly to dst
509 blk.output = dst
510 err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
511 }
512
513 switch err {
514 case errIncompressible:
515 if debug {
516 println("Storing incompressible block as raw")
517 }
518 dst = blk.encodeRawTo(dst, src)
519 case nil:
520 dst = blk.output
521 default:
522 panic(err)
523 }
524 blk.output = oldout
525 } else {
526 enc.Reset(e.o.dict, false)
527 blk := enc.Block()
528 for len(src) > 0 {
529 todo := src
530 if len(todo) > e.o.blockSize {
531 todo = todo[:e.o.blockSize]
532 }
533 src = src[len(todo):]
534 if e.o.crc {
535 _, _ = enc.CRC().Write(todo)
536 }
537 blk.pushOffsets()
538 enc.Encode(blk, todo)
539 if len(src) == 0 {
540 blk.last = true
541 }
542 err := errIncompressible
543 // If we got the exact same number of literals as input,
544 // assume the literals cannot be compressed.
545 if len(blk.literals) != len(todo) || len(todo) != e.o.blockSize {
546 err = blk.encode(todo, e.o.noEntropy, !e.o.allLitEntropy)
547 }
548
549 switch err {
550 case errIncompressible:
551 if debug {
552 println("Storing incompressible block as raw")
553 }
554 dst = blk.encodeRawTo(dst, todo)
555 blk.popOffsets()
556 case nil:
557 dst = append(dst, blk.output...)
558 default:
559 panic(err)
560 }
561 blk.reset(nil)
562 }
563 }
564 if e.o.crc {
565 dst = enc.AppendCRC(dst)
566 }
567 // Add padding with content from crypto/rand.Reader
568 if e.o.pad > 0 {
569 add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad))
570 dst, err = skippableFrame(dst, add, rand.Reader)
571 if err != nil {
572 panic(err)
573 }
574 }
575 return dst
576}