blob: 80507e14e45c90bbf65632836235229a2ab84207 [file] [log] [blame]
kesavandc71914f2022-03-25 11:19:03 +05301package sarama
2
3import (
4 "sync"
5
6 "github.com/klauspost/compress/zstd"
7)
8
9type ZstdEncoderParams struct {
10 Level int
11}
12type ZstdDecoderParams struct {
13}
14
15var zstdEncMap, zstdDecMap sync.Map
16
17func getEncoder(params ZstdEncoderParams) *zstd.Encoder {
18 if ret, ok := zstdEncMap.Load(params); ok {
19 return ret.(*zstd.Encoder)
20 }
21 // It's possible to race and create multiple new writers.
22 // Only one will survive GC after use.
23 encoderLevel := zstd.SpeedDefault
24 if params.Level != CompressionLevelDefault {
25 encoderLevel = zstd.EncoderLevelFromZstd(params.Level)
26 }
27 zstdEnc, _ := zstd.NewWriter(nil, zstd.WithZeroFrames(true),
28 zstd.WithEncoderLevel(encoderLevel))
29 zstdEncMap.Store(params, zstdEnc)
30 return zstdEnc
31}
32
33func getDecoder(params ZstdDecoderParams) *zstd.Decoder {
34 if ret, ok := zstdDecMap.Load(params); ok {
35 return ret.(*zstd.Decoder)
36 }
37 // It's possible to race and create multiple new readers.
38 // Only one will survive GC after use.
39 zstdDec, _ := zstd.NewReader(nil)
40 zstdDecMap.Store(params, zstdDec)
41 return zstdDec
42}
43
44func zstdDecompress(params ZstdDecoderParams, dst, src []byte) ([]byte, error) {
45 return getDecoder(params).DecodeAll(src, dst)
46}
47
48func zstdCompress(params ZstdEncoderParams, dst, src []byte) ([]byte, error) {
49 return getEncoder(params).EncodeAll(src, dst), nil
50}