blob: e9953d6cc3a4785fed7cfbb3616a15834bd2ea36 [file] [log] [blame]
William Kurkianea869482019-04-09 15:16:11 -04001package zstd
2
3/*
4#define ZSTD_STATIC_LINKING_ONLY
5#include "zstd.h"
6#include "stdint.h" // for uintptr_t
7
8// The following *_wrapper function are used for removing superflouos
9// memory allocations when calling the wrapped functions from Go code.
10// See https://github.com/golang/go/issues/24450 for details.
11
12static size_t ZSTD_compress_wrapper(uintptr_t dst, size_t maxDstSize, const uintptr_t src, size_t srcSize, int compressionLevel) {
13 return ZSTD_compress((void*)dst, maxDstSize, (const void*)src, srcSize, compressionLevel);
14}
15
16static size_t ZSTD_decompress_wrapper(uintptr_t dst, size_t maxDstSize, uintptr_t src, size_t srcSize) {
17 return ZSTD_decompress((void*)dst, maxDstSize, (const void *)src, srcSize);
18}
19
20*/
21import "C"
22import (
23 "bytes"
24 "errors"
25 "io/ioutil"
26 "unsafe"
27)
28
29// Defines best and standard values for zstd cli
30const (
31 BestSpeed = 1
32 BestCompression = 20
33 DefaultCompression = 5
34)
35
36var (
37 // ErrEmptySlice is returned when there is nothing to compress
38 ErrEmptySlice = errors.New("Bytes slice is empty")
39)
40
41// CompressBound returns the worst case size needed for a destination buffer,
42// which can be used to preallocate a destination buffer or select a previously
43// allocated buffer from a pool.
44// See zstd.h to mirror implementation of ZSTD_COMPRESSBOUND
45func CompressBound(srcSize int) int {
46 lowLimit := 128 << 10 // 128 kB
47 var margin int
48 if srcSize < lowLimit {
49 margin = (lowLimit - srcSize) >> 11
50 }
51 return srcSize + (srcSize >> 8) + margin
52}
53
54// cCompressBound is a cgo call to check the go implementation above against the c code.
55func cCompressBound(srcSize int) int {
56 return int(C.ZSTD_compressBound(C.size_t(srcSize)))
57}
58
59// Compress src into dst. If you have a buffer to use, you can pass it to
60// prevent allocation. If it is too small, or if nil is passed, a new buffer
61// will be allocated and returned.
62func Compress(dst, src []byte) ([]byte, error) {
63 return CompressLevel(dst, src, DefaultCompression)
64}
65
66// CompressLevel is the same as Compress but you can pass a compression level
67func CompressLevel(dst, src []byte, level int) ([]byte, error) {
68 bound := CompressBound(len(src))
69 if cap(dst) >= bound {
70 dst = dst[0:bound] // Reuse dst buffer
71 } else {
72 dst = make([]byte, bound)
73 }
74
75 srcPtr := C.uintptr_t(uintptr(0)) // Do not point anywhere, if src is empty
76 if len(src) > 0 {
77 srcPtr = C.uintptr_t(uintptr(unsafe.Pointer(&src[0])))
78 }
79
80 cWritten := C.ZSTD_compress_wrapper(
81 C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))),
82 C.size_t(len(dst)),
83 srcPtr,
84 C.size_t(len(src)),
85 C.int(level))
86
87 written := int(cWritten)
88 // Check if the return is an Error code
89 if err := getError(written); err != nil {
90 return nil, err
91 }
92 return dst[:written], nil
93}
94
95// Decompress src into dst. If you have a buffer to use, you can pass it to
96// prevent allocation. If it is too small, or if nil is passed, a new buffer
97// will be allocated and returned.
98func Decompress(dst, src []byte) ([]byte, error) {
99 if len(src) == 0 {
100 return []byte{}, ErrEmptySlice
101 }
102 decompress := func(dst, src []byte) ([]byte, error) {
103
104 cWritten := C.ZSTD_decompress_wrapper(
105 C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))),
106 C.size_t(len(dst)),
107 C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))),
108 C.size_t(len(src)))
109
110 written := int(cWritten)
111 // Check error
112 if err := getError(written); err != nil {
113 return nil, err
114 }
115 return dst[:written], nil
116 }
117
118 if dst == nil {
119 // Attempt to use zStd to determine decompressed size (may result in error or 0)
120 size := int(C.size_t(C.ZSTD_getDecompressedSize(unsafe.Pointer(&src[0]), C.size_t(len(src)))))
121
122 if err := getError(size); err != nil {
123 return nil, err
124 }
125
126 if size > 0 {
127 dst = make([]byte, size)
128 } else {
129 dst = make([]byte, len(src)*3) // starting guess
130 }
131 }
132 for i := 0; i < 3; i++ { // 3 tries to allocate a bigger buffer
133 result, err := decompress(dst, src)
134 if !IsDstSizeTooSmallError(err) {
135 return result, err
136 }
137 dst = make([]byte, len(dst)*2) // Grow buffer by 2
138 }
139
140 // We failed getting a dst buffer of correct size, use stream API
141 r := NewReader(bytes.NewReader(src))
142 defer r.Close()
143 return ioutil.ReadAll(r)
144}