blob: 5ec2ea078b3bdc8b07739316ae06638c253359d6 [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package compressor
8
9import (
10 "bytes"
11 "compress/zlib"
12
13 "io"
14
15 "github.com/golang/snappy"
16 "github.com/mongodb/mongo-go-driver/x/network/wiremessage"
17)
18
19// Compressor is the interface implemented by types that can compress and decompress wire messages. This is used
20// when sending and receiving messages to and from the server.
21type Compressor interface {
22 CompressBytes(src, dest []byte) ([]byte, error)
23 UncompressBytes(src, dest []byte) ([]byte, error)
24 CompressorID() wiremessage.CompressorID
25 Name() string
26}
27
28type writer struct {
29 buf []byte
30}
31
32// Write appends bytes to the writer
33func (w *writer) Write(p []byte) (n int, err error) {
34 index := len(w.buf)
35 if len(p) > cap(w.buf)-index {
36 buf := make([]byte, 2*cap(w.buf)+len(p))
37 copy(buf, w.buf)
38 w.buf = buf
39 }
40
41 w.buf = w.buf[:index+len(p)]
42 copy(w.buf[index:], p)
43 return len(p), nil
44}
45
46// SnappyCompressor uses the snappy method to compress data
47type SnappyCompressor struct {
48}
49
50// ZlibCompressor uses the zlib method to compress data
51type ZlibCompressor struct {
52 level int
53 zlibWriter *zlib.Writer
54}
55
56// CompressBytes uses snappy to compress a slice of bytes.
57func (s *SnappyCompressor) CompressBytes(src, dest []byte) ([]byte, error) {
58 dest = dest[:0]
59 dest = snappy.Encode(dest, src)
60 return dest, nil
61}
62
63// UncompressBytes uses snappy to uncompress a slice of bytes.
64func (s *SnappyCompressor) UncompressBytes(src, dest []byte) ([]byte, error) {
65 var err error
66 dest, err = snappy.Decode(dest, src)
67 if err != nil {
68 return dest, err
69 }
70
71 return dest, nil
72}
73
74// CompressorID returns the ID for the snappy compressor.
75func (s *SnappyCompressor) CompressorID() wiremessage.CompressorID {
76 return wiremessage.CompressorSnappy
77}
78
79// Name returns the string name for the snappy compressor.
80func (s *SnappyCompressor) Name() string {
81 return "snappy"
82}
83
84// CompressBytes uses zlib to compress a slice of bytes.
85func (z *ZlibCompressor) CompressBytes(src, dest []byte) ([]byte, error) {
86 dest = dest[:0]
87 z.zlibWriter.Reset(&writer{
88 buf: dest,
89 })
90
91 _, err := z.zlibWriter.Write(src)
92 if err != nil {
93 _ = z.zlibWriter.Close()
94 return dest, err
95 }
96
97 err = z.zlibWriter.Close()
98 if err != nil {
99 return dest, err
100 }
101 return dest, nil
102}
103
104// UncompressBytes uses zlib to uncompress a slice of bytes. It assumes dest is empty and is the exact size that it
105// needs to be.
106func (z *ZlibCompressor) UncompressBytes(src, dest []byte) ([]byte, error) {
107 reader := bytes.NewReader(src)
108 zlibReader, err := zlib.NewReader(reader)
109
110 if err != nil {
111 return dest, err
112 }
113 defer func() {
114 _ = zlibReader.Close()
115 }()
116
117 _, err = io.ReadFull(zlibReader, dest)
118 if err != nil {
119 return dest, err
120 }
121
122 return dest, nil
123}
124
125// CompressorID returns the ID for the zlib compressor.
126func (z *ZlibCompressor) CompressorID() wiremessage.CompressorID {
127 return wiremessage.CompressorZLib
128}
129
130// Name returns the name for the zlib compressor.
131func (z *ZlibCompressor) Name() string {
132 return "zlib"
133}
134
135// CreateSnappy creates a snappy compressor
136func CreateSnappy() Compressor {
137 return &SnappyCompressor{}
138}
139
140// CreateZlib creates a zlib compressor
141func CreateZlib(level int) (Compressor, error) {
142 if level < 0 {
143 level = wiremessage.DefaultZlibLevel
144 }
145
146 var compressBuf bytes.Buffer
147 zlibWriter, err := zlib.NewWriterLevel(&compressBuf, level)
148
149 if err != nil {
150 return &ZlibCompressor{}, err
151 }
152
153 return &ZlibCompressor{
154 level: level,
155 zlibWriter: zlibWriter,
156 }, nil
157}