blob: 233035352aa6954c520beab5c0d43e9a2f9cfa58 [file] [log] [blame]
Takahiro Suzukid7bf8202020-12-17 20:21:59 +09001package zstd
2
3/*
4#define ZSTD_STATIC_LINKING_ONLY
5#define ZBUFF_DISABLE_DEPRECATE_WARNINGS
6#include "zstd.h"
7#include "zbuff.h"
8*/
9import "C"
10import (
11 "errors"
12 "fmt"
13 "io"
14 "runtime"
15 "unsafe"
16)
17
18var errShortRead = errors.New("short read")
19
20// Writer is an io.WriteCloser that zstd-compresses its input.
21type Writer struct {
22 CompressionLevel int
23
24 ctx *C.ZSTD_CCtx
25 dict []byte
26 dstBuffer []byte
27 firstError error
28 underlyingWriter io.Writer
29}
30
31func resize(in []byte, newSize int) []byte {
32 if in == nil {
33 return make([]byte, newSize)
34 }
35 if newSize <= cap(in) {
36 return in[:newSize]
37 }
38 toAdd := newSize - len(in)
39 return append(in, make([]byte, toAdd)...)
40}
41
42// NewWriter creates a new Writer with default compression options. Writes to
43// the writer will be written in compressed form to w.
44func NewWriter(w io.Writer) *Writer {
45 return NewWriterLevelDict(w, DefaultCompression, nil)
46}
47
48// NewWriterLevel is like NewWriter but specifies the compression level instead
49// of assuming default compression.
50//
51// The level can be DefaultCompression or any integer value between BestSpeed
52// and BestCompression inclusive.
53func NewWriterLevel(w io.Writer, level int) *Writer {
54 return NewWriterLevelDict(w, level, nil)
55
56}
57
58// NewWriterLevelDict is like NewWriterLevel but specifies a dictionary to
59// compress with. If the dictionary is empty or nil it is ignored. The dictionary
60// should not be modified until the writer is closed.
61func NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer {
62 var err error
63 ctx := C.ZSTD_createCCtx()
64
65 if dict == nil {
66 err = getError(int(C.ZSTD_compressBegin(ctx,
67 C.int(level))))
68 } else {
69 err = getError(int(C.ZSTD_compressBegin_usingDict(
70 ctx,
71 unsafe.Pointer(&dict[0]),
72 C.size_t(len(dict)),
73 C.int(level))))
74 }
75
76 return &Writer{
77 CompressionLevel: level,
78 ctx: ctx,
79 dict: dict,
80 dstBuffer: make([]byte, CompressBound(1024)),
81 firstError: err,
82 underlyingWriter: w,
83 }
84}
85
86// Write writes a compressed form of p to the underlying io.Writer.
87func (w *Writer) Write(p []byte) (int, error) {
88 if w.firstError != nil {
89 return 0, w.firstError
90 }
91 if len(p) == 0 {
92 return 0, nil
93 }
94 // Check if dstBuffer is enough
95 if len(w.dstBuffer) < CompressBound(len(p)) {
96 w.dstBuffer = make([]byte, CompressBound(len(p)))
97 }
98
99 retCode := C.ZSTD_compressContinue(
100 w.ctx,
101 unsafe.Pointer(&w.dstBuffer[0]),
102 C.size_t(len(w.dstBuffer)),
103 unsafe.Pointer(&p[0]),
104 C.size_t(len(p)))
105
106 if err := getError(int(retCode)); err != nil {
107 return 0, err
108 }
109 written := int(retCode)
110
111 // Write to underlying buffer
112 _, err := w.underlyingWriter.Write(w.dstBuffer[:written])
113
114 // Same behaviour as zlib, we can't know how much data we wrote, only
115 // if there was an error
116 if err != nil {
117 return 0, err
118 }
119 return len(p), err
120}
121
122// Close closes the Writer, flushing any unwritten data to the underlying
123// io.Writer and freeing objects, but does not close the underlying io.Writer.
124func (w *Writer) Close() error {
125 retCode := C.ZSTD_compressEnd(
126 w.ctx,
127 unsafe.Pointer(&w.dstBuffer[0]),
128 C.size_t(len(w.dstBuffer)),
129 unsafe.Pointer(nil),
130 C.size_t(0))
131
132 if err := getError(int(retCode)); err != nil {
133 return err
134 }
135 written := int(retCode)
136 retCode = C.ZSTD_freeCCtx(w.ctx) // Safely close buffer before writing the end
137
138 if err := getError(int(retCode)); err != nil {
139 return err
140 }
141
142 _, err := w.underlyingWriter.Write(w.dstBuffer[:written])
143 if err != nil {
144 return err
145 }
146 return nil
147}
148
149// reader is an io.ReadCloser that decompresses when read from.
150type reader struct {
151 ctx *C.ZBUFF_DCtx
152 compressionBuffer []byte
153 compressionLeft int
154 decompressionBuffer []byte
155 decompOff int
156 decompSize int
157 dict []byte
158 firstError error
159 recommendedSrcSize int
160 underlyingReader io.Reader
161}
162
163// NewReader creates a new io.ReadCloser. Reads from the returned ReadCloser
164// read and decompress data from r. It is the caller's responsibility to call
165// Close on the ReadCloser when done. If this is not done, underlying objects
166// in the zstd library will not be freed.
167func NewReader(r io.Reader) io.ReadCloser {
168 return NewReaderDict(r, nil)
169}
170
171// NewReaderDict is like NewReader but uses a preset dictionary. NewReaderDict
172// ignores the dictionary if it is nil.
173func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser {
174 var err error
175 ctx := C.ZBUFF_createDCtx()
176 if len(dict) == 0 {
177 err = getError(int(C.ZBUFF_decompressInit(ctx)))
178 } else {
179 err = getError(int(C.ZBUFF_decompressInitDictionary(
180 ctx,
181 unsafe.Pointer(&dict[0]),
182 C.size_t(len(dict)))))
183 }
184 cSize := int(C.ZBUFF_recommendedDInSize())
185 dSize := int(C.ZBUFF_recommendedDOutSize())
186 if cSize <= 0 {
187 panic(fmt.Errorf("ZBUFF_recommendedDInSize() returned invalid size: %v", cSize))
188 }
189 if dSize <= 0 {
190 panic(fmt.Errorf("ZBUFF_recommendedDOutSize() returned invalid size: %v", dSize))
191 }
192
193 compressionBuffer := make([]byte, cSize)
194 decompressionBuffer := make([]byte, dSize)
195 return &reader{
196 ctx: ctx,
197 dict: dict,
198 compressionBuffer: compressionBuffer,
199 decompressionBuffer: decompressionBuffer,
200 firstError: err,
201 recommendedSrcSize: cSize,
202 underlyingReader: r,
203 }
204}
205
206// Close frees the allocated C objects
207func (r *reader) Close() error {
208 return getError(int(C.ZBUFF_freeDCtx(r.ctx)))
209}
210
211func (r *reader) Read(p []byte) (int, error) {
212
213 // If we already have enough bytes, return
214 if r.decompSize-r.decompOff >= len(p) {
215 copy(p, r.decompressionBuffer[r.decompOff:])
216 r.decompOff += len(p)
217 return len(p), nil
218 }
219
220 copy(p, r.decompressionBuffer[r.decompOff:r.decompSize])
221 got := r.decompSize - r.decompOff
222 r.decompSize = 0
223 r.decompOff = 0
224
225 for got < len(p) {
226 // Populate src
227 src := r.compressionBuffer
228 reader := r.underlyingReader
229 n, err := TryReadFull(reader, src[r.compressionLeft:])
230 if err != nil && err != errShortRead { // Handle underlying reader errors first
231 return 0, fmt.Errorf("failed to read from underlying reader: %s", err)
232 } else if n == 0 && r.compressionLeft == 0 {
233 return got, io.EOF
234 }
235 src = src[:r.compressionLeft+n]
236
237 // C code
238 cSrcSize := C.size_t(len(src))
239 cDstSize := C.size_t(len(r.decompressionBuffer))
240 retCode := int(C.ZBUFF_decompressContinue(
241 r.ctx,
242 unsafe.Pointer(&r.decompressionBuffer[0]),
243 &cDstSize,
244 unsafe.Pointer(&src[0]),
245 &cSrcSize))
246
247 // Keep src here eventhough, we reuse later, the code might be deleted at some point
248 runtime.KeepAlive(src)
249 if err = getError(retCode); err != nil {
250 return 0, fmt.Errorf("failed to decompress: %s", err)
251 }
252
253 // Put everything in buffer
254 if int(cSrcSize) < len(src) {
255 left := src[int(cSrcSize):]
256 copy(r.compressionBuffer, left)
257 }
258 r.compressionLeft = len(src) - int(cSrcSize)
259 r.decompSize = int(cDstSize)
260 r.decompOff = copy(p[got:], r.decompressionBuffer[:r.decompSize])
261 got += r.decompOff
262
263 // Resize buffers
264 nsize := retCode // Hint for next src buffer size
265 if nsize <= 0 {
266 // Reset to recommended size
267 nsize = r.recommendedSrcSize
268 }
269 if nsize < r.compressionLeft {
270 nsize = r.compressionLeft
271 }
272 r.compressionBuffer = resize(r.compressionBuffer, nsize)
273 }
274 return got, nil
275}
276
277// TryReadFull reads buffer just as ReadFull does
278// Here we expect that buffer may end and we do not return ErrUnexpectedEOF as ReadAtLeast does.
279// We return errShortRead instead to distinguish short reads and failures.
280// We cannot use ReadFull/ReadAtLeast because it masks Reader errors, such as network failures
281// and causes panic instead of error.
282func TryReadFull(r io.Reader, buf []byte) (n int, err error) {
283 for n < len(buf) && err == nil {
284 var nn int
285 nn, err = r.Read(buf[n:])
286 n += nn
287 }
288 if n == len(buf) && err == io.EOF {
289 err = nil // EOF at the end is somewhat expected
290 } else if err == io.EOF {
291 err = errShortRead
292 }
293 return
294}