blob: ac3839f48d7f3e8933209c66617b455ca90c2762 [file] [log] [blame]
Matteo Scandolo9a2772a2018-11-19 14:56:26 -08001package zstd
2
3/*
4#define ZSTD_STATIC_LINKING_ONLY
5#include "zstd.h"
6#include "stdint.h" // for uintptr_t
7
8// The following *_wrapper function are used for removing superflouos
9// memory allocations when calling the wrapped functions from Go code.
10// See https://github.com/golang/go/issues/24450 for details.
11
12static size_t ZSTD_compress_wrapper(uintptr_t dst, size_t maxDstSize, const uintptr_t src, size_t srcSize, int compressionLevel) {
13 return ZSTD_compress((void*)dst, maxDstSize, (const void*)src, srcSize, compressionLevel);
14}
15
16static size_t ZSTD_decompress_wrapper(uintptr_t dst, size_t maxDstSize, uintptr_t src, size_t srcSize) {
17 return ZSTD_decompress((void*)dst, maxDstSize, (const void *)src, srcSize);
18}
19
20*/
21import "C"
22import (
23 "bytes"
24 "errors"
25 "io/ioutil"
26 "unsafe"
27)
28
29// Defines best and standard values for zstd cli
30const (
31 BestSpeed = 1
32 BestCompression = 20
33 DefaultCompression = 5
34)
35
36var (
37 // ErrEmptySlice is returned when there is nothing to compress
38 ErrEmptySlice = errors.New("Bytes slice is empty")
39)
40
41// CompressBound returns the worst case size needed for a destination buffer,
42// which can be used to preallocate a destination buffer or select a previously
43// allocated buffer from a pool.
44// See zstd.h to mirror implementation of ZSTD_COMPRESSBOUND
45func CompressBound(srcSize int) int {
46 lowLimit := 128 << 10 // 128 kB
47 var margin int
48 if srcSize < lowLimit {
49 margin = (lowLimit - srcSize) >> 11
50 }
51 return srcSize + (srcSize >> 8) + margin
52}
53
54// cCompressBound is a cgo call to check the go implementation above against the c code.
55func cCompressBound(srcSize int) int {
56 return int(C.ZSTD_compressBound(C.size_t(srcSize)))
57}
58
59// Compress src into dst. If you have a buffer to use, you can pass it to
60// prevent allocation. If it is too small, or if nil is passed, a new buffer
61// will be allocated and returned.
62func Compress(dst, src []byte) ([]byte, error) {
63 return CompressLevel(dst, src, DefaultCompression)
64}
65
66// CompressLevel is the same as Compress but you can pass a compression level
67func CompressLevel(dst, src []byte, level int) ([]byte, error) {
68 if len(src) == 0 {
69 return []byte{}, ErrEmptySlice
70 }
71 bound := CompressBound(len(src))
72 if cap(dst) >= bound {
73 dst = dst[0:bound] // Reuse dst buffer
74 } else {
75 dst = make([]byte, bound)
76 }
77
78 cWritten := C.ZSTD_compress_wrapper(
79 C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))),
80 C.size_t(len(dst)),
81 C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))),
82 C.size_t(len(src)),
83 C.int(level))
84
85 written := int(cWritten)
86 // Check if the return is an Error code
87 if err := getError(written); err != nil {
88 return nil, err
89 }
90 return dst[:written], nil
91}
92
93// Decompress src into dst. If you have a buffer to use, you can pass it to
94// prevent allocation. If it is too small, or if nil is passed, a new buffer
95// will be allocated and returned.
96func Decompress(dst, src []byte) ([]byte, error) {
97 decompress := func(dst, src []byte) ([]byte, error) {
98
99 cWritten := C.ZSTD_decompress_wrapper(
100 C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))),
101 C.size_t(len(dst)),
102 C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))),
103 C.size_t(len(src)))
104
105 written := int(cWritten)
106 // Check error
107 if err := getError(written); err != nil {
108 return nil, err
109 }
110 return dst[:written], nil
111 }
112
113 if dst == nil {
114 // Attempt to use zStd to determine decompressed size (may result in error or 0)
115 size := int(C.size_t(C.ZSTD_getDecompressedSize(unsafe.Pointer(&src[0]), C.size_t(len(src)))))
116
117 if err := getError(size); err != nil {
118 return nil, err
119 }
120
121 if size > 0 {
122 dst = make([]byte, size)
123 } else {
124 dst = make([]byte, len(src)*3) // starting guess
125 }
126 }
127 for i := 0; i < 3; i++ { // 3 tries to allocate a bigger buffer
128 result, err := decompress(dst, src)
129 if !IsDstSizeTooSmallError(err) {
130 return result, err
131 }
132 dst = make([]byte, len(dst)*2) // Grow buffer by 2
133 }
134
135 // We failed getting a dst buffer of correct size, use stream API
136 r := NewReader(bytes.NewReader(src))
137 defer r.Close()
138 return ioutil.ReadAll(r)
139}