| package zstd |
| |
| /* |
| #define ZSTD_STATIC_LINKING_ONLY |
| #define ZBUFF_DISABLE_DEPRECATE_WARNINGS |
| #include "zstd.h" |
| #include "zbuff.h" |
| */ |
| import "C" |
| import ( |
| "errors" |
| "fmt" |
| "io" |
| "runtime" |
| "unsafe" |
| ) |
| |
| var errShortRead = errors.New("short read") |
| |
| // Writer is an io.WriteCloser that zstd-compresses its input. |
| type Writer struct { |
| CompressionLevel int |
| |
| ctx *C.ZSTD_CCtx |
| dict []byte |
| dstBuffer []byte |
| firstError error |
| underlyingWriter io.Writer |
| } |
| |
| func resize(in []byte, newSize int) []byte { |
| if in == nil { |
| return make([]byte, newSize) |
| } |
| if newSize <= cap(in) { |
| return in[:newSize] |
| } |
| toAdd := newSize - len(in) |
| return append(in, make([]byte, toAdd)...) |
| } |
| |
| // NewWriter creates a new Writer with default compression options. Writes to |
| // the writer will be written in compressed form to w. |
| func NewWriter(w io.Writer) *Writer { |
| return NewWriterLevelDict(w, DefaultCompression, nil) |
| } |
| |
| // NewWriterLevel is like NewWriter but specifies the compression level instead |
| // of assuming default compression. |
| // |
| // The level can be DefaultCompression or any integer value between BestSpeed |
| // and BestCompression inclusive. |
| func NewWriterLevel(w io.Writer, level int) *Writer { |
| return NewWriterLevelDict(w, level, nil) |
| |
| } |
| |
| // NewWriterLevelDict is like NewWriterLevel but specifies a dictionary to |
| // compress with. If the dictionary is empty or nil it is ignored. The dictionary |
| // should not be modified until the writer is closed. |
| func NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer { |
| var err error |
| ctx := C.ZSTD_createCCtx() |
| |
| if dict == nil { |
| err = getError(int(C.ZSTD_compressBegin(ctx, |
| C.int(level)))) |
| } else { |
| err = getError(int(C.ZSTD_compressBegin_usingDict( |
| ctx, |
| unsafe.Pointer(&dict[0]), |
| C.size_t(len(dict)), |
| C.int(level)))) |
| } |
| |
| return &Writer{ |
| CompressionLevel: level, |
| ctx: ctx, |
| dict: dict, |
| dstBuffer: make([]byte, CompressBound(1024)), |
| firstError: err, |
| underlyingWriter: w, |
| } |
| } |
| |
| // Write writes a compressed form of p to the underlying io.Writer. |
| func (w *Writer) Write(p []byte) (int, error) { |
| if w.firstError != nil { |
| return 0, w.firstError |
| } |
| if len(p) == 0 { |
| return 0, nil |
| } |
| // Check if dstBuffer is enough |
| if len(w.dstBuffer) < CompressBound(len(p)) { |
| w.dstBuffer = make([]byte, CompressBound(len(p))) |
| } |
| |
| retCode := C.ZSTD_compressContinue( |
| w.ctx, |
| unsafe.Pointer(&w.dstBuffer[0]), |
| C.size_t(len(w.dstBuffer)), |
| unsafe.Pointer(&p[0]), |
| C.size_t(len(p))) |
| |
| if err := getError(int(retCode)); err != nil { |
| return 0, err |
| } |
| written := int(retCode) |
| |
| // Write to underlying buffer |
| _, err := w.underlyingWriter.Write(w.dstBuffer[:written]) |
| |
| // Same behaviour as zlib, we can't know how much data we wrote, only |
| // if there was an error |
| if err != nil { |
| return 0, err |
| } |
| return len(p), err |
| } |
| |
| // Close closes the Writer, flushing any unwritten data to the underlying |
| // io.Writer and freeing objects, but does not close the underlying io.Writer. |
| func (w *Writer) Close() error { |
| retCode := C.ZSTD_compressEnd( |
| w.ctx, |
| unsafe.Pointer(&w.dstBuffer[0]), |
| C.size_t(len(w.dstBuffer)), |
| unsafe.Pointer(nil), |
| C.size_t(0)) |
| |
| if err := getError(int(retCode)); err != nil { |
| return err |
| } |
| written := int(retCode) |
| retCode = C.ZSTD_freeCCtx(w.ctx) // Safely close buffer before writing the end |
| |
| if err := getError(int(retCode)); err != nil { |
| return err |
| } |
| |
| _, err := w.underlyingWriter.Write(w.dstBuffer[:written]) |
| if err != nil { |
| return err |
| } |
| return nil |
| } |
| |
| // reader is an io.ReadCloser that decompresses when read from. |
| type reader struct { |
| ctx *C.ZBUFF_DCtx |
| compressionBuffer []byte |
| compressionLeft int |
| decompressionBuffer []byte |
| decompOff int |
| decompSize int |
| dict []byte |
| firstError error |
| recommendedSrcSize int |
| underlyingReader io.Reader |
| } |
| |
| // NewReader creates a new io.ReadCloser. Reads from the returned ReadCloser |
| // read and decompress data from r. It is the caller's responsibility to call |
| // Close on the ReadCloser when done. If this is not done, underlying objects |
| // in the zstd library will not be freed. |
| func NewReader(r io.Reader) io.ReadCloser { |
| return NewReaderDict(r, nil) |
| } |
| |
| // NewReaderDict is like NewReader but uses a preset dictionary. NewReaderDict |
| // ignores the dictionary if it is nil. |
| func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser { |
| var err error |
| ctx := C.ZBUFF_createDCtx() |
| if len(dict) == 0 { |
| err = getError(int(C.ZBUFF_decompressInit(ctx))) |
| } else { |
| err = getError(int(C.ZBUFF_decompressInitDictionary( |
| ctx, |
| unsafe.Pointer(&dict[0]), |
| C.size_t(len(dict))))) |
| } |
| cSize := int(C.ZBUFF_recommendedDInSize()) |
| dSize := int(C.ZBUFF_recommendedDOutSize()) |
| if cSize <= 0 { |
| panic(fmt.Errorf("ZBUFF_recommendedDInSize() returned invalid size: %v", cSize)) |
| } |
| if dSize <= 0 { |
| panic(fmt.Errorf("ZBUFF_recommendedDOutSize() returned invalid size: %v", dSize)) |
| } |
| |
| compressionBuffer := make([]byte, cSize) |
| decompressionBuffer := make([]byte, dSize) |
| return &reader{ |
| ctx: ctx, |
| dict: dict, |
| compressionBuffer: compressionBuffer, |
| decompressionBuffer: decompressionBuffer, |
| firstError: err, |
| recommendedSrcSize: cSize, |
| underlyingReader: r, |
| } |
| } |
| |
| // Close frees the allocated C objects |
| func (r *reader) Close() error { |
| return getError(int(C.ZBUFF_freeDCtx(r.ctx))) |
| } |
| |
| func (r *reader) Read(p []byte) (int, error) { |
| |
| // If we already have enough bytes, return |
| if r.decompSize-r.decompOff >= len(p) { |
| copy(p, r.decompressionBuffer[r.decompOff:]) |
| r.decompOff += len(p) |
| return len(p), nil |
| } |
| |
| copy(p, r.decompressionBuffer[r.decompOff:r.decompSize]) |
| got := r.decompSize - r.decompOff |
| r.decompSize = 0 |
| r.decompOff = 0 |
| |
| for got < len(p) { |
| // Populate src |
| src := r.compressionBuffer |
| reader := r.underlyingReader |
| n, err := TryReadFull(reader, src[r.compressionLeft:]) |
| if err != nil && err != errShortRead { // Handle underlying reader errors first |
| return 0, fmt.Errorf("failed to read from underlying reader: %s", err) |
| } else if n == 0 && r.compressionLeft == 0 { |
| return got, io.EOF |
| } |
| src = src[:r.compressionLeft+n] |
| |
| // C code |
| cSrcSize := C.size_t(len(src)) |
| cDstSize := C.size_t(len(r.decompressionBuffer)) |
| retCode := int(C.ZBUFF_decompressContinue( |
| r.ctx, |
| unsafe.Pointer(&r.decompressionBuffer[0]), |
| &cDstSize, |
| unsafe.Pointer(&src[0]), |
| &cSrcSize)) |
| |
| // Keep src here eventhough, we reuse later, the code might be deleted at some point |
| runtime.KeepAlive(src) |
| if err = getError(retCode); err != nil { |
| return 0, fmt.Errorf("failed to decompress: %s", err) |
| } |
| |
| // Put everything in buffer |
| if int(cSrcSize) < len(src) { |
| left := src[int(cSrcSize):] |
| copy(r.compressionBuffer, left) |
| } |
| r.compressionLeft = len(src) - int(cSrcSize) |
| r.decompSize = int(cDstSize) |
| r.decompOff = copy(p[got:], r.decompressionBuffer[:r.decompSize]) |
| got += r.decompOff |
| |
| // Resize buffers |
| nsize := retCode // Hint for next src buffer size |
| if nsize <= 0 { |
| // Reset to recommended size |
| nsize = r.recommendedSrcSize |
| } |
| if nsize < r.compressionLeft { |
| nsize = r.compressionLeft |
| } |
| r.compressionBuffer = resize(r.compressionBuffer, nsize) |
| } |
| return got, nil |
| } |
| |
| // TryReadFull reads buffer just as ReadFull does |
| // Here we expect that buffer may end and we do not return ErrUnexpectedEOF as ReadAtLeast does. |
| // We return errShortRead instead to distinguish short reads and failures. |
| // We cannot use ReadFull/ReadAtLeast because it masks Reader errors, such as network failures |
| // and causes panic instead of error. |
| func TryReadFull(r io.Reader, buf []byte) (n int, err error) { |
| for n < len(buf) && err == nil { |
| var nn int |
| nn, err = r.Read(buf[n:]) |
| n += nn |
| } |
| if n == len(buf) && err == io.EOF { |
| err = nil // EOF at the end is somewhat expected |
| } else if err == io.EOF { |
| err = errShortRead |
| } |
| return |
| } |