blob: 28dfd3e504e1c83deb5bc3109a30a99fd8ebaa8e [file] [log] [blame]
package zstd
/*
#define ZSTD_STATIC_LINKING_ONLY
#define ZBUFF_DISABLE_DEPRECATE_WARNINGS
#include "zstd.h"
#include "zbuff.h"
*/
import "C"
import (
"fmt"
"io"
"unsafe"
)
// Writer is an io.WriteCloser that zstd-compresses its input.
type Writer struct {
CompressionLevel int
ctx *C.ZSTD_CCtx
dict []byte
dstBuffer []byte
firstError error
underlyingWriter io.Writer
}
func resize(in []byte, newSize int) []byte {
if in == nil {
return make([]byte, newSize)
}
if newSize <= cap(in) {
return in[:newSize]
}
toAdd := newSize - len(in)
return append(in, make([]byte, toAdd)...)
}
// NewWriter creates a new Writer with default compression options. Writes to
// the writer will be written in compressed form to w.
func NewWriter(w io.Writer) *Writer {
return NewWriterLevelDict(w, DefaultCompression, nil)
}
// NewWriterLevel is like NewWriter but specifies the compression level instead
// of assuming default compression.
//
// The level can be DefaultCompression or any integer value between BestSpeed
// and BestCompression inclusive.
func NewWriterLevel(w io.Writer, level int) *Writer {
return NewWriterLevelDict(w, level, nil)
}
// NewWriterLevelDict is like NewWriterLevel but specifies a dictionary to
// compress with. If the dictionary is empty or nil it is ignored. The dictionary
// should not be modified until the writer is closed.
func NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer {
var err error
ctx := C.ZSTD_createCCtx()
if dict == nil {
err = getError(int(C.ZSTD_compressBegin(ctx,
C.int(level))))
} else {
err = getError(int(C.ZSTD_compressBegin_usingDict(
ctx,
unsafe.Pointer(&dict[0]),
C.size_t(len(dict)),
C.int(level))))
}
return &Writer{
CompressionLevel: level,
ctx: ctx,
dict: dict,
dstBuffer: make([]byte, CompressBound(1024)),
firstError: err,
underlyingWriter: w,
}
}
// Write writes a compressed form of p to the underlying io.Writer.
func (w *Writer) Write(p []byte) (int, error) {
if w.firstError != nil {
return 0, w.firstError
}
if len(p) == 0 {
return 0, nil
}
// Check if dstBuffer is enough
if len(w.dstBuffer) < CompressBound(len(p)) {
w.dstBuffer = make([]byte, CompressBound(len(p)))
}
retCode := C.ZSTD_compressContinue(
w.ctx,
unsafe.Pointer(&w.dstBuffer[0]),
C.size_t(len(w.dstBuffer)),
unsafe.Pointer(&p[0]),
C.size_t(len(p)))
if err := getError(int(retCode)); err != nil {
return 0, err
}
written := int(retCode)
// Write to underlying buffer
_, err := w.underlyingWriter.Write(w.dstBuffer[:written])
// Same behaviour as zlib, we can't know how much data we wrote, only
// if there was an error
if err != nil {
return 0, err
}
return len(p), err
}
// Close closes the Writer, flushing any unwritten data to the underlying
// io.Writer and freeing objects, but does not close the underlying io.Writer.
func (w *Writer) Close() error {
retCode := C.ZSTD_compressEnd(
w.ctx,
unsafe.Pointer(&w.dstBuffer[0]),
C.size_t(len(w.dstBuffer)),
unsafe.Pointer(nil),
C.size_t(0))
if err := getError(int(retCode)); err != nil {
return err
}
written := int(retCode)
retCode = C.ZSTD_freeCCtx(w.ctx) // Safely close buffer before writing the end
if err := getError(int(retCode)); err != nil {
return err
}
_, err := w.underlyingWriter.Write(w.dstBuffer[:written])
if err != nil {
return err
}
return nil
}
// reader is an io.ReadCloser that decompresses when read from.
type reader struct {
ctx *C.ZBUFF_DCtx
compressionBuffer []byte
compressionLeft int
decompressionBuffer []byte
decompOff int
decompSize int
dict []byte
firstError error
recommendedSrcSize int
underlyingReader io.Reader
}
// NewReader creates a new io.ReadCloser. Reads from the returned ReadCloser
// read and decompress data from r. It is the caller's responsibility to call
// Close on the ReadCloser when done. If this is not done, underlying objects
// in the zstd library will not be freed.
func NewReader(r io.Reader) io.ReadCloser {
return NewReaderDict(r, nil)
}
// NewReaderDict is like NewReader but uses a preset dictionary. NewReaderDict
// ignores the dictionary if it is nil.
func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser {
var err error
ctx := C.ZBUFF_createDCtx()
if len(dict) == 0 {
err = getError(int(C.ZBUFF_decompressInit(ctx)))
} else {
err = getError(int(C.ZBUFF_decompressInitDictionary(
ctx,
unsafe.Pointer(&dict[0]),
C.size_t(len(dict)))))
}
cSize := int(C.ZBUFF_recommendedDInSize())
dSize := int(C.ZBUFF_recommendedDOutSize())
if cSize <= 0 {
panic(fmt.Errorf("ZBUFF_recommendedDInSize() returned invalid size: %v", cSize))
}
if dSize <= 0 {
panic(fmt.Errorf("ZBUFF_recommendedDOutSize() returned invalid size: %v", dSize))
}
compressionBuffer := make([]byte, cSize)
decompressionBuffer := make([]byte, dSize)
return &reader{
ctx: ctx,
dict: dict,
compressionBuffer: compressionBuffer,
decompressionBuffer: decompressionBuffer,
firstError: err,
recommendedSrcSize: cSize,
underlyingReader: r,
}
}
// Close frees the allocated C objects
func (r *reader) Close() error {
return getError(int(C.ZBUFF_freeDCtx(r.ctx)))
}
func (r *reader) Read(p []byte) (int, error) {
// If we already have enough bytes, return
if r.decompSize-r.decompOff >= len(p) {
copy(p, r.decompressionBuffer[r.decompOff:])
r.decompOff += len(p)
return len(p), nil
}
copy(p, r.decompressionBuffer[r.decompOff:r.decompSize])
got := r.decompSize - r.decompOff
r.decompSize = 0
r.decompOff = 0
for got < len(p) {
// Populate src
src := r.compressionBuffer
reader := r.underlyingReader
n, err := io.ReadFull(reader, src[r.compressionLeft:])
if err == io.EOF && r.compressionLeft == 0 {
return got, io.EOF
} else if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
return 0, fmt.Errorf("failed to read from underlying reader: %s", err)
}
src = src[:r.compressionLeft+n]
// C code
cSrcSize := C.size_t(len(src))
cDstSize := C.size_t(len(r.decompressionBuffer))
retCode := int(C.ZBUFF_decompressContinue(
r.ctx,
unsafe.Pointer(&r.decompressionBuffer[0]),
&cDstSize,
unsafe.Pointer(&src[0]),
&cSrcSize))
if err = getError(retCode); err != nil {
return 0, fmt.Errorf("failed to decompress: %s", err)
}
// Put everything in buffer
if int(cSrcSize) < len(src) {
left := src[int(cSrcSize):]
copy(r.compressionBuffer, left)
}
r.compressionLeft = len(src) - int(cSrcSize)
r.decompSize = int(cDstSize)
r.decompOff = copy(p[got:], r.decompressionBuffer[:r.decompSize])
got += r.decompOff
// Resize buffers
nsize := retCode // Hint for next src buffer size
if nsize <= 0 {
// Reset to recommended size
nsize = r.recommendedSrcSize
}
if nsize < r.compressionLeft {
nsize = r.compressionLeft
}
r.compressionBuffer = resize(r.compressionBuffer, nsize)
}
return got, nil
}