blob: 28dfd3e504e1c83deb5bc3109a30a99fd8ebaa8e [file] [log] [blame]
Matteo Scandolo9a2772a2018-11-19 14:56:26 -08001package zstd
2
3/*
4#define ZSTD_STATIC_LINKING_ONLY
5#define ZBUFF_DISABLE_DEPRECATE_WARNINGS
6#include "zstd.h"
7#include "zbuff.h"
8*/
9import "C"
10import (
11 "fmt"
12 "io"
13 "unsafe"
14)
15
16// Writer is an io.WriteCloser that zstd-compresses its input.
17type Writer struct {
18 CompressionLevel int
19
20 ctx *C.ZSTD_CCtx
21 dict []byte
22 dstBuffer []byte
23 firstError error
24 underlyingWriter io.Writer
25}
26
27func resize(in []byte, newSize int) []byte {
28 if in == nil {
29 return make([]byte, newSize)
30 }
31 if newSize <= cap(in) {
32 return in[:newSize]
33 }
34 toAdd := newSize - len(in)
35 return append(in, make([]byte, toAdd)...)
36}
37
38// NewWriter creates a new Writer with default compression options. Writes to
39// the writer will be written in compressed form to w.
40func NewWriter(w io.Writer) *Writer {
41 return NewWriterLevelDict(w, DefaultCompression, nil)
42}
43
44// NewWriterLevel is like NewWriter but specifies the compression level instead
45// of assuming default compression.
46//
47// The level can be DefaultCompression or any integer value between BestSpeed
48// and BestCompression inclusive.
49func NewWriterLevel(w io.Writer, level int) *Writer {
50 return NewWriterLevelDict(w, level, nil)
51
52}
53
54// NewWriterLevelDict is like NewWriterLevel but specifies a dictionary to
55// compress with. If the dictionary is empty or nil it is ignored. The dictionary
56// should not be modified until the writer is closed.
57func NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer {
58 var err error
59 ctx := C.ZSTD_createCCtx()
60
61 if dict == nil {
62 err = getError(int(C.ZSTD_compressBegin(ctx,
63 C.int(level))))
64 } else {
65 err = getError(int(C.ZSTD_compressBegin_usingDict(
66 ctx,
67 unsafe.Pointer(&dict[0]),
68 C.size_t(len(dict)),
69 C.int(level))))
70 }
71
72 return &Writer{
73 CompressionLevel: level,
74 ctx: ctx,
75 dict: dict,
76 dstBuffer: make([]byte, CompressBound(1024)),
77 firstError: err,
78 underlyingWriter: w,
79 }
80}
81
82// Write writes a compressed form of p to the underlying io.Writer.
83func (w *Writer) Write(p []byte) (int, error) {
84 if w.firstError != nil {
85 return 0, w.firstError
86 }
87 if len(p) == 0 {
88 return 0, nil
89 }
90 // Check if dstBuffer is enough
91 if len(w.dstBuffer) < CompressBound(len(p)) {
92 w.dstBuffer = make([]byte, CompressBound(len(p)))
93 }
94
95 retCode := C.ZSTD_compressContinue(
96 w.ctx,
97 unsafe.Pointer(&w.dstBuffer[0]),
98 C.size_t(len(w.dstBuffer)),
99 unsafe.Pointer(&p[0]),
100 C.size_t(len(p)))
101
102 if err := getError(int(retCode)); err != nil {
103 return 0, err
104 }
105 written := int(retCode)
106
107 // Write to underlying buffer
108 _, err := w.underlyingWriter.Write(w.dstBuffer[:written])
109
110 // Same behaviour as zlib, we can't know how much data we wrote, only
111 // if there was an error
112 if err != nil {
113 return 0, err
114 }
115 return len(p), err
116}
117
118// Close closes the Writer, flushing any unwritten data to the underlying
119// io.Writer and freeing objects, but does not close the underlying io.Writer.
120func (w *Writer) Close() error {
121 retCode := C.ZSTD_compressEnd(
122 w.ctx,
123 unsafe.Pointer(&w.dstBuffer[0]),
124 C.size_t(len(w.dstBuffer)),
125 unsafe.Pointer(nil),
126 C.size_t(0))
127
128 if err := getError(int(retCode)); err != nil {
129 return err
130 }
131 written := int(retCode)
132 retCode = C.ZSTD_freeCCtx(w.ctx) // Safely close buffer before writing the end
133
134 if err := getError(int(retCode)); err != nil {
135 return err
136 }
137
138 _, err := w.underlyingWriter.Write(w.dstBuffer[:written])
139 if err != nil {
140 return err
141 }
142 return nil
143}
144
145// reader is an io.ReadCloser that decompresses when read from.
146type reader struct {
147 ctx *C.ZBUFF_DCtx
148 compressionBuffer []byte
149 compressionLeft int
150 decompressionBuffer []byte
151 decompOff int
152 decompSize int
153 dict []byte
154 firstError error
155 recommendedSrcSize int
156 underlyingReader io.Reader
157}
158
159// NewReader creates a new io.ReadCloser. Reads from the returned ReadCloser
160// read and decompress data from r. It is the caller's responsibility to call
161// Close on the ReadCloser when done. If this is not done, underlying objects
162// in the zstd library will not be freed.
163func NewReader(r io.Reader) io.ReadCloser {
164 return NewReaderDict(r, nil)
165}
166
167// NewReaderDict is like NewReader but uses a preset dictionary. NewReaderDict
168// ignores the dictionary if it is nil.
169func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser {
170 var err error
171 ctx := C.ZBUFF_createDCtx()
172 if len(dict) == 0 {
173 err = getError(int(C.ZBUFF_decompressInit(ctx)))
174 } else {
175 err = getError(int(C.ZBUFF_decompressInitDictionary(
176 ctx,
177 unsafe.Pointer(&dict[0]),
178 C.size_t(len(dict)))))
179 }
180 cSize := int(C.ZBUFF_recommendedDInSize())
181 dSize := int(C.ZBUFF_recommendedDOutSize())
182 if cSize <= 0 {
183 panic(fmt.Errorf("ZBUFF_recommendedDInSize() returned invalid size: %v", cSize))
184 }
185 if dSize <= 0 {
186 panic(fmt.Errorf("ZBUFF_recommendedDOutSize() returned invalid size: %v", dSize))
187 }
188
189 compressionBuffer := make([]byte, cSize)
190 decompressionBuffer := make([]byte, dSize)
191 return &reader{
192 ctx: ctx,
193 dict: dict,
194 compressionBuffer: compressionBuffer,
195 decompressionBuffer: decompressionBuffer,
196 firstError: err,
197 recommendedSrcSize: cSize,
198 underlyingReader: r,
199 }
200}
201
202// Close frees the allocated C objects
203func (r *reader) Close() error {
204 return getError(int(C.ZBUFF_freeDCtx(r.ctx)))
205}
206
207func (r *reader) Read(p []byte) (int, error) {
208
209 // If we already have enough bytes, return
210 if r.decompSize-r.decompOff >= len(p) {
211 copy(p, r.decompressionBuffer[r.decompOff:])
212 r.decompOff += len(p)
213 return len(p), nil
214 }
215
216 copy(p, r.decompressionBuffer[r.decompOff:r.decompSize])
217 got := r.decompSize - r.decompOff
218 r.decompSize = 0
219 r.decompOff = 0
220
221 for got < len(p) {
222 // Populate src
223 src := r.compressionBuffer
224 reader := r.underlyingReader
225 n, err := io.ReadFull(reader, src[r.compressionLeft:])
226 if err == io.EOF && r.compressionLeft == 0 {
227 return got, io.EOF
228 } else if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
229 return 0, fmt.Errorf("failed to read from underlying reader: %s", err)
230 }
231 src = src[:r.compressionLeft+n]
232
233 // C code
234 cSrcSize := C.size_t(len(src))
235 cDstSize := C.size_t(len(r.decompressionBuffer))
236 retCode := int(C.ZBUFF_decompressContinue(
237 r.ctx,
238 unsafe.Pointer(&r.decompressionBuffer[0]),
239 &cDstSize,
240 unsafe.Pointer(&src[0]),
241 &cSrcSize))
242
243 if err = getError(retCode); err != nil {
244 return 0, fmt.Errorf("failed to decompress: %s", err)
245 }
246
247 // Put everything in buffer
248 if int(cSrcSize) < len(src) {
249 left := src[int(cSrcSize):]
250 copy(r.compressionBuffer, left)
251 }
252 r.compressionLeft = len(src) - int(cSrcSize)
253 r.decompSize = int(cDstSize)
254 r.decompOff = copy(p[got:], r.decompressionBuffer[:r.decompSize])
255 got += r.decompOff
256
257 // Resize buffers
258 nsize := retCode // Hint for next src buffer size
259 if nsize <= 0 {
260 // Reset to recommended size
261 nsize = r.recommendedSrcSize
262 }
263 if nsize < r.compressionLeft {
264 nsize = r.compressionLeft
265 }
266 r.compressionBuffer = resize(r.compressionBuffer, nsize)
267 }
268 return got, nil
269}