blob: 366dd66bde9bad27d3da087a3e47ffb3f6c91755 [file] [log] [blame]
Scott Bakered4efab2020-01-13 19:12:25 -08001// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8 "crypto/rand"
9 "fmt"
10 "io"
11 rdebug "runtime/debug"
12 "sync"
13
14 "github.com/klauspost/compress/zstd/internal/xxhash"
15)
16
17// Encoder provides encoding to Zstandard.
18// An Encoder can be used for either compressing a stream via the
19// io.WriteCloser interface supported by the Encoder or as multiple independent
20// tasks via the EncodeAll function.
21// Smaller encodes are encouraged to use the EncodeAll function.
22// Use NewWriter to create a new instance.
23type Encoder struct {
24 o encoderOptions
25 encoders chan encoder
26 state encoderState
27 init sync.Once
28}
29
30type encoder interface {
31 Encode(blk *blockEnc, src []byte)
32 EncodeNoHist(blk *blockEnc, src []byte)
33 Block() *blockEnc
34 CRC() *xxhash.Digest
35 AppendCRC([]byte) []byte
36 WindowSize(size int) int32
37 UseBlock(*blockEnc)
38 Reset()
39}
40
41type encoderState struct {
42 w io.Writer
43 filling []byte
44 current []byte
45 previous []byte
46 encoder encoder
47 writing *blockEnc
48 err error
49 writeErr error
50 nWritten int64
51 headerWritten bool
52 eofWritten bool
53
54 // This waitgroup indicates an encode is running.
55 wg sync.WaitGroup
56 // This waitgroup indicates we have a block encoding/writing.
57 wWg sync.WaitGroup
58}
59
60// NewWriter will create a new Zstandard encoder.
61// If the encoder will be used for encoding blocks a nil writer can be used.
62func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) {
63 initPredefined()
64 var e Encoder
65 e.o.setDefault()
66 for _, o := range opts {
67 err := o(&e.o)
68 if err != nil {
69 return nil, err
70 }
71 }
72 if w != nil {
73 e.Reset(w)
74 } else {
75 e.init.Do(func() {
76 e.initialize()
77 })
78 }
79 return &e, nil
80}
81
82func (e *Encoder) initialize() {
83 e.encoders = make(chan encoder, e.o.concurrent)
84 for i := 0; i < e.o.concurrent; i++ {
85 e.encoders <- e.o.encoder()
86 }
87}
88
89// Reset will re-initialize the writer and new writes will encode to the supplied writer
90// as a new, independent stream.
91func (e *Encoder) Reset(w io.Writer) {
92 e.init.Do(func() {
93 e.initialize()
94 })
95 s := &e.state
96 s.wg.Wait()
97 s.wWg.Wait()
98 if cap(s.filling) == 0 {
99 s.filling = make([]byte, 0, e.o.blockSize)
100 }
101 if cap(s.current) == 0 {
102 s.current = make([]byte, 0, e.o.blockSize)
103 }
104 if cap(s.previous) == 0 {
105 s.previous = make([]byte, 0, e.o.blockSize)
106 }
107 if s.encoder == nil {
108 s.encoder = e.o.encoder()
109 }
110 if s.writing == nil {
111 s.writing = &blockEnc{}
112 s.writing.init()
113 }
114 s.writing.initNewEncode()
115 s.filling = s.filling[:0]
116 s.current = s.current[:0]
117 s.previous = s.previous[:0]
118 s.encoder.Reset()
119 s.headerWritten = false
120 s.eofWritten = false
121 s.w = w
122 s.err = nil
123 s.nWritten = 0
124 s.writeErr = nil
125}
126
127// Write data to the encoder.
128// Input data will be buffered and as the buffer fills up
129// content will be compressed and written to the output.
130// When done writing, use Close to flush the remaining output
131// and write CRC if requested.
132func (e *Encoder) Write(p []byte) (n int, err error) {
133 s := &e.state
134 for len(p) > 0 {
135 if len(p)+len(s.filling) < e.o.blockSize {
136 if e.o.crc {
137 _, _ = s.encoder.CRC().Write(p)
138 }
139 s.filling = append(s.filling, p...)
140 return n + len(p), nil
141 }
142 add := p
143 if len(p)+len(s.filling) > e.o.blockSize {
144 add = add[:e.o.blockSize-len(s.filling)]
145 }
146 if e.o.crc {
147 _, _ = s.encoder.CRC().Write(add)
148 }
149 s.filling = append(s.filling, add...)
150 p = p[len(add):]
151 n += len(add)
152 if len(s.filling) < e.o.blockSize {
153 return n, nil
154 }
155 err := e.nextBlock(false)
156 if err != nil {
157 return n, err
158 }
159 if debug && len(s.filling) > 0 {
160 panic(len(s.filling))
161 }
162 }
163 return n, nil
164}
165
166// nextBlock will synchronize and start compressing input in e.state.filling.
167// If an error has occurred during encoding it will be returned.
168func (e *Encoder) nextBlock(final bool) error {
169 s := &e.state
170 // Wait for current block.
171 s.wg.Wait()
172 if s.err != nil {
173 return s.err
174 }
175 if len(s.filling) > e.o.blockSize {
176 return fmt.Errorf("block > maxStoreBlockSize")
177 }
178 if !s.headerWritten {
179 var tmp [maxHeaderSize]byte
180 fh := frameHeader{
181 ContentSize: 0,
182 WindowSize: uint32(s.encoder.WindowSize(0)),
183 SingleSegment: false,
184 Checksum: e.o.crc,
185 DictID: 0,
186 }
187 dst, err := fh.appendTo(tmp[:0])
188 if err != nil {
189 return err
190 }
191 s.headerWritten = true
192 s.wWg.Wait()
193 var n2 int
194 n2, s.err = s.w.Write(dst)
195 if s.err != nil {
196 return s.err
197 }
198 s.nWritten += int64(n2)
199 }
200 if s.eofWritten {
201 // Ensure we only write it once.
202 final = false
203 }
204
205 if len(s.filling) == 0 {
206 // Final block, but no data.
207 if final {
208 enc := s.encoder
209 blk := enc.Block()
210 blk.reset(nil)
211 blk.last = true
212 blk.encodeRaw(nil)
213 s.wWg.Wait()
214 _, s.err = s.w.Write(blk.output)
215 s.nWritten += int64(len(blk.output))
216 s.eofWritten = true
217 }
218 return s.err
219 }
220
221 // Move blocks forward.
222 s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
223 s.wg.Add(1)
224 go func(src []byte) {
225 if debug {
226 println("Adding block,", len(src), "bytes, final:", final)
227 }
228 defer func() {
229 if r := recover(); r != nil {
230 s.err = fmt.Errorf("panic while encoding: %v", r)
231 rdebug.PrintStack()
232 }
233 s.wg.Done()
234 }()
235 enc := s.encoder
236 blk := enc.Block()
237 enc.Encode(blk, src)
238 blk.last = final
239 if final {
240 s.eofWritten = true
241 }
242 // Wait for pending writes.
243 s.wWg.Wait()
244 if s.writeErr != nil {
245 s.err = s.writeErr
246 return
247 }
248 // Transfer encoders from previous write block.
249 blk.swapEncoders(s.writing)
250 // Transfer recent offsets to next.
251 enc.UseBlock(s.writing)
252 s.writing = blk
253 s.wWg.Add(1)
254 go func() {
255 defer func() {
256 if r := recover(); r != nil {
257 s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r)
258 rdebug.PrintStack()
259 }
260 s.wWg.Done()
261 }()
262 err := errIncompressible
263 // If we got the exact same number of literals as input,
264 // assume the literals cannot be compressed.
265 if len(src) != len(blk.literals) || len(src) != e.o.blockSize {
266 err = blk.encode(e.o.noEntropy)
267 }
268 switch err {
269 case errIncompressible:
270 if debug {
271 println("Storing incompressible block as raw")
272 }
273 blk.encodeRaw(src)
274 // In fast mode, we do not transfer offsets, so we don't have to deal with changing the.
275 case nil:
276 default:
277 s.writeErr = err
278 return
279 }
280 _, s.writeErr = s.w.Write(blk.output)
281 s.nWritten += int64(len(blk.output))
282 }()
283 }(s.current)
284 return nil
285}
286
287// ReadFrom reads data from r until EOF or error.
288// The return value n is the number of bytes read.
289// Any error except io.EOF encountered during the read is also returned.
290//
291// The Copy function uses ReaderFrom if available.
292func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
293 if debug {
294 println("Using ReadFrom")
295 }
296 // Maybe handle stuff queued?
297 e.state.filling = e.state.filling[:e.o.blockSize]
298 src := e.state.filling
299 for {
300 n2, err := r.Read(src)
301 _, _ = e.state.encoder.CRC().Write(src[:n2])
302 // src is now the unfilled part...
303 src = src[n2:]
304 n += int64(n2)
305 switch err {
306 case io.EOF:
307 e.state.filling = e.state.filling[:len(e.state.filling)-len(src)]
308 if debug {
309 println("ReadFrom: got EOF final block:", len(e.state.filling))
310 }
311 return n, e.nextBlock(true)
312 default:
313 if debug {
314 println("ReadFrom: got error:", err)
315 }
316 e.state.err = err
317 return n, err
318 case nil:
319 }
320 if len(src) > 0 {
321 if debug {
322 println("ReadFrom: got space left in source:", len(src))
323 }
324 continue
325 }
326 err = e.nextBlock(false)
327 if err != nil {
328 return n, err
329 }
330 e.state.filling = e.state.filling[:e.o.blockSize]
331 src = e.state.filling
332 }
333}
334
335// Flush will send the currently written data to output
336// and block until everything has been written.
337// This should only be used on rare occasions where pushing the currently queued data is critical.
338func (e *Encoder) Flush() error {
339 s := &e.state
340 if len(s.filling) > 0 {
341 err := e.nextBlock(false)
342 if err != nil {
343 return err
344 }
345 }
346 s.wg.Wait()
347 s.wWg.Wait()
348 if s.err != nil {
349 return s.err
350 }
351 return s.writeErr
352}
353
354// Close will flush the final output and close the stream.
355// The function will block until everything has been written.
356// The Encoder can still be re-used after calling this.
357func (e *Encoder) Close() error {
358 s := &e.state
359 if s.encoder == nil {
360 return nil
361 }
362 err := e.nextBlock(true)
363 if err != nil {
364 return err
365 }
366 s.wg.Wait()
367 s.wWg.Wait()
368
369 if s.err != nil {
370 return s.err
371 }
372 if s.writeErr != nil {
373 return s.writeErr
374 }
375
376 // Write CRC
377 if e.o.crc && s.err == nil {
378 // heap alloc.
379 var tmp [4]byte
380 _, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0]))
381 s.nWritten += 4
382 }
383
384 // Add padding with content from crypto/rand.Reader
385 if s.err == nil && e.o.pad > 0 {
386 add := calcSkippableFrame(s.nWritten, int64(e.o.pad))
387 frame, err := skippableFrame(s.filling[:0], add, rand.Reader)
388 if err != nil {
389 return err
390 }
391 _, s.err = s.w.Write(frame)
392 }
393 return s.err
394}
395
396// EncodeAll will encode all input in src and append it to dst.
397// This function can be called concurrently, but each call will only run on a single goroutine.
398// If empty input is given, nothing is returned, unless WithZeroFrames is specified.
399// Encoded blocks can be concatenated and the result will be the combined input stream.
400// Data compressed with EncodeAll can be decoded with the Decoder,
401// using either a stream or DecodeAll.
402func (e *Encoder) EncodeAll(src, dst []byte) []byte {
403 if len(src) == 0 {
404 if e.o.fullZero {
405 // Add frame header.
406 fh := frameHeader{
407 ContentSize: 0,
408 WindowSize: MinWindowSize,
409 SingleSegment: true,
410 // Adding a checksum would be a waste of space.
411 Checksum: false,
412 DictID: 0,
413 }
414 dst, _ = fh.appendTo(dst)
415
416 // Write raw block as last one only.
417 var blk blockHeader
418 blk.setSize(0)
419 blk.setType(blockTypeRaw)
420 blk.setLast(true)
421 dst = blk.appendTo(dst)
422 }
423 return dst
424 }
425 e.init.Do(func() {
426 e.o.setDefault()
427 e.initialize()
428 })
429 enc := <-e.encoders
430 defer func() {
431 // Release encoder reference to last block.
432 enc.Reset()
433 e.encoders <- enc
434 }()
435 enc.Reset()
436 blk := enc.Block()
437 // Use single segments when above minimum window and below 1MB.
438 single := len(src) < 1<<20 && len(src) > MinWindowSize
439 if e.o.single != nil {
440 single = *e.o.single
441 }
442 fh := frameHeader{
443 ContentSize: uint64(len(src)),
444 WindowSize: uint32(enc.WindowSize(len(src))),
445 SingleSegment: single,
446 Checksum: e.o.crc,
447 DictID: 0,
448 }
449
450 // If less than 1MB, allocate a buffer up front.
451 if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 {
452 dst = make([]byte, 0, len(src))
453 }
454 dst, err := fh.appendTo(dst)
455 if err != nil {
456 panic(err)
457 }
458
459 if len(src) <= e.o.blockSize && len(src) <= maxBlockSize {
460 // Slightly faster with no history and everything in one block.
461 if e.o.crc {
462 _, _ = enc.CRC().Write(src)
463 }
464 blk.reset(nil)
465 blk.last = true
466 enc.EncodeNoHist(blk, src)
467
468 // If we got the exact same number of literals as input,
469 // assume the literals cannot be compressed.
470 err := errIncompressible
471 oldout := blk.output
472 if len(blk.literals) != len(src) || len(src) != e.o.blockSize {
473 // Output directly to dst
474 blk.output = dst
475 err = blk.encode(e.o.noEntropy)
476 }
477
478 switch err {
479 case errIncompressible:
480 if debug {
481 println("Storing incompressible block as raw")
482 }
483 dst = blk.encodeRawTo(dst, src)
484 case nil:
485 dst = blk.output
486 default:
487 panic(err)
488 }
489 blk.output = oldout
490 } else {
491 for len(src) > 0 {
492 todo := src
493 if len(todo) > e.o.blockSize {
494 todo = todo[:e.o.blockSize]
495 }
496 src = src[len(todo):]
497 if e.o.crc {
498 _, _ = enc.CRC().Write(todo)
499 }
500 blk.reset(nil)
501 blk.pushOffsets()
502 enc.Encode(blk, todo)
503 if len(src) == 0 {
504 blk.last = true
505 }
506 err := errIncompressible
507 // If we got the exact same number of literals as input,
508 // assume the literals cannot be compressed.
509 if len(blk.literals) != len(todo) || len(todo) != e.o.blockSize {
510 err = blk.encode(e.o.noEntropy)
511 }
512
513 switch err {
514 case errIncompressible:
515 if debug {
516 println("Storing incompressible block as raw")
517 }
518 dst = blk.encodeRawTo(dst, todo)
519 blk.popOffsets()
520 case nil:
521 dst = append(dst, blk.output...)
522 default:
523 panic(err)
524 }
525 }
526 }
527 if e.o.crc {
528 dst = enc.AppendCRC(dst)
529 }
530 // Add padding with content from crypto/rand.Reader
531 if e.o.pad > 0 {
532 add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad))
533 dst, err = skippableFrame(dst, add, rand.Reader)
534 if err != nil {
535 panic(err)
536 }
537 }
538 return dst
539}