| // Copyright 2019+ Klaus Post. All rights reserved. |
| // License information can be found in the LICENSE file. |
| // Based on work by Yann Collet, released under BSD License. |
| |
| package zstd |
| |
| import ( |
| "crypto/rand" |
| "fmt" |
| "io" |
| rdebug "runtime/debug" |
| "sync" |
| |
| "github.com/klauspost/compress/zstd/internal/xxhash" |
| ) |
| |
| // Encoder provides encoding to Zstandard. |
| // An Encoder can be used for either compressing a stream via the |
| // io.WriteCloser interface supported by the Encoder or as multiple independent |
| // tasks via the EncodeAll function. |
| // Smaller encodes are encouraged to use the EncodeAll function. |
| // Use NewWriter to create a new instance. |
| type Encoder struct { |
| o encoderOptions |
| encoders chan encoder |
| state encoderState |
| init sync.Once |
| } |
| |
| type encoder interface { |
| Encode(blk *blockEnc, src []byte) |
| Block() *blockEnc |
| CRC() *xxhash.Digest |
| AppendCRC([]byte) []byte |
| WindowSize(size int) int32 |
| UseBlock(*blockEnc) |
| Reset() |
| } |
| |
| type encoderState struct { |
| w io.Writer |
| filling []byte |
| current []byte |
| previous []byte |
| encoder encoder |
| writing *blockEnc |
| err error |
| writeErr error |
| nWritten int64 |
| headerWritten bool |
| eofWritten bool |
| |
| // This waitgroup indicates an encode is running. |
| wg sync.WaitGroup |
| // This waitgroup indicates we have a block encoding/writing. |
| wWg sync.WaitGroup |
| } |
| |
| // NewWriter will create a new Zstandard encoder. |
| // If the encoder will be used for encoding blocks a nil writer can be used. |
| func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) { |
| initPredefined() |
| var e Encoder |
| e.o.setDefault() |
| for _, o := range opts { |
| err := o(&e.o) |
| if err != nil { |
| return nil, err |
| } |
| } |
| if w != nil { |
| e.Reset(w) |
| } else { |
| e.init.Do(func() { |
| e.initialize() |
| }) |
| } |
| return &e, nil |
| } |
| |
| func (e *Encoder) initialize() { |
| e.encoders = make(chan encoder, e.o.concurrent) |
| for i := 0; i < e.o.concurrent; i++ { |
| e.encoders <- e.o.encoder() |
| } |
| } |
| |
| // Reset will re-initialize the writer and new writes will encode to the supplied writer |
| // as a new, independent stream. |
| func (e *Encoder) Reset(w io.Writer) { |
| e.init.Do(func() { |
| e.initialize() |
| }) |
| s := &e.state |
| s.wg.Wait() |
| s.wWg.Wait() |
| if cap(s.filling) == 0 { |
| s.filling = make([]byte, 0, e.o.blockSize) |
| } |
| if cap(s.current) == 0 { |
| s.current = make([]byte, 0, e.o.blockSize) |
| } |
| if cap(s.previous) == 0 { |
| s.previous = make([]byte, 0, e.o.blockSize) |
| } |
| if s.encoder == nil { |
| s.encoder = e.o.encoder() |
| } |
| if s.writing == nil { |
| s.writing = &blockEnc{} |
| s.writing.init() |
| } |
| s.writing.initNewEncode() |
| s.filling = s.filling[:0] |
| s.current = s.current[:0] |
| s.previous = s.previous[:0] |
| s.encoder.Reset() |
| s.headerWritten = false |
| s.eofWritten = false |
| s.w = w |
| s.err = nil |
| s.nWritten = 0 |
| s.writeErr = nil |
| } |
| |
| // Write data to the encoder. |
| // Input data will be buffered and as the buffer fills up |
| // content will be compressed and written to the output. |
| // When done writing, use Close to flush the remaining output |
| // and write CRC if requested. |
| func (e *Encoder) Write(p []byte) (n int, err error) { |
| s := &e.state |
| for len(p) > 0 { |
| if len(p)+len(s.filling) < e.o.blockSize { |
| if e.o.crc { |
| _, _ = s.encoder.CRC().Write(p) |
| } |
| s.filling = append(s.filling, p...) |
| return n + len(p), nil |
| } |
| add := p |
| if len(p)+len(s.filling) > e.o.blockSize { |
| add = add[:e.o.blockSize-len(s.filling)] |
| } |
| if e.o.crc { |
| _, _ = s.encoder.CRC().Write(add) |
| } |
| s.filling = append(s.filling, add...) |
| p = p[len(add):] |
| n += len(add) |
| if len(s.filling) < e.o.blockSize { |
| return n, nil |
| } |
| err := e.nextBlock(false) |
| if err != nil { |
| return n, err |
| } |
| if debug && len(s.filling) > 0 { |
| panic(len(s.filling)) |
| } |
| } |
| return n, nil |
| } |
| |
| // nextBlock will synchronize and start compressing input in e.state.filling. |
| // If an error has occurred during encoding it will be returned. |
| func (e *Encoder) nextBlock(final bool) error { |
| s := &e.state |
| // Wait for current block. |
| s.wg.Wait() |
| if s.err != nil { |
| return s.err |
| } |
| if len(s.filling) > e.o.blockSize { |
| return fmt.Errorf("block > maxStoreBlockSize") |
| } |
| if !s.headerWritten { |
| var tmp [maxHeaderSize]byte |
| fh := frameHeader{ |
| ContentSize: 0, |
| WindowSize: uint32(s.encoder.WindowSize(0)), |
| SingleSegment: false, |
| Checksum: e.o.crc, |
| DictID: 0, |
| } |
| dst, err := fh.appendTo(tmp[:0]) |
| if err != nil { |
| return err |
| } |
| s.headerWritten = true |
| s.wWg.Wait() |
| var n2 int |
| n2, s.err = s.w.Write(dst) |
| if s.err != nil { |
| return s.err |
| } |
| s.nWritten += int64(n2) |
| } |
| if s.eofWritten { |
| // Ensure we only write it once. |
| final = false |
| } |
| |
| if len(s.filling) == 0 { |
| // Final block, but no data. |
| if final { |
| enc := s.encoder |
| blk := enc.Block() |
| blk.reset(nil) |
| blk.last = true |
| blk.encodeRaw(nil) |
| s.wWg.Wait() |
| _, s.err = s.w.Write(blk.output) |
| s.nWritten += int64(len(blk.output)) |
| s.eofWritten = true |
| } |
| return s.err |
| } |
| |
| // Move blocks forward. |
| s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current |
| s.wg.Add(1) |
| go func(src []byte) { |
| if debug { |
| println("Adding block,", len(src), "bytes, final:", final) |
| } |
| defer func() { |
| if r := recover(); r != nil { |
| s.err = fmt.Errorf("panic while encoding: %v", r) |
| rdebug.PrintStack() |
| } |
| s.wg.Done() |
| }() |
| enc := s.encoder |
| blk := enc.Block() |
| enc.Encode(blk, src) |
| blk.last = final |
| if final { |
| s.eofWritten = true |
| } |
| // Wait for pending writes. |
| s.wWg.Wait() |
| if s.writeErr != nil { |
| s.err = s.writeErr |
| return |
| } |
| // Transfer encoders from previous write block. |
| blk.swapEncoders(s.writing) |
| // Transfer recent offsets to next. |
| enc.UseBlock(s.writing) |
| s.writing = blk |
| s.wWg.Add(1) |
| go func() { |
| defer func() { |
| if r := recover(); r != nil { |
| s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r) |
| rdebug.PrintStack() |
| } |
| s.wWg.Done() |
| }() |
| err := errIncompressible |
| // If we got the exact same number of literals as input, |
| // assume the literals cannot be compressed. |
| if len(src) != len(blk.literals) || len(src) != e.o.blockSize { |
| err = blk.encode() |
| } |
| switch err { |
| case errIncompressible: |
| if debug { |
| println("Storing incompressible block as raw") |
| } |
| blk.encodeRaw(src) |
| // In fast mode, we do not transfer offsets, so we don't have to deal with changing the. |
| case nil: |
| default: |
| s.writeErr = err |
| return |
| } |
| _, s.writeErr = s.w.Write(blk.output) |
| s.nWritten += int64(len(blk.output)) |
| }() |
| }(s.current) |
| return nil |
| } |
| |
| // ReadFrom reads data from r until EOF or error. |
| // The return value n is the number of bytes read. |
| // Any error except io.EOF encountered during the read is also returned. |
| // |
| // The Copy function uses ReaderFrom if available. |
| func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) { |
| if debug { |
| println("Using ReadFrom") |
| } |
| // Maybe handle stuff queued? |
| e.state.filling = e.state.filling[:e.o.blockSize] |
| src := e.state.filling |
| for { |
| n2, err := r.Read(src) |
| _, _ = e.state.encoder.CRC().Write(src[:n2]) |
| // src is now the unfilled part... |
| src = src[n2:] |
| n += int64(n2) |
| switch err { |
| case io.EOF: |
| e.state.filling = e.state.filling[:len(e.state.filling)-len(src)] |
| if debug { |
| println("ReadFrom: got EOF final block:", len(e.state.filling)) |
| } |
| return n, e.nextBlock(true) |
| default: |
| if debug { |
| println("ReadFrom: got error:", err) |
| } |
| e.state.err = err |
| return n, err |
| case nil: |
| } |
| if len(src) > 0 { |
| if debug { |
| println("ReadFrom: got space left in source:", len(src)) |
| } |
| continue |
| } |
| err = e.nextBlock(false) |
| if err != nil { |
| return n, err |
| } |
| e.state.filling = e.state.filling[:e.o.blockSize] |
| src = e.state.filling |
| } |
| } |
| |
| // Flush will send the currently written data to output |
| // and block until everything has been written. |
| // This should only be used on rare occasions where pushing the currently queued data is critical. |
| func (e *Encoder) Flush() error { |
| s := &e.state |
| if len(s.filling) > 0 { |
| err := e.nextBlock(false) |
| if err != nil { |
| return err |
| } |
| } |
| s.wg.Wait() |
| s.wWg.Wait() |
| if s.err != nil { |
| return s.err |
| } |
| return s.writeErr |
| } |
| |
| // Close will flush the final output and close the stream. |
| // The function will block until everything has been written. |
| // The Encoder can still be re-used after calling this. |
| func (e *Encoder) Close() error { |
| s := &e.state |
| if s.encoder == nil { |
| return nil |
| } |
| err := e.nextBlock(true) |
| if err != nil { |
| return err |
| } |
| s.wg.Wait() |
| s.wWg.Wait() |
| |
| if s.err != nil { |
| return s.err |
| } |
| if s.writeErr != nil { |
| return s.writeErr |
| } |
| |
| // Write CRC |
| if e.o.crc && s.err == nil { |
| // heap alloc. |
| var tmp [4]byte |
| _, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0])) |
| s.nWritten += 4 |
| } |
| |
| // Add padding with content from crypto/rand.Reader |
| if s.err == nil && e.o.pad > 0 { |
| add := calcSkippableFrame(s.nWritten, int64(e.o.pad)) |
| frame, err := skippableFrame(s.filling[:0], add, rand.Reader) |
| if err != nil { |
| return err |
| } |
| _, s.err = s.w.Write(frame) |
| } |
| return s.err |
| } |
| |
| // EncodeAll will encode all input in src and append it to dst. |
| // This function can be called concurrently, but each call will only run on a single goroutine. |
| // If empty input is given, nothing is returned, unless WithZeroFrames is specified. |
| // Encoded blocks can be concatenated and the result will be the combined input stream. |
| // Data compressed with EncodeAll can be decoded with the Decoder, |
| // using either a stream or DecodeAll. |
| func (e *Encoder) EncodeAll(src, dst []byte) []byte { |
| if len(src) == 0 { |
| if e.o.fullZero { |
| // Add frame header. |
| fh := frameHeader{ |
| ContentSize: 0, |
| WindowSize: minWindowSize, |
| SingleSegment: true, |
| // Adding a checksum would be a waste of space. |
| Checksum: false, |
| DictID: 0, |
| } |
| dst, _ = fh.appendTo(dst) |
| |
| // Write raw block as last one only. |
| var blk blockHeader |
| blk.setSize(0) |
| blk.setType(blockTypeRaw) |
| blk.setLast(true) |
| dst = blk.appendTo(dst) |
| } |
| return dst |
| } |
| e.init.Do(func() { |
| e.o.setDefault() |
| e.initialize() |
| }) |
| enc := <-e.encoders |
| defer func() { |
| // Release encoder reference to last block. |
| enc.Reset() |
| e.encoders <- enc |
| }() |
| enc.Reset() |
| blk := enc.Block() |
| single := len(src) > 1<<20 |
| if e.o.single != nil { |
| single = *e.o.single |
| } |
| fh := frameHeader{ |
| ContentSize: uint64(len(src)), |
| WindowSize: uint32(enc.WindowSize(len(src))), |
| SingleSegment: single, |
| Checksum: e.o.crc, |
| DictID: 0, |
| } |
| |
| // If less than 1MB, allocate a buffer up front. |
| if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 { |
| dst = make([]byte, 0, len(src)) |
| } |
| dst, err := fh.appendTo(dst) |
| if err != nil { |
| panic(err) |
| } |
| |
| for len(src) > 0 { |
| todo := src |
| if len(todo) > e.o.blockSize { |
| todo = todo[:e.o.blockSize] |
| } |
| src = src[len(todo):] |
| if e.o.crc { |
| _, _ = enc.CRC().Write(todo) |
| } |
| blk.reset(nil) |
| blk.pushOffsets() |
| enc.Encode(blk, todo) |
| if len(src) == 0 { |
| blk.last = true |
| } |
| err := errIncompressible |
| // If we got the exact same number of literals as input, |
| // assume the literals cannot be compressed. |
| if len(blk.literals) != len(todo) || len(todo) != e.o.blockSize { |
| err = blk.encode() |
| } |
| |
| switch err { |
| case errIncompressible: |
| if debug { |
| println("Storing incompressible block as raw") |
| } |
| blk.encodeRaw(todo) |
| blk.popOffsets() |
| case nil: |
| default: |
| panic(err) |
| } |
| dst = append(dst, blk.output...) |
| } |
| if e.o.crc { |
| dst = enc.AppendCRC(dst) |
| } |
| // Add padding with content from crypto/rand.Reader |
| if e.o.pad > 0 { |
| add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad)) |
| dst, err = skippableFrame(dst, add, rand.Reader) |
| if err != nil { |
| panic(err) |
| } |
| } |
| return dst |
| } |