blob: d5d1336deb6e37a9126a62149aae1d97414333ba [file] [log] [blame]
William Kurkianea869482019-04-09 15:16:11 -04001package zstd
2
3/*
4#define ZSTD_STATIC_LINKING_ONLY
5#define ZBUFF_DISABLE_DEPRECATE_WARNINGS
6#include "zstd.h"
7#include "zbuff.h"
8*/
9import "C"
10import (
11 "errors"
12 "fmt"
13 "io"
14 "unsafe"
15)
16
17var errShortRead = errors.New("short read")
18
19// Writer is an io.WriteCloser that zstd-compresses its input.
20type Writer struct {
21 CompressionLevel int
22
23 ctx *C.ZSTD_CCtx
24 dict []byte
25 dstBuffer []byte
26 firstError error
27 underlyingWriter io.Writer
28}
29
30func resize(in []byte, newSize int) []byte {
31 if in == nil {
32 return make([]byte, newSize)
33 }
34 if newSize <= cap(in) {
35 return in[:newSize]
36 }
37 toAdd := newSize - len(in)
38 return append(in, make([]byte, toAdd)...)
39}
40
41// NewWriter creates a new Writer with default compression options. Writes to
42// the writer will be written in compressed form to w.
43func NewWriter(w io.Writer) *Writer {
44 return NewWriterLevelDict(w, DefaultCompression, nil)
45}
46
47// NewWriterLevel is like NewWriter but specifies the compression level instead
48// of assuming default compression.
49//
50// The level can be DefaultCompression or any integer value between BestSpeed
51// and BestCompression inclusive.
52func NewWriterLevel(w io.Writer, level int) *Writer {
53 return NewWriterLevelDict(w, level, nil)
54
55}
56
57// NewWriterLevelDict is like NewWriterLevel but specifies a dictionary to
58// compress with. If the dictionary is empty or nil it is ignored. The dictionary
59// should not be modified until the writer is closed.
60func NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer {
61 var err error
62 ctx := C.ZSTD_createCCtx()
63
64 if dict == nil {
65 err = getError(int(C.ZSTD_compressBegin(ctx,
66 C.int(level))))
67 } else {
68 err = getError(int(C.ZSTD_compressBegin_usingDict(
69 ctx,
70 unsafe.Pointer(&dict[0]),
71 C.size_t(len(dict)),
72 C.int(level))))
73 }
74
75 return &Writer{
76 CompressionLevel: level,
77 ctx: ctx,
78 dict: dict,
79 dstBuffer: make([]byte, CompressBound(1024)),
80 firstError: err,
81 underlyingWriter: w,
82 }
83}
84
85// Write writes a compressed form of p to the underlying io.Writer.
86func (w *Writer) Write(p []byte) (int, error) {
87 if w.firstError != nil {
88 return 0, w.firstError
89 }
90 if len(p) == 0 {
91 return 0, nil
92 }
93 // Check if dstBuffer is enough
94 if len(w.dstBuffer) < CompressBound(len(p)) {
95 w.dstBuffer = make([]byte, CompressBound(len(p)))
96 }
97
98 retCode := C.ZSTD_compressContinue(
99 w.ctx,
100 unsafe.Pointer(&w.dstBuffer[0]),
101 C.size_t(len(w.dstBuffer)),
102 unsafe.Pointer(&p[0]),
103 C.size_t(len(p)))
104
105 if err := getError(int(retCode)); err != nil {
106 return 0, err
107 }
108 written := int(retCode)
109
110 // Write to underlying buffer
111 _, err := w.underlyingWriter.Write(w.dstBuffer[:written])
112
113 // Same behaviour as zlib, we can't know how much data we wrote, only
114 // if there was an error
115 if err != nil {
116 return 0, err
117 }
118 return len(p), err
119}
120
121// Close closes the Writer, flushing any unwritten data to the underlying
122// io.Writer and freeing objects, but does not close the underlying io.Writer.
123func (w *Writer) Close() error {
124 retCode := C.ZSTD_compressEnd(
125 w.ctx,
126 unsafe.Pointer(&w.dstBuffer[0]),
127 C.size_t(len(w.dstBuffer)),
128 unsafe.Pointer(nil),
129 C.size_t(0))
130
131 if err := getError(int(retCode)); err != nil {
132 return err
133 }
134 written := int(retCode)
135 retCode = C.ZSTD_freeCCtx(w.ctx) // Safely close buffer before writing the end
136
137 if err := getError(int(retCode)); err != nil {
138 return err
139 }
140
141 _, err := w.underlyingWriter.Write(w.dstBuffer[:written])
142 if err != nil {
143 return err
144 }
145 return nil
146}
147
148// reader is an io.ReadCloser that decompresses when read from.
149type reader struct {
150 ctx *C.ZBUFF_DCtx
151 compressionBuffer []byte
152 compressionLeft int
153 decompressionBuffer []byte
154 decompOff int
155 decompSize int
156 dict []byte
157 firstError error
158 recommendedSrcSize int
159 underlyingReader io.Reader
160}
161
162// NewReader creates a new io.ReadCloser. Reads from the returned ReadCloser
163// read and decompress data from r. It is the caller's responsibility to call
164// Close on the ReadCloser when done. If this is not done, underlying objects
165// in the zstd library will not be freed.
166func NewReader(r io.Reader) io.ReadCloser {
167 return NewReaderDict(r, nil)
168}
169
170// NewReaderDict is like NewReader but uses a preset dictionary. NewReaderDict
171// ignores the dictionary if it is nil.
172func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser {
173 var err error
174 ctx := C.ZBUFF_createDCtx()
175 if len(dict) == 0 {
176 err = getError(int(C.ZBUFF_decompressInit(ctx)))
177 } else {
178 err = getError(int(C.ZBUFF_decompressInitDictionary(
179 ctx,
180 unsafe.Pointer(&dict[0]),
181 C.size_t(len(dict)))))
182 }
183 cSize := int(C.ZBUFF_recommendedDInSize())
184 dSize := int(C.ZBUFF_recommendedDOutSize())
185 if cSize <= 0 {
186 panic(fmt.Errorf("ZBUFF_recommendedDInSize() returned invalid size: %v", cSize))
187 }
188 if dSize <= 0 {
189 panic(fmt.Errorf("ZBUFF_recommendedDOutSize() returned invalid size: %v", dSize))
190 }
191
192 compressionBuffer := make([]byte, cSize)
193 decompressionBuffer := make([]byte, dSize)
194 return &reader{
195 ctx: ctx,
196 dict: dict,
197 compressionBuffer: compressionBuffer,
198 decompressionBuffer: decompressionBuffer,
199 firstError: err,
200 recommendedSrcSize: cSize,
201 underlyingReader: r,
202 }
203}
204
205// Close frees the allocated C objects
206func (r *reader) Close() error {
207 return getError(int(C.ZBUFF_freeDCtx(r.ctx)))
208}
209
210func (r *reader) Read(p []byte) (int, error) {
211
212 // If we already have enough bytes, return
213 if r.decompSize-r.decompOff >= len(p) {
214 copy(p, r.decompressionBuffer[r.decompOff:])
215 r.decompOff += len(p)
216 return len(p), nil
217 }
218
219 copy(p, r.decompressionBuffer[r.decompOff:r.decompSize])
220 got := r.decompSize - r.decompOff
221 r.decompSize = 0
222 r.decompOff = 0
223
224 for got < len(p) {
225 // Populate src
226 src := r.compressionBuffer
227 reader := r.underlyingReader
228 n, err := TryReadFull(reader, src[r.compressionLeft:])
229 if err != nil && err != errShortRead { // Handle underlying reader errors first
230 return 0, fmt.Errorf("failed to read from underlying reader: %s", err)
231 } else if n == 0 && r.compressionLeft == 0 {
232 return got, io.EOF
233 }
234 src = src[:r.compressionLeft+n]
235
236 // C code
237 cSrcSize := C.size_t(len(src))
238 cDstSize := C.size_t(len(r.decompressionBuffer))
239 retCode := int(C.ZBUFF_decompressContinue(
240 r.ctx,
241 unsafe.Pointer(&r.decompressionBuffer[0]),
242 &cDstSize,
243 unsafe.Pointer(&src[0]),
244 &cSrcSize))
245
246 if err = getError(retCode); err != nil {
247 return 0, fmt.Errorf("failed to decompress: %s", err)
248 }
249
250 // Put everything in buffer
251 if int(cSrcSize) < len(src) {
252 left := src[int(cSrcSize):]
253 copy(r.compressionBuffer, left)
254 }
255 r.compressionLeft = len(src) - int(cSrcSize)
256 r.decompSize = int(cDstSize)
257 r.decompOff = copy(p[got:], r.decompressionBuffer[:r.decompSize])
258 got += r.decompOff
259
260 // Resize buffers
261 nsize := retCode // Hint for next src buffer size
262 if nsize <= 0 {
263 // Reset to recommended size
264 nsize = r.recommendedSrcSize
265 }
266 if nsize < r.compressionLeft {
267 nsize = r.compressionLeft
268 }
269 r.compressionBuffer = resize(r.compressionBuffer, nsize)
270 }
271 return got, nil
272}
273
274// TryReadFull reads buffer just as ReadFull does
275// Here we expect that buffer may end and we do not return ErrUnexpectedEOF as ReadAtLeast does.
276// We return errShortRead instead to distinguish short reads and failures.
277// We cannot use ReadFull/ReadAtLeast because it masks Reader errors, such as network failures
278// and causes panic instead of error.
279func TryReadFull(r io.Reader, buf []byte) (n int, err error) {
280 for n < len(buf) && err == nil {
281 var nn int
282 nn, err = r.Read(buf[n:])
283 n += nn
284 }
285 if n == len(buf) && err == io.EOF {
286 err = nil // EOF at the end is somewhat expected
287 } else if err == io.EOF {
288 err = errShortRead
289 }
290 return
291}