blob: e6e315969b00b89c8536b5d9f4753a896d3aba92 [file] [log] [blame]
kesavandc71914f2022-03-25 11:19:03 +05301// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8 "crypto/rand"
9 "fmt"
10 "io"
11 rdebug "runtime/debug"
12 "sync"
13
14 "github.com/klauspost/compress/zstd/internal/xxhash"
15)
16
17// Encoder provides encoding to Zstandard.
18// An Encoder can be used for either compressing a stream via the
19// io.WriteCloser interface supported by the Encoder or as multiple independent
20// tasks via the EncodeAll function.
21// Smaller encodes are encouraged to use the EncodeAll function.
22// Use NewWriter to create a new instance.
23type Encoder struct {
24 o encoderOptions
25 encoders chan encoder
26 state encoderState
27 init sync.Once
28}
29
30type encoder interface {
31 Encode(blk *blockEnc, src []byte)
32 EncodeNoHist(blk *blockEnc, src []byte)
33 Block() *blockEnc
34 CRC() *xxhash.Digest
35 AppendCRC([]byte) []byte
36 WindowSize(size int64) int32
37 UseBlock(*blockEnc)
38 Reset(d *dict, singleBlock bool)
39}
40
41type encoderState struct {
42 w io.Writer
43 filling []byte
44 current []byte
45 previous []byte
46 encoder encoder
47 writing *blockEnc
48 err error
49 writeErr error
50 nWritten int64
51 nInput int64
52 frameContentSize int64
53 headerWritten bool
54 eofWritten bool
55 fullFrameWritten bool
56
57 // This waitgroup indicates an encode is running.
58 wg sync.WaitGroup
59 // This waitgroup indicates we have a block encoding/writing.
60 wWg sync.WaitGroup
61}
62
63// NewWriter will create a new Zstandard encoder.
64// If the encoder will be used for encoding blocks a nil writer can be used.
65func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) {
66 initPredefined()
67 var e Encoder
68 e.o.setDefault()
69 for _, o := range opts {
70 err := o(&e.o)
71 if err != nil {
72 return nil, err
73 }
74 }
75 if w != nil {
76 e.Reset(w)
77 }
78 return &e, nil
79}
80
81func (e *Encoder) initialize() {
82 if e.o.concurrent == 0 {
83 e.o.setDefault()
84 }
85 e.encoders = make(chan encoder, e.o.concurrent)
86 for i := 0; i < e.o.concurrent; i++ {
87 enc := e.o.encoder()
88 e.encoders <- enc
89 }
90}
91
92// Reset will re-initialize the writer and new writes will encode to the supplied writer
93// as a new, independent stream.
94func (e *Encoder) Reset(w io.Writer) {
95 s := &e.state
96 s.wg.Wait()
97 s.wWg.Wait()
98 if cap(s.filling) == 0 {
99 s.filling = make([]byte, 0, e.o.blockSize)
100 }
101 if cap(s.current) == 0 {
102 s.current = make([]byte, 0, e.o.blockSize)
103 }
104 if cap(s.previous) == 0 {
105 s.previous = make([]byte, 0, e.o.blockSize)
106 }
107 if s.encoder == nil {
108 s.encoder = e.o.encoder()
109 }
110 if s.writing == nil {
111 s.writing = &blockEnc{lowMem: e.o.lowMem}
112 s.writing.init()
113 }
114 s.writing.initNewEncode()
115 s.filling = s.filling[:0]
116 s.current = s.current[:0]
117 s.previous = s.previous[:0]
118 s.encoder.Reset(e.o.dict, false)
119 s.headerWritten = false
120 s.eofWritten = false
121 s.fullFrameWritten = false
122 s.w = w
123 s.err = nil
124 s.nWritten = 0
125 s.nInput = 0
126 s.writeErr = nil
127 s.frameContentSize = 0
128}
129
130// ResetContentSize will reset and set a content size for the next stream.
131// If the bytes written does not match the size given an error will be returned
132// when calling Close().
133// This is removed when Reset is called.
134// Sizes <= 0 results in no content size set.
135func (e *Encoder) ResetContentSize(w io.Writer, size int64) {
136 e.Reset(w)
137 if size >= 0 {
138 e.state.frameContentSize = size
139 }
140}
141
142// Write data to the encoder.
143// Input data will be buffered and as the buffer fills up
144// content will be compressed and written to the output.
145// When done writing, use Close to flush the remaining output
146// and write CRC if requested.
147func (e *Encoder) Write(p []byte) (n int, err error) {
148 s := &e.state
149 for len(p) > 0 {
150 if len(p)+len(s.filling) < e.o.blockSize {
151 if e.o.crc {
152 _, _ = s.encoder.CRC().Write(p)
153 }
154 s.filling = append(s.filling, p...)
155 return n + len(p), nil
156 }
157 add := p
158 if len(p)+len(s.filling) > e.o.blockSize {
159 add = add[:e.o.blockSize-len(s.filling)]
160 }
161 if e.o.crc {
162 _, _ = s.encoder.CRC().Write(add)
163 }
164 s.filling = append(s.filling, add...)
165 p = p[len(add):]
166 n += len(add)
167 if len(s.filling) < e.o.blockSize {
168 return n, nil
169 }
170 err := e.nextBlock(false)
171 if err != nil {
172 return n, err
173 }
174 if debugAsserts && len(s.filling) > 0 {
175 panic(len(s.filling))
176 }
177 }
178 return n, nil
179}
180
181// nextBlock will synchronize and start compressing input in e.state.filling.
182// If an error has occurred during encoding it will be returned.
183func (e *Encoder) nextBlock(final bool) error {
184 s := &e.state
185 // Wait for current block.
186 s.wg.Wait()
187 if s.err != nil {
188 return s.err
189 }
190 if len(s.filling) > e.o.blockSize {
191 return fmt.Errorf("block > maxStoreBlockSize")
192 }
193 if !s.headerWritten {
194 // If we have a single block encode, do a sync compression.
195 if final && len(s.filling) == 0 && !e.o.fullZero {
196 s.headerWritten = true
197 s.fullFrameWritten = true
198 s.eofWritten = true
199 return nil
200 }
201 if final && len(s.filling) > 0 {
202 s.current = e.EncodeAll(s.filling, s.current[:0])
203 var n2 int
204 n2, s.err = s.w.Write(s.current)
205 if s.err != nil {
206 return s.err
207 }
208 s.nWritten += int64(n2)
209 s.nInput += int64(len(s.filling))
210 s.current = s.current[:0]
211 s.filling = s.filling[:0]
212 s.headerWritten = true
213 s.fullFrameWritten = true
214 s.eofWritten = true
215 return nil
216 }
217
218 var tmp [maxHeaderSize]byte
219 fh := frameHeader{
220 ContentSize: uint64(s.frameContentSize),
221 WindowSize: uint32(s.encoder.WindowSize(s.frameContentSize)),
222 SingleSegment: false,
223 Checksum: e.o.crc,
224 DictID: e.o.dict.ID(),
225 }
226
227 dst, err := fh.appendTo(tmp[:0])
228 if err != nil {
229 return err
230 }
231 s.headerWritten = true
232 s.wWg.Wait()
233 var n2 int
234 n2, s.err = s.w.Write(dst)
235 if s.err != nil {
236 return s.err
237 }
238 s.nWritten += int64(n2)
239 }
240 if s.eofWritten {
241 // Ensure we only write it once.
242 final = false
243 }
244
245 if len(s.filling) == 0 {
246 // Final block, but no data.
247 if final {
248 enc := s.encoder
249 blk := enc.Block()
250 blk.reset(nil)
251 blk.last = true
252 blk.encodeRaw(nil)
253 s.wWg.Wait()
254 _, s.err = s.w.Write(blk.output)
255 s.nWritten += int64(len(blk.output))
256 s.eofWritten = true
257 }
258 return s.err
259 }
260
261 // Move blocks forward.
262 s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
263 s.nInput += int64(len(s.current))
264 s.wg.Add(1)
265 go func(src []byte) {
266 if debugEncoder {
267 println("Adding block,", len(src), "bytes, final:", final)
268 }
269 defer func() {
270 if r := recover(); r != nil {
271 s.err = fmt.Errorf("panic while encoding: %v", r)
272 rdebug.PrintStack()
273 }
274 s.wg.Done()
275 }()
276 enc := s.encoder
277 blk := enc.Block()
278 enc.Encode(blk, src)
279 blk.last = final
280 if final {
281 s.eofWritten = true
282 }
283 // Wait for pending writes.
284 s.wWg.Wait()
285 if s.writeErr != nil {
286 s.err = s.writeErr
287 return
288 }
289 // Transfer encoders from previous write block.
290 blk.swapEncoders(s.writing)
291 // Transfer recent offsets to next.
292 enc.UseBlock(s.writing)
293 s.writing = blk
294 s.wWg.Add(1)
295 go func() {
296 defer func() {
297 if r := recover(); r != nil {
298 s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r)
299 rdebug.PrintStack()
300 }
301 s.wWg.Done()
302 }()
303 err := errIncompressible
304 // If we got the exact same number of literals as input,
305 // assume the literals cannot be compressed.
306 if len(src) != len(blk.literals) || len(src) != e.o.blockSize {
307 err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
308 }
309 switch err {
310 case errIncompressible:
311 if debugEncoder {
312 println("Storing incompressible block as raw")
313 }
314 blk.encodeRaw(src)
315 // In fast mode, we do not transfer offsets, so we don't have to deal with changing the.
316 case nil:
317 default:
318 s.writeErr = err
319 return
320 }
321 _, s.writeErr = s.w.Write(blk.output)
322 s.nWritten += int64(len(blk.output))
323 }()
324 }(s.current)
325 return nil
326}
327
328// ReadFrom reads data from r until EOF or error.
329// The return value n is the number of bytes read.
330// Any error except io.EOF encountered during the read is also returned.
331//
332// The Copy function uses ReaderFrom if available.
333func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
334 if debugEncoder {
335 println("Using ReadFrom")
336 }
337
338 // Flush any current writes.
339 if len(e.state.filling) > 0 {
340 if err := e.nextBlock(false); err != nil {
341 return 0, err
342 }
343 }
344 e.state.filling = e.state.filling[:e.o.blockSize]
345 src := e.state.filling
346 for {
347 n2, err := r.Read(src)
348 if e.o.crc {
349 _, _ = e.state.encoder.CRC().Write(src[:n2])
350 }
351 // src is now the unfilled part...
352 src = src[n2:]
353 n += int64(n2)
354 switch err {
355 case io.EOF:
356 e.state.filling = e.state.filling[:len(e.state.filling)-len(src)]
357 if debugEncoder {
358 println("ReadFrom: got EOF final block:", len(e.state.filling))
359 }
360 return n, nil
361 case nil:
362 default:
363 if debugEncoder {
364 println("ReadFrom: got error:", err)
365 }
366 e.state.err = err
367 return n, err
368 }
369 if len(src) > 0 {
370 if debugEncoder {
371 println("ReadFrom: got space left in source:", len(src))
372 }
373 continue
374 }
375 err = e.nextBlock(false)
376 if err != nil {
377 return n, err
378 }
379 e.state.filling = e.state.filling[:e.o.blockSize]
380 src = e.state.filling
381 }
382}
383
384// Flush will send the currently written data to output
385// and block until everything has been written.
386// This should only be used on rare occasions where pushing the currently queued data is critical.
387func (e *Encoder) Flush() error {
388 s := &e.state
389 if len(s.filling) > 0 {
390 err := e.nextBlock(false)
391 if err != nil {
392 return err
393 }
394 }
395 s.wg.Wait()
396 s.wWg.Wait()
397 if s.err != nil {
398 return s.err
399 }
400 return s.writeErr
401}
402
403// Close will flush the final output and close the stream.
404// The function will block until everything has been written.
405// The Encoder can still be re-used after calling this.
406func (e *Encoder) Close() error {
407 s := &e.state
408 if s.encoder == nil {
409 return nil
410 }
411 err := e.nextBlock(true)
412 if err != nil {
413 return err
414 }
415 if s.frameContentSize > 0 {
416 if s.nInput != s.frameContentSize {
417 return fmt.Errorf("frame content size %d given, but %d bytes was written", s.frameContentSize, s.nInput)
418 }
419 }
420 if e.state.fullFrameWritten {
421 return s.err
422 }
423 s.wg.Wait()
424 s.wWg.Wait()
425
426 if s.err != nil {
427 return s.err
428 }
429 if s.writeErr != nil {
430 return s.writeErr
431 }
432
433 // Write CRC
434 if e.o.crc && s.err == nil {
435 // heap alloc.
436 var tmp [4]byte
437 _, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0]))
438 s.nWritten += 4
439 }
440
441 // Add padding with content from crypto/rand.Reader
442 if s.err == nil && e.o.pad > 0 {
443 add := calcSkippableFrame(s.nWritten, int64(e.o.pad))
444 frame, err := skippableFrame(s.filling[:0], add, rand.Reader)
445 if err != nil {
446 return err
447 }
448 _, s.err = s.w.Write(frame)
449 }
450 return s.err
451}
452
453// EncodeAll will encode all input in src and append it to dst.
454// This function can be called concurrently, but each call will only run on a single goroutine.
455// If empty input is given, nothing is returned, unless WithZeroFrames is specified.
456// Encoded blocks can be concatenated and the result will be the combined input stream.
457// Data compressed with EncodeAll can be decoded with the Decoder,
458// using either a stream or DecodeAll.
459func (e *Encoder) EncodeAll(src, dst []byte) []byte {
460 if len(src) == 0 {
461 if e.o.fullZero {
462 // Add frame header.
463 fh := frameHeader{
464 ContentSize: 0,
465 WindowSize: MinWindowSize,
466 SingleSegment: true,
467 // Adding a checksum would be a waste of space.
468 Checksum: false,
469 DictID: 0,
470 }
471 dst, _ = fh.appendTo(dst)
472
473 // Write raw block as last one only.
474 var blk blockHeader
475 blk.setSize(0)
476 blk.setType(blockTypeRaw)
477 blk.setLast(true)
478 dst = blk.appendTo(dst)
479 }
480 return dst
481 }
482 e.init.Do(e.initialize)
483 enc := <-e.encoders
484 defer func() {
485 // Release encoder reference to last block.
486 // If a non-single block is needed the encoder will reset again.
487 e.encoders <- enc
488 }()
489 // Use single segments when above minimum window and below 1MB.
490 single := len(src) < 1<<20 && len(src) > MinWindowSize
491 if e.o.single != nil {
492 single = *e.o.single
493 }
494 fh := frameHeader{
495 ContentSize: uint64(len(src)),
496 WindowSize: uint32(enc.WindowSize(int64(len(src)))),
497 SingleSegment: single,
498 Checksum: e.o.crc,
499 DictID: e.o.dict.ID(),
500 }
501
502 // If less than 1MB, allocate a buffer up front.
503 if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 && !e.o.lowMem {
504 dst = make([]byte, 0, len(src))
505 }
506 dst, err := fh.appendTo(dst)
507 if err != nil {
508 panic(err)
509 }
510
511 // If we can do everything in one block, prefer that.
512 if len(src) <= maxCompressedBlockSize {
513 enc.Reset(e.o.dict, true)
514 // Slightly faster with no history and everything in one block.
515 if e.o.crc {
516 _, _ = enc.CRC().Write(src)
517 }
518 blk := enc.Block()
519 blk.last = true
520 if e.o.dict == nil {
521 enc.EncodeNoHist(blk, src)
522 } else {
523 enc.Encode(blk, src)
524 }
525
526 // If we got the exact same number of literals as input,
527 // assume the literals cannot be compressed.
528 err := errIncompressible
529 oldout := blk.output
530 if len(blk.literals) != len(src) || len(src) != e.o.blockSize {
531 // Output directly to dst
532 blk.output = dst
533 err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
534 }
535
536 switch err {
537 case errIncompressible:
538 if debugEncoder {
539 println("Storing incompressible block as raw")
540 }
541 dst = blk.encodeRawTo(dst, src)
542 case nil:
543 dst = blk.output
544 default:
545 panic(err)
546 }
547 blk.output = oldout
548 } else {
549 enc.Reset(e.o.dict, false)
550 blk := enc.Block()
551 for len(src) > 0 {
552 todo := src
553 if len(todo) > e.o.blockSize {
554 todo = todo[:e.o.blockSize]
555 }
556 src = src[len(todo):]
557 if e.o.crc {
558 _, _ = enc.CRC().Write(todo)
559 }
560 blk.pushOffsets()
561 enc.Encode(blk, todo)
562 if len(src) == 0 {
563 blk.last = true
564 }
565 err := errIncompressible
566 // If we got the exact same number of literals as input,
567 // assume the literals cannot be compressed.
568 if len(blk.literals) != len(todo) || len(todo) != e.o.blockSize {
569 err = blk.encode(todo, e.o.noEntropy, !e.o.allLitEntropy)
570 }
571
572 switch err {
573 case errIncompressible:
574 if debugEncoder {
575 println("Storing incompressible block as raw")
576 }
577 dst = blk.encodeRawTo(dst, todo)
578 blk.popOffsets()
579 case nil:
580 dst = append(dst, blk.output...)
581 default:
582 panic(err)
583 }
584 blk.reset(nil)
585 }
586 }
587 if e.o.crc {
588 dst = enc.AppendCRC(dst)
589 }
590 // Add padding with content from crypto/rand.Reader
591 if e.o.pad > 0 {
592 add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad))
593 dst, err = skippableFrame(dst, add, rand.Reader)
594 if err != nil {
595 panic(err)
596 }
597 }
598 return dst
599}