blob: 7aaaedb23e58c639f1f9a7709279c3b77f3cfda2 [file] [log] [blame]
khenaidoo7d3c5582021-08-11 18:09:44 -04001// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8 "crypto/rand"
9 "fmt"
10 "io"
11 rdebug "runtime/debug"
12 "sync"
13
14 "github.com/klauspost/compress/zstd/internal/xxhash"
15)
16
17// Encoder provides encoding to Zstandard.
18// An Encoder can be used for either compressing a stream via the
19// io.WriteCloser interface supported by the Encoder or as multiple independent
20// tasks via the EncodeAll function.
21// Smaller encodes are encouraged to use the EncodeAll function.
22// Use NewWriter to create a new instance.
23type Encoder struct {
24 o encoderOptions
25 encoders chan encoder
26 state encoderState
27 init sync.Once
28}
29
30type encoder interface {
31 Encode(blk *blockEnc, src []byte)
32 EncodeNoHist(blk *blockEnc, src []byte)
33 Block() *blockEnc
34 CRC() *xxhash.Digest
35 AppendCRC([]byte) []byte
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +053036 WindowSize(size int64) int32
khenaidoo7d3c5582021-08-11 18:09:44 -040037 UseBlock(*blockEnc)
38 Reset(d *dict, singleBlock bool)
39}
40
41type encoderState struct {
42 w io.Writer
43 filling []byte
44 current []byte
45 previous []byte
46 encoder encoder
47 writing *blockEnc
48 err error
49 writeErr error
50 nWritten int64
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +053051 nInput int64
52 frameContentSize int64
khenaidoo7d3c5582021-08-11 18:09:44 -040053 headerWritten bool
54 eofWritten bool
55 fullFrameWritten bool
56
57 // This waitgroup indicates an encode is running.
58 wg sync.WaitGroup
59 // This waitgroup indicates we have a block encoding/writing.
60 wWg sync.WaitGroup
61}
62
63// NewWriter will create a new Zstandard encoder.
64// If the encoder will be used for encoding blocks a nil writer can be used.
65func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) {
66 initPredefined()
67 var e Encoder
68 e.o.setDefault()
69 for _, o := range opts {
70 err := o(&e.o)
71 if err != nil {
72 return nil, err
73 }
74 }
75 if w != nil {
76 e.Reset(w)
77 }
78 return &e, nil
79}
80
81func (e *Encoder) initialize() {
82 if e.o.concurrent == 0 {
83 e.o.setDefault()
84 }
85 e.encoders = make(chan encoder, e.o.concurrent)
86 for i := 0; i < e.o.concurrent; i++ {
87 enc := e.o.encoder()
88 e.encoders <- enc
89 }
90}
91
92// Reset will re-initialize the writer and new writes will encode to the supplied writer
93// as a new, independent stream.
94func (e *Encoder) Reset(w io.Writer) {
95 s := &e.state
96 s.wg.Wait()
97 s.wWg.Wait()
98 if cap(s.filling) == 0 {
99 s.filling = make([]byte, 0, e.o.blockSize)
100 }
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +0530101 if e.o.concurrent > 1 {
102 if cap(s.current) == 0 {
103 s.current = make([]byte, 0, e.o.blockSize)
104 }
105 if cap(s.previous) == 0 {
106 s.previous = make([]byte, 0, e.o.blockSize)
107 }
108 s.current = s.current[:0]
109 s.previous = s.previous[:0]
110 if s.writing == nil {
111 s.writing = &blockEnc{lowMem: e.o.lowMem}
112 s.writing.init()
113 }
114 s.writing.initNewEncode()
khenaidoo7d3c5582021-08-11 18:09:44 -0400115 }
116 if s.encoder == nil {
117 s.encoder = e.o.encoder()
118 }
khenaidoo7d3c5582021-08-11 18:09:44 -0400119 s.filling = s.filling[:0]
khenaidoo7d3c5582021-08-11 18:09:44 -0400120 s.encoder.Reset(e.o.dict, false)
121 s.headerWritten = false
122 s.eofWritten = false
123 s.fullFrameWritten = false
124 s.w = w
125 s.err = nil
126 s.nWritten = 0
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +0530127 s.nInput = 0
khenaidoo7d3c5582021-08-11 18:09:44 -0400128 s.writeErr = nil
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +0530129 s.frameContentSize = 0
130}
131
132// ResetContentSize will reset and set a content size for the next stream.
133// If the bytes written does not match the size given an error will be returned
134// when calling Close().
135// This is removed when Reset is called.
136// Sizes <= 0 results in no content size set.
137func (e *Encoder) ResetContentSize(w io.Writer, size int64) {
138 e.Reset(w)
139 if size >= 0 {
140 e.state.frameContentSize = size
141 }
khenaidoo7d3c5582021-08-11 18:09:44 -0400142}
143
144// Write data to the encoder.
145// Input data will be buffered and as the buffer fills up
146// content will be compressed and written to the output.
147// When done writing, use Close to flush the remaining output
148// and write CRC if requested.
149func (e *Encoder) Write(p []byte) (n int, err error) {
150 s := &e.state
151 for len(p) > 0 {
152 if len(p)+len(s.filling) < e.o.blockSize {
153 if e.o.crc {
154 _, _ = s.encoder.CRC().Write(p)
155 }
156 s.filling = append(s.filling, p...)
157 return n + len(p), nil
158 }
159 add := p
160 if len(p)+len(s.filling) > e.o.blockSize {
161 add = add[:e.o.blockSize-len(s.filling)]
162 }
163 if e.o.crc {
164 _, _ = s.encoder.CRC().Write(add)
165 }
166 s.filling = append(s.filling, add...)
167 p = p[len(add):]
168 n += len(add)
169 if len(s.filling) < e.o.blockSize {
170 return n, nil
171 }
172 err := e.nextBlock(false)
173 if err != nil {
174 return n, err
175 }
176 if debugAsserts && len(s.filling) > 0 {
177 panic(len(s.filling))
178 }
179 }
180 return n, nil
181}
182
183// nextBlock will synchronize and start compressing input in e.state.filling.
184// If an error has occurred during encoding it will be returned.
185func (e *Encoder) nextBlock(final bool) error {
186 s := &e.state
187 // Wait for current block.
188 s.wg.Wait()
189 if s.err != nil {
190 return s.err
191 }
192 if len(s.filling) > e.o.blockSize {
193 return fmt.Errorf("block > maxStoreBlockSize")
194 }
195 if !s.headerWritten {
196 // If we have a single block encode, do a sync compression.
197 if final && len(s.filling) == 0 && !e.o.fullZero {
198 s.headerWritten = true
199 s.fullFrameWritten = true
200 s.eofWritten = true
201 return nil
202 }
203 if final && len(s.filling) > 0 {
204 s.current = e.EncodeAll(s.filling, s.current[:0])
205 var n2 int
206 n2, s.err = s.w.Write(s.current)
207 if s.err != nil {
208 return s.err
209 }
210 s.nWritten += int64(n2)
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +0530211 s.nInput += int64(len(s.filling))
khenaidoo7d3c5582021-08-11 18:09:44 -0400212 s.current = s.current[:0]
213 s.filling = s.filling[:0]
214 s.headerWritten = true
215 s.fullFrameWritten = true
216 s.eofWritten = true
217 return nil
218 }
219
220 var tmp [maxHeaderSize]byte
221 fh := frameHeader{
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +0530222 ContentSize: uint64(s.frameContentSize),
223 WindowSize: uint32(s.encoder.WindowSize(s.frameContentSize)),
khenaidoo7d3c5582021-08-11 18:09:44 -0400224 SingleSegment: false,
225 Checksum: e.o.crc,
226 DictID: e.o.dict.ID(),
227 }
228
229 dst, err := fh.appendTo(tmp[:0])
230 if err != nil {
231 return err
232 }
233 s.headerWritten = true
234 s.wWg.Wait()
235 var n2 int
236 n2, s.err = s.w.Write(dst)
237 if s.err != nil {
238 return s.err
239 }
240 s.nWritten += int64(n2)
241 }
242 if s.eofWritten {
243 // Ensure we only write it once.
244 final = false
245 }
246
247 if len(s.filling) == 0 {
248 // Final block, but no data.
249 if final {
250 enc := s.encoder
251 blk := enc.Block()
252 blk.reset(nil)
253 blk.last = true
254 blk.encodeRaw(nil)
255 s.wWg.Wait()
256 _, s.err = s.w.Write(blk.output)
257 s.nWritten += int64(len(blk.output))
258 s.eofWritten = true
259 }
260 return s.err
261 }
262
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +0530263 // SYNC:
264 if e.o.concurrent == 1 {
265 src := s.filling
266 s.nInput += int64(len(s.filling))
267 if debugEncoder {
268 println("Adding sync block,", len(src), "bytes, final:", final)
269 }
270 enc := s.encoder
271 blk := enc.Block()
272 blk.reset(nil)
273 enc.Encode(blk, src)
274 blk.last = final
275 if final {
276 s.eofWritten = true
277 }
278
279 err := errIncompressible
280 // If we got the exact same number of literals as input,
281 // assume the literals cannot be compressed.
282 if len(src) != len(blk.literals) || len(src) != e.o.blockSize {
283 err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
284 }
285 switch err {
286 case errIncompressible:
287 if debugEncoder {
288 println("Storing incompressible block as raw")
289 }
290 blk.encodeRaw(src)
291 // In fast mode, we do not transfer offsets, so we don't have to deal with changing the.
292 case nil:
293 default:
294 s.err = err
295 return err
296 }
297 _, s.err = s.w.Write(blk.output)
298 s.nWritten += int64(len(blk.output))
299 s.filling = s.filling[:0]
300 return s.err
301 }
302
khenaidoo7d3c5582021-08-11 18:09:44 -0400303 // Move blocks forward.
304 s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +0530305 s.nInput += int64(len(s.current))
khenaidoo7d3c5582021-08-11 18:09:44 -0400306 s.wg.Add(1)
307 go func(src []byte) {
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +0530308 if debugEncoder {
khenaidoo7d3c5582021-08-11 18:09:44 -0400309 println("Adding block,", len(src), "bytes, final:", final)
310 }
311 defer func() {
312 if r := recover(); r != nil {
313 s.err = fmt.Errorf("panic while encoding: %v", r)
314 rdebug.PrintStack()
315 }
316 s.wg.Done()
317 }()
318 enc := s.encoder
319 blk := enc.Block()
320 enc.Encode(blk, src)
321 blk.last = final
322 if final {
323 s.eofWritten = true
324 }
325 // Wait for pending writes.
326 s.wWg.Wait()
327 if s.writeErr != nil {
328 s.err = s.writeErr
329 return
330 }
331 // Transfer encoders from previous write block.
332 blk.swapEncoders(s.writing)
333 // Transfer recent offsets to next.
334 enc.UseBlock(s.writing)
335 s.writing = blk
336 s.wWg.Add(1)
337 go func() {
338 defer func() {
339 if r := recover(); r != nil {
340 s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r)
341 rdebug.PrintStack()
342 }
343 s.wWg.Done()
344 }()
345 err := errIncompressible
346 // If we got the exact same number of literals as input,
347 // assume the literals cannot be compressed.
348 if len(src) != len(blk.literals) || len(src) != e.o.blockSize {
349 err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
350 }
351 switch err {
352 case errIncompressible:
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +0530353 if debugEncoder {
khenaidoo7d3c5582021-08-11 18:09:44 -0400354 println("Storing incompressible block as raw")
355 }
356 blk.encodeRaw(src)
357 // In fast mode, we do not transfer offsets, so we don't have to deal with changing the.
358 case nil:
359 default:
360 s.writeErr = err
361 return
362 }
363 _, s.writeErr = s.w.Write(blk.output)
364 s.nWritten += int64(len(blk.output))
365 }()
366 }(s.current)
367 return nil
368}
369
370// ReadFrom reads data from r until EOF or error.
371// The return value n is the number of bytes read.
372// Any error except io.EOF encountered during the read is also returned.
373//
374// The Copy function uses ReaderFrom if available.
375func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +0530376 if debugEncoder {
khenaidoo7d3c5582021-08-11 18:09:44 -0400377 println("Using ReadFrom")
378 }
379
380 // Flush any current writes.
381 if len(e.state.filling) > 0 {
382 if err := e.nextBlock(false); err != nil {
383 return 0, err
384 }
385 }
386 e.state.filling = e.state.filling[:e.o.blockSize]
387 src := e.state.filling
388 for {
389 n2, err := r.Read(src)
390 if e.o.crc {
391 _, _ = e.state.encoder.CRC().Write(src[:n2])
392 }
393 // src is now the unfilled part...
394 src = src[n2:]
395 n += int64(n2)
396 switch err {
397 case io.EOF:
398 e.state.filling = e.state.filling[:len(e.state.filling)-len(src)]
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +0530399 if debugEncoder {
khenaidoo7d3c5582021-08-11 18:09:44 -0400400 println("ReadFrom: got EOF final block:", len(e.state.filling))
401 }
402 return n, nil
403 case nil:
404 default:
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +0530405 if debugEncoder {
khenaidoo7d3c5582021-08-11 18:09:44 -0400406 println("ReadFrom: got error:", err)
407 }
408 e.state.err = err
409 return n, err
410 }
411 if len(src) > 0 {
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +0530412 if debugEncoder {
khenaidoo7d3c5582021-08-11 18:09:44 -0400413 println("ReadFrom: got space left in source:", len(src))
414 }
415 continue
416 }
417 err = e.nextBlock(false)
418 if err != nil {
419 return n, err
420 }
421 e.state.filling = e.state.filling[:e.o.blockSize]
422 src = e.state.filling
423 }
424}
425
426// Flush will send the currently written data to output
427// and block until everything has been written.
428// This should only be used on rare occasions where pushing the currently queued data is critical.
429func (e *Encoder) Flush() error {
430 s := &e.state
431 if len(s.filling) > 0 {
432 err := e.nextBlock(false)
433 if err != nil {
434 return err
435 }
436 }
437 s.wg.Wait()
438 s.wWg.Wait()
439 if s.err != nil {
440 return s.err
441 }
442 return s.writeErr
443}
444
445// Close will flush the final output and close the stream.
446// The function will block until everything has been written.
447// The Encoder can still be re-used after calling this.
448func (e *Encoder) Close() error {
449 s := &e.state
450 if s.encoder == nil {
451 return nil
452 }
453 err := e.nextBlock(true)
454 if err != nil {
455 return err
456 }
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +0530457 if s.frameContentSize > 0 {
458 if s.nInput != s.frameContentSize {
459 return fmt.Errorf("frame content size %d given, but %d bytes was written", s.frameContentSize, s.nInput)
460 }
461 }
khenaidoo7d3c5582021-08-11 18:09:44 -0400462 if e.state.fullFrameWritten {
463 return s.err
464 }
465 s.wg.Wait()
466 s.wWg.Wait()
467
468 if s.err != nil {
469 return s.err
470 }
471 if s.writeErr != nil {
472 return s.writeErr
473 }
474
475 // Write CRC
476 if e.o.crc && s.err == nil {
477 // heap alloc.
478 var tmp [4]byte
479 _, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0]))
480 s.nWritten += 4
481 }
482
483 // Add padding with content from crypto/rand.Reader
484 if s.err == nil && e.o.pad > 0 {
485 add := calcSkippableFrame(s.nWritten, int64(e.o.pad))
486 frame, err := skippableFrame(s.filling[:0], add, rand.Reader)
487 if err != nil {
488 return err
489 }
490 _, s.err = s.w.Write(frame)
491 }
492 return s.err
493}
494
495// EncodeAll will encode all input in src and append it to dst.
496// This function can be called concurrently, but each call will only run on a single goroutine.
497// If empty input is given, nothing is returned, unless WithZeroFrames is specified.
498// Encoded blocks can be concatenated and the result will be the combined input stream.
499// Data compressed with EncodeAll can be decoded with the Decoder,
500// using either a stream or DecodeAll.
501func (e *Encoder) EncodeAll(src, dst []byte) []byte {
502 if len(src) == 0 {
503 if e.o.fullZero {
504 // Add frame header.
505 fh := frameHeader{
506 ContentSize: 0,
507 WindowSize: MinWindowSize,
508 SingleSegment: true,
509 // Adding a checksum would be a waste of space.
510 Checksum: false,
511 DictID: 0,
512 }
513 dst, _ = fh.appendTo(dst)
514
515 // Write raw block as last one only.
516 var blk blockHeader
517 blk.setSize(0)
518 blk.setType(blockTypeRaw)
519 blk.setLast(true)
520 dst = blk.appendTo(dst)
521 }
522 return dst
523 }
524 e.init.Do(e.initialize)
525 enc := <-e.encoders
526 defer func() {
527 // Release encoder reference to last block.
528 // If a non-single block is needed the encoder will reset again.
529 e.encoders <- enc
530 }()
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +0530531 // Use single segments when above minimum window and below window size.
532 single := len(src) <= e.o.windowSize && len(src) > MinWindowSize
khenaidoo7d3c5582021-08-11 18:09:44 -0400533 if e.o.single != nil {
534 single = *e.o.single
535 }
536 fh := frameHeader{
537 ContentSize: uint64(len(src)),
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +0530538 WindowSize: uint32(enc.WindowSize(int64(len(src)))),
khenaidoo7d3c5582021-08-11 18:09:44 -0400539 SingleSegment: single,
540 Checksum: e.o.crc,
541 DictID: e.o.dict.ID(),
542 }
543
544 // If less than 1MB, allocate a buffer up front.
545 if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 && !e.o.lowMem {
546 dst = make([]byte, 0, len(src))
547 }
548 dst, err := fh.appendTo(dst)
549 if err != nil {
550 panic(err)
551 }
552
553 // If we can do everything in one block, prefer that.
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +0530554 if len(src) <= e.o.blockSize {
khenaidoo7d3c5582021-08-11 18:09:44 -0400555 enc.Reset(e.o.dict, true)
556 // Slightly faster with no history and everything in one block.
557 if e.o.crc {
558 _, _ = enc.CRC().Write(src)
559 }
560 blk := enc.Block()
561 blk.last = true
562 if e.o.dict == nil {
563 enc.EncodeNoHist(blk, src)
564 } else {
565 enc.Encode(blk, src)
566 }
567
568 // If we got the exact same number of literals as input,
569 // assume the literals cannot be compressed.
570 err := errIncompressible
571 oldout := blk.output
572 if len(blk.literals) != len(src) || len(src) != e.o.blockSize {
573 // Output directly to dst
574 blk.output = dst
575 err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
576 }
577
578 switch err {
579 case errIncompressible:
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +0530580 if debugEncoder {
khenaidoo7d3c5582021-08-11 18:09:44 -0400581 println("Storing incompressible block as raw")
582 }
583 dst = blk.encodeRawTo(dst, src)
584 case nil:
585 dst = blk.output
586 default:
587 panic(err)
588 }
589 blk.output = oldout
590 } else {
591 enc.Reset(e.o.dict, false)
592 blk := enc.Block()
593 for len(src) > 0 {
594 todo := src
595 if len(todo) > e.o.blockSize {
596 todo = todo[:e.o.blockSize]
597 }
598 src = src[len(todo):]
599 if e.o.crc {
600 _, _ = enc.CRC().Write(todo)
601 }
602 blk.pushOffsets()
603 enc.Encode(blk, todo)
604 if len(src) == 0 {
605 blk.last = true
606 }
607 err := errIncompressible
608 // If we got the exact same number of literals as input,
609 // assume the literals cannot be compressed.
610 if len(blk.literals) != len(todo) || len(todo) != e.o.blockSize {
611 err = blk.encode(todo, e.o.noEntropy, !e.o.allLitEntropy)
612 }
613
614 switch err {
615 case errIncompressible:
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +0530616 if debugEncoder {
khenaidoo7d3c5582021-08-11 18:09:44 -0400617 println("Storing incompressible block as raw")
618 }
619 dst = blk.encodeRawTo(dst, todo)
620 blk.popOffsets()
621 case nil:
622 dst = append(dst, blk.output...)
623 default:
624 panic(err)
625 }
626 blk.reset(nil)
627 }
628 }
629 if e.o.crc {
630 dst = enc.AppendCRC(dst)
631 }
632 // Add padding with content from crypto/rand.Reader
633 if e.o.pad > 0 {
634 add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad))
635 dst, err = skippableFrame(dst, add, rand.Reader)
636 if err != nil {
637 panic(err)
638 }
639 }
640 return dst
641}