blob: a855655c4222e0ac35588375332bba2c1622c6e3 [file] [log] [blame]
Dinesh Belwalkare63f7f92019-11-22 23:11:16 +00001// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8 "crypto/rand"
9 "fmt"
10 "io"
11 rdebug "runtime/debug"
12 "sync"
13
14 "github.com/klauspost/compress/zstd/internal/xxhash"
15)
16
17// Encoder provides encoding to Zstandard.
18// An Encoder can be used for either compressing a stream via the
19// io.WriteCloser interface supported by the Encoder or as multiple independent
20// tasks via the EncodeAll function.
21// Smaller encodes are encouraged to use the EncodeAll function.
22// Use NewWriter to create a new instance.
23type Encoder struct {
24 o encoderOptions
25 encoders chan encoder
26 state encoderState
27 init sync.Once
28}
29
30type encoder interface {
31 Encode(blk *blockEnc, src []byte)
32 Block() *blockEnc
33 CRC() *xxhash.Digest
34 AppendCRC([]byte) []byte
35 WindowSize(size int) int32
36 UseBlock(*blockEnc)
37 Reset()
38}
39
40type encoderState struct {
41 w io.Writer
42 filling []byte
43 current []byte
44 previous []byte
45 encoder encoder
46 writing *blockEnc
47 err error
48 writeErr error
49 nWritten int64
50 headerWritten bool
51 eofWritten bool
52
53 // This waitgroup indicates an encode is running.
54 wg sync.WaitGroup
55 // This waitgroup indicates we have a block encoding/writing.
56 wWg sync.WaitGroup
57}
58
59// NewWriter will create a new Zstandard encoder.
60// If the encoder will be used for encoding blocks a nil writer can be used.
61func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) {
62 initPredefined()
63 var e Encoder
64 e.o.setDefault()
65 for _, o := range opts {
66 err := o(&e.o)
67 if err != nil {
68 return nil, err
69 }
70 }
71 if w != nil {
72 e.Reset(w)
73 } else {
74 e.init.Do(func() {
75 e.initialize()
76 })
77 }
78 return &e, nil
79}
80
81func (e *Encoder) initialize() {
82 e.encoders = make(chan encoder, e.o.concurrent)
83 for i := 0; i < e.o.concurrent; i++ {
84 e.encoders <- e.o.encoder()
85 }
86}
87
88// Reset will re-initialize the writer and new writes will encode to the supplied writer
89// as a new, independent stream.
90func (e *Encoder) Reset(w io.Writer) {
91 e.init.Do(func() {
92 e.initialize()
93 })
94 s := &e.state
95 s.wg.Wait()
96 s.wWg.Wait()
97 if cap(s.filling) == 0 {
98 s.filling = make([]byte, 0, e.o.blockSize)
99 }
100 if cap(s.current) == 0 {
101 s.current = make([]byte, 0, e.o.blockSize)
102 }
103 if cap(s.previous) == 0 {
104 s.previous = make([]byte, 0, e.o.blockSize)
105 }
106 if s.encoder == nil {
107 s.encoder = e.o.encoder()
108 }
109 if s.writing == nil {
110 s.writing = &blockEnc{}
111 s.writing.init()
112 }
113 s.writing.initNewEncode()
114 s.filling = s.filling[:0]
115 s.current = s.current[:0]
116 s.previous = s.previous[:0]
117 s.encoder.Reset()
118 s.headerWritten = false
119 s.eofWritten = false
120 s.w = w
121 s.err = nil
122 s.nWritten = 0
123 s.writeErr = nil
124}
125
126// Write data to the encoder.
127// Input data will be buffered and as the buffer fills up
128// content will be compressed and written to the output.
129// When done writing, use Close to flush the remaining output
130// and write CRC if requested.
131func (e *Encoder) Write(p []byte) (n int, err error) {
132 s := &e.state
133 for len(p) > 0 {
134 if len(p)+len(s.filling) < e.o.blockSize {
135 if e.o.crc {
136 _, _ = s.encoder.CRC().Write(p)
137 }
138 s.filling = append(s.filling, p...)
139 return n + len(p), nil
140 }
141 add := p
142 if len(p)+len(s.filling) > e.o.blockSize {
143 add = add[:e.o.blockSize-len(s.filling)]
144 }
145 if e.o.crc {
146 _, _ = s.encoder.CRC().Write(add)
147 }
148 s.filling = append(s.filling, add...)
149 p = p[len(add):]
150 n += len(add)
151 if len(s.filling) < e.o.blockSize {
152 return n, nil
153 }
154 err := e.nextBlock(false)
155 if err != nil {
156 return n, err
157 }
158 if debug && len(s.filling) > 0 {
159 panic(len(s.filling))
160 }
161 }
162 return n, nil
163}
164
165// nextBlock will synchronize and start compressing input in e.state.filling.
166// If an error has occurred during encoding it will be returned.
167func (e *Encoder) nextBlock(final bool) error {
168 s := &e.state
169 // Wait for current block.
170 s.wg.Wait()
171 if s.err != nil {
172 return s.err
173 }
174 if len(s.filling) > e.o.blockSize {
175 return fmt.Errorf("block > maxStoreBlockSize")
176 }
177 if !s.headerWritten {
178 var tmp [maxHeaderSize]byte
179 fh := frameHeader{
180 ContentSize: 0,
181 WindowSize: uint32(s.encoder.WindowSize(0)),
182 SingleSegment: false,
183 Checksum: e.o.crc,
184 DictID: 0,
185 }
186 dst, err := fh.appendTo(tmp[:0])
187 if err != nil {
188 return err
189 }
190 s.headerWritten = true
191 s.wWg.Wait()
192 var n2 int
193 n2, s.err = s.w.Write(dst)
194 if s.err != nil {
195 return s.err
196 }
197 s.nWritten += int64(n2)
198 }
199 if s.eofWritten {
200 // Ensure we only write it once.
201 final = false
202 }
203
204 if len(s.filling) == 0 {
205 // Final block, but no data.
206 if final {
207 enc := s.encoder
208 blk := enc.Block()
209 blk.reset(nil)
210 blk.last = true
211 blk.encodeRaw(nil)
212 s.wWg.Wait()
213 _, s.err = s.w.Write(blk.output)
214 s.nWritten += int64(len(blk.output))
215 s.eofWritten = true
216 }
217 return s.err
218 }
219
220 // Move blocks forward.
221 s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
222 s.wg.Add(1)
223 go func(src []byte) {
224 if debug {
225 println("Adding block,", len(src), "bytes, final:", final)
226 }
227 defer func() {
228 if r := recover(); r != nil {
229 s.err = fmt.Errorf("panic while encoding: %v", r)
230 rdebug.PrintStack()
231 }
232 s.wg.Done()
233 }()
234 enc := s.encoder
235 blk := enc.Block()
236 enc.Encode(blk, src)
237 blk.last = final
238 if final {
239 s.eofWritten = true
240 }
241 // Wait for pending writes.
242 s.wWg.Wait()
243 if s.writeErr != nil {
244 s.err = s.writeErr
245 return
246 }
247 // Transfer encoders from previous write block.
248 blk.swapEncoders(s.writing)
249 // Transfer recent offsets to next.
250 enc.UseBlock(s.writing)
251 s.writing = blk
252 s.wWg.Add(1)
253 go func() {
254 defer func() {
255 if r := recover(); r != nil {
256 s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r)
257 rdebug.PrintStack()
258 }
259 s.wWg.Done()
260 }()
261 err := errIncompressible
262 // If we got the exact same number of literals as input,
263 // assume the literals cannot be compressed.
264 if len(src) != len(blk.literals) || len(src) != e.o.blockSize {
265 err = blk.encode()
266 }
267 switch err {
268 case errIncompressible:
269 if debug {
270 println("Storing incompressible block as raw")
271 }
272 blk.encodeRaw(src)
273 // In fast mode, we do not transfer offsets, so we don't have to deal with changing the.
274 case nil:
275 default:
276 s.writeErr = err
277 return
278 }
279 _, s.writeErr = s.w.Write(blk.output)
280 s.nWritten += int64(len(blk.output))
281 }()
282 }(s.current)
283 return nil
284}
285
286// ReadFrom reads data from r until EOF or error.
287// The return value n is the number of bytes read.
288// Any error except io.EOF encountered during the read is also returned.
289//
290// The Copy function uses ReaderFrom if available.
291func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
292 if debug {
293 println("Using ReadFrom")
294 }
295 // Maybe handle stuff queued?
296 e.state.filling = e.state.filling[:e.o.blockSize]
297 src := e.state.filling
298 for {
299 n2, err := r.Read(src)
300 _, _ = e.state.encoder.CRC().Write(src[:n2])
301 // src is now the unfilled part...
302 src = src[n2:]
303 n += int64(n2)
304 switch err {
305 case io.EOF:
306 e.state.filling = e.state.filling[:len(e.state.filling)-len(src)]
307 if debug {
308 println("ReadFrom: got EOF final block:", len(e.state.filling))
309 }
310 return n, e.nextBlock(true)
311 default:
312 if debug {
313 println("ReadFrom: got error:", err)
314 }
315 e.state.err = err
316 return n, err
317 case nil:
318 }
319 if len(src) > 0 {
320 if debug {
321 println("ReadFrom: got space left in source:", len(src))
322 }
323 continue
324 }
325 err = e.nextBlock(false)
326 if err != nil {
327 return n, err
328 }
329 e.state.filling = e.state.filling[:e.o.blockSize]
330 src = e.state.filling
331 }
332}
333
334// Flush will send the currently written data to output
335// and block until everything has been written.
336// This should only be used on rare occasions where pushing the currently queued data is critical.
337func (e *Encoder) Flush() error {
338 s := &e.state
339 if len(s.filling) > 0 {
340 err := e.nextBlock(false)
341 if err != nil {
342 return err
343 }
344 }
345 s.wg.Wait()
346 s.wWg.Wait()
347 if s.err != nil {
348 return s.err
349 }
350 return s.writeErr
351}
352
353// Close will flush the final output and close the stream.
354// The function will block until everything has been written.
355// The Encoder can still be re-used after calling this.
356func (e *Encoder) Close() error {
357 s := &e.state
358 if s.encoder == nil {
359 return nil
360 }
361 err := e.nextBlock(true)
362 if err != nil {
363 return err
364 }
365 s.wg.Wait()
366 s.wWg.Wait()
367
368 if s.err != nil {
369 return s.err
370 }
371 if s.writeErr != nil {
372 return s.writeErr
373 }
374
375 // Write CRC
376 if e.o.crc && s.err == nil {
377 // heap alloc.
378 var tmp [4]byte
379 _, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0]))
380 s.nWritten += 4
381 }
382
383 // Add padding with content from crypto/rand.Reader
384 if s.err == nil && e.o.pad > 0 {
385 add := calcSkippableFrame(s.nWritten, int64(e.o.pad))
386 frame, err := skippableFrame(s.filling[:0], add, rand.Reader)
387 if err != nil {
388 return err
389 }
390 _, s.err = s.w.Write(frame)
391 }
392 return s.err
393}
394
395// EncodeAll will encode all input in src and append it to dst.
396// This function can be called concurrently, but each call will only run on a single goroutine.
397// If empty input is given, nothing is returned, unless WithZeroFrames is specified.
398// Encoded blocks can be concatenated and the result will be the combined input stream.
399// Data compressed with EncodeAll can be decoded with the Decoder,
400// using either a stream or DecodeAll.
401func (e *Encoder) EncodeAll(src, dst []byte) []byte {
402 if len(src) == 0 {
403 if e.o.fullZero {
404 // Add frame header.
405 fh := frameHeader{
406 ContentSize: 0,
407 WindowSize: minWindowSize,
408 SingleSegment: true,
409 // Adding a checksum would be a waste of space.
410 Checksum: false,
411 DictID: 0,
412 }
413 dst, _ = fh.appendTo(dst)
414
415 // Write raw block as last one only.
416 var blk blockHeader
417 blk.setSize(0)
418 blk.setType(blockTypeRaw)
419 blk.setLast(true)
420 dst = blk.appendTo(dst)
421 }
422 return dst
423 }
424 e.init.Do(func() {
425 e.o.setDefault()
426 e.initialize()
427 })
428 enc := <-e.encoders
429 defer func() {
430 // Release encoder reference to last block.
431 enc.Reset()
432 e.encoders <- enc
433 }()
434 enc.Reset()
435 blk := enc.Block()
436 single := len(src) > 1<<20
437 if e.o.single != nil {
438 single = *e.o.single
439 }
440 fh := frameHeader{
441 ContentSize: uint64(len(src)),
442 WindowSize: uint32(enc.WindowSize(len(src))),
443 SingleSegment: single,
444 Checksum: e.o.crc,
445 DictID: 0,
446 }
447
448 // If less than 1MB, allocate a buffer up front.
449 if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 {
450 dst = make([]byte, 0, len(src))
451 }
452 dst, err := fh.appendTo(dst)
453 if err != nil {
454 panic(err)
455 }
456
457 for len(src) > 0 {
458 todo := src
459 if len(todo) > e.o.blockSize {
460 todo = todo[:e.o.blockSize]
461 }
462 src = src[len(todo):]
463 if e.o.crc {
464 _, _ = enc.CRC().Write(todo)
465 }
466 blk.reset(nil)
467 blk.pushOffsets()
468 enc.Encode(blk, todo)
469 if len(src) == 0 {
470 blk.last = true
471 }
472 err := errIncompressible
473 // If we got the exact same number of literals as input,
474 // assume the literals cannot be compressed.
475 if len(blk.literals) != len(todo) || len(todo) != e.o.blockSize {
476 err = blk.encode()
477 }
478
479 switch err {
480 case errIncompressible:
481 if debug {
482 println("Storing incompressible block as raw")
483 }
484 blk.encodeRaw(todo)
485 blk.popOffsets()
486 case nil:
487 default:
488 panic(err)
489 }
490 dst = append(dst, blk.output...)
491 }
492 if e.o.crc {
493 dst = enc.AppendCRC(dst)
494 }
495 // Add padding with content from crypto/rand.Reader
496 if e.o.pad > 0 {
497 add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad))
498 dst, err = skippableFrame(dst, add, rand.Reader)
499 if err != nil {
500 panic(err)
501 }
502 }
503 return dst
504}