blob: b6af4eb19df67466f73cc354f0b8a7559431cdfc [file] [log] [blame]
khenaidooac637102019-01-14 15:44:34 -05001package zstd
2
3/*
4#define ZSTD_STATIC_LINKING_ONLY
5#include "zstd.h"
6#include "stdint.h" // for uintptr_t
7
8// The following *_wrapper function are used for removing superflouos
9// memory allocations when calling the wrapped functions from Go code.
10// See https://github.com/golang/go/issues/24450 for details.
11
12static size_t ZSTD_compress_wrapper(uintptr_t dst, size_t maxDstSize, const uintptr_t src, size_t srcSize, int compressionLevel) {
13 return ZSTD_compress((void*)dst, maxDstSize, (const void*)src, srcSize, compressionLevel);
14}
15
16static size_t ZSTD_decompress_wrapper(uintptr_t dst, size_t maxDstSize, uintptr_t src, size_t srcSize) {
17 return ZSTD_decompress((void*)dst, maxDstSize, (const void *)src, srcSize);
18}
19
20*/
21import "C"
22import (
23 "bytes"
24 "errors"
25 "io/ioutil"
Scott Bakerbeb3cfa2019-10-01 14:44:30 -070026 "runtime"
khenaidooac637102019-01-14 15:44:34 -050027 "unsafe"
28)
29
30// Defines best and standard values for zstd cli
31const (
32 BestSpeed = 1
33 BestCompression = 20
34 DefaultCompression = 5
35)
36
37var (
38 // ErrEmptySlice is returned when there is nothing to compress
39 ErrEmptySlice = errors.New("Bytes slice is empty")
40)
41
42// CompressBound returns the worst case size needed for a destination buffer,
43// which can be used to preallocate a destination buffer or select a previously
44// allocated buffer from a pool.
45// See zstd.h to mirror implementation of ZSTD_COMPRESSBOUND
46func CompressBound(srcSize int) int {
47 lowLimit := 128 << 10 // 128 kB
48 var margin int
49 if srcSize < lowLimit {
50 margin = (lowLimit - srcSize) >> 11
51 }
52 return srcSize + (srcSize >> 8) + margin
53}
54
55// cCompressBound is a cgo call to check the go implementation above against the c code.
56func cCompressBound(srcSize int) int {
57 return int(C.ZSTD_compressBound(C.size_t(srcSize)))
58}
59
60// Compress src into dst. If you have a buffer to use, you can pass it to
61// prevent allocation. If it is too small, or if nil is passed, a new buffer
62// will be allocated and returned.
63func Compress(dst, src []byte) ([]byte, error) {
64 return CompressLevel(dst, src, DefaultCompression)
65}
66
67// CompressLevel is the same as Compress but you can pass a compression level
68func CompressLevel(dst, src []byte, level int) ([]byte, error) {
69 bound := CompressBound(len(src))
70 if cap(dst) >= bound {
71 dst = dst[0:bound] // Reuse dst buffer
72 } else {
73 dst = make([]byte, bound)
74 }
75
76 srcPtr := C.uintptr_t(uintptr(0)) // Do not point anywhere, if src is empty
77 if len(src) > 0 {
78 srcPtr = C.uintptr_t(uintptr(unsafe.Pointer(&src[0])))
79 }
80
81 cWritten := C.ZSTD_compress_wrapper(
82 C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))),
83 C.size_t(len(dst)),
84 srcPtr,
85 C.size_t(len(src)),
86 C.int(level))
87
Scott Bakerbeb3cfa2019-10-01 14:44:30 -070088 runtime.KeepAlive(src)
khenaidooac637102019-01-14 15:44:34 -050089 written := int(cWritten)
90 // Check if the return is an Error code
91 if err := getError(written); err != nil {
92 return nil, err
93 }
94 return dst[:written], nil
95}
96
97// Decompress src into dst. If you have a buffer to use, you can pass it to
98// prevent allocation. If it is too small, or if nil is passed, a new buffer
99// will be allocated and returned.
100func Decompress(dst, src []byte) ([]byte, error) {
101 if len(src) == 0 {
102 return []byte{}, ErrEmptySlice
103 }
104 decompress := func(dst, src []byte) ([]byte, error) {
105
106 cWritten := C.ZSTD_decompress_wrapper(
107 C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))),
108 C.size_t(len(dst)),
109 C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))),
110 C.size_t(len(src)))
111
Scott Bakerbeb3cfa2019-10-01 14:44:30 -0700112 runtime.KeepAlive(src)
khenaidooac637102019-01-14 15:44:34 -0500113 written := int(cWritten)
114 // Check error
115 if err := getError(written); err != nil {
116 return nil, err
117 }
118 return dst[:written], nil
119 }
120
Scott Bakerbeb3cfa2019-10-01 14:44:30 -0700121 if len(dst) == 0 {
khenaidooac637102019-01-14 15:44:34 -0500122 // Attempt to use zStd to determine decompressed size (may result in error or 0)
123 size := int(C.size_t(C.ZSTD_getDecompressedSize(unsafe.Pointer(&src[0]), C.size_t(len(src)))))
124
125 if err := getError(size); err != nil {
126 return nil, err
127 }
128
129 if size > 0 {
130 dst = make([]byte, size)
131 } else {
132 dst = make([]byte, len(src)*3) // starting guess
133 }
134 }
135 for i := 0; i < 3; i++ { // 3 tries to allocate a bigger buffer
136 result, err := decompress(dst, src)
137 if !IsDstSizeTooSmallError(err) {
138 return result, err
139 }
140 dst = make([]byte, len(dst)*2) // Grow buffer by 2
141 }
142
143 // We failed getting a dst buffer of correct size, use stream API
144 r := NewReader(bytes.NewReader(src))
145 defer r.Close()
146 return ioutil.ReadAll(r)
147}