blob: 9d9d1d567e6885cb37e79b4653c484b877a22d32 [file] [log] [blame]
khenaidoo106c61a2021-08-11 18:05:46 -04001// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8 "encoding/binary"
9 "errors"
10 "hash/crc32"
11 "io"
12
13 "github.com/golang/snappy"
14 "github.com/klauspost/compress/huff0"
15)
16
17const (
18 snappyTagLiteral = 0x00
19 snappyTagCopy1 = 0x01
20 snappyTagCopy2 = 0x02
21 snappyTagCopy4 = 0x03
22)
23
24const (
25 snappyChecksumSize = 4
26 snappyMagicBody = "sNaPpY"
27
28 // snappyMaxBlockSize is the maximum size of the input to encodeBlock. It is not
29 // part of the wire format per se, but some parts of the encoder assume
30 // that an offset fits into a uint16.
31 //
32 // Also, for the framing format (Writer type instead of Encode function),
33 // https://github.com/google/snappy/blob/master/framing_format.txt says
34 // that "the uncompressed data in a chunk must be no longer than 65536
35 // bytes".
36 snappyMaxBlockSize = 65536
37
38 // snappyMaxEncodedLenOfMaxBlockSize equals MaxEncodedLen(snappyMaxBlockSize), but is
39 // hard coded to be a const instead of a variable, so that obufLen can also
40 // be a const. Their equivalence is confirmed by
41 // TestMaxEncodedLenOfMaxBlockSize.
42 snappyMaxEncodedLenOfMaxBlockSize = 76490
43)
44
45const (
46 chunkTypeCompressedData = 0x00
47 chunkTypeUncompressedData = 0x01
48 chunkTypePadding = 0xfe
49 chunkTypeStreamIdentifier = 0xff
50)
51
52var (
53 // ErrSnappyCorrupt reports that the input is invalid.
54 ErrSnappyCorrupt = errors.New("snappy: corrupt input")
55 // ErrSnappyTooLarge reports that the uncompressed length is too large.
56 ErrSnappyTooLarge = errors.New("snappy: decoded block is too large")
57 // ErrSnappyUnsupported reports that the input isn't supported.
58 ErrSnappyUnsupported = errors.New("snappy: unsupported input")
59
60 errUnsupportedLiteralLength = errors.New("snappy: unsupported literal length")
61)
62
63// SnappyConverter can read SnappyConverter-compressed streams and convert them to zstd.
64// Conversion is done by converting the stream directly from Snappy without intermediate
65// full decoding.
66// Therefore the compression ratio is much less than what can be done by a full decompression
67// and compression, and a faulty Snappy stream may lead to a faulty Zstandard stream without
68// any errors being generated.
69// No CRC value is being generated and not all CRC values of the Snappy stream are checked.
70// However, it provides really fast recompression of Snappy streams.
71// The converter can be reused to avoid allocations, even after errors.
72type SnappyConverter struct {
73 r io.Reader
74 err error
75 buf []byte
76 block *blockEnc
77}
78
79// Convert the Snappy stream supplied in 'in' and write the zStandard stream to 'w'.
80// If any error is detected on the Snappy stream it is returned.
81// The number of bytes written is returned.
82func (r *SnappyConverter) Convert(in io.Reader, w io.Writer) (int64, error) {
83 initPredefined()
84 r.err = nil
85 r.r = in
86 if r.block == nil {
87 r.block = &blockEnc{}
88 r.block.init()
89 }
90 r.block.initNewEncode()
91 if len(r.buf) != snappyMaxEncodedLenOfMaxBlockSize+snappyChecksumSize {
92 r.buf = make([]byte, snappyMaxEncodedLenOfMaxBlockSize+snappyChecksumSize)
93 }
94 r.block.litEnc.Reuse = huff0.ReusePolicyNone
95 var written int64
96 var readHeader bool
97 {
98 var header []byte
99 var n int
100 header, r.err = frameHeader{WindowSize: snappyMaxBlockSize}.appendTo(r.buf[:0])
101
102 n, r.err = w.Write(header)
103 if r.err != nil {
104 return written, r.err
105 }
106 written += int64(n)
107 }
108
109 for {
110 if !r.readFull(r.buf[:4], true) {
111 // Add empty last block
112 r.block.reset(nil)
113 r.block.last = true
114 err := r.block.encodeLits(r.block.literals, false)
115 if err != nil {
116 return written, err
117 }
118 n, err := w.Write(r.block.output)
119 if err != nil {
120 return written, err
121 }
122 written += int64(n)
123
124 return written, r.err
125 }
126 chunkType := r.buf[0]
127 if !readHeader {
128 if chunkType != chunkTypeStreamIdentifier {
129 println("chunkType != chunkTypeStreamIdentifier", chunkType)
130 r.err = ErrSnappyCorrupt
131 return written, r.err
132 }
133 readHeader = true
134 }
135 chunkLen := int(r.buf[1]) | int(r.buf[2])<<8 | int(r.buf[3])<<16
136 if chunkLen > len(r.buf) {
137 println("chunkLen > len(r.buf)", chunkType)
138 r.err = ErrSnappyUnsupported
139 return written, r.err
140 }
141
142 // The chunk types are specified at
143 // https://github.com/google/snappy/blob/master/framing_format.txt
144 switch chunkType {
145 case chunkTypeCompressedData:
146 // Section 4.2. Compressed data (chunk type 0x00).
147 if chunkLen < snappyChecksumSize {
148 println("chunkLen < snappyChecksumSize", chunkLen, snappyChecksumSize)
149 r.err = ErrSnappyCorrupt
150 return written, r.err
151 }
152 buf := r.buf[:chunkLen]
153 if !r.readFull(buf, false) {
154 return written, r.err
155 }
156 //checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
157 buf = buf[snappyChecksumSize:]
158
159 n, hdr, err := snappyDecodedLen(buf)
160 if err != nil {
161 r.err = err
162 return written, r.err
163 }
164 buf = buf[hdr:]
165 if n > snappyMaxBlockSize {
166 println("n > snappyMaxBlockSize", n, snappyMaxBlockSize)
167 r.err = ErrSnappyCorrupt
168 return written, r.err
169 }
170 r.block.reset(nil)
171 r.block.pushOffsets()
172 if err := decodeSnappy(r.block, buf); err != nil {
173 r.err = err
174 return written, r.err
175 }
176 if r.block.size+r.block.extraLits != n {
177 printf("invalid size, want %d, got %d\n", n, r.block.size+r.block.extraLits)
178 r.err = ErrSnappyCorrupt
179 return written, r.err
180 }
181 err = r.block.encode(nil, false, false)
182 switch err {
183 case errIncompressible:
184 r.block.popOffsets()
185 r.block.reset(nil)
186 r.block.literals, err = snappy.Decode(r.block.literals[:n], r.buf[snappyChecksumSize:chunkLen])
187 if err != nil {
188 return written, err
189 }
190 err = r.block.encodeLits(r.block.literals, false)
191 if err != nil {
192 return written, err
193 }
194 case nil:
195 default:
196 return written, err
197 }
198
199 n, r.err = w.Write(r.block.output)
200 if r.err != nil {
201 return written, err
202 }
203 written += int64(n)
204 continue
205 case chunkTypeUncompressedData:
206 if debug {
207 println("Uncompressed, chunklen", chunkLen)
208 }
209 // Section 4.3. Uncompressed data (chunk type 0x01).
210 if chunkLen < snappyChecksumSize {
211 println("chunkLen < snappyChecksumSize", chunkLen, snappyChecksumSize)
212 r.err = ErrSnappyCorrupt
213 return written, r.err
214 }
215 r.block.reset(nil)
216 buf := r.buf[:snappyChecksumSize]
217 if !r.readFull(buf, false) {
218 return written, r.err
219 }
220 checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
221 // Read directly into r.decoded instead of via r.buf.
222 n := chunkLen - snappyChecksumSize
223 if n > snappyMaxBlockSize {
224 println("n > snappyMaxBlockSize", n, snappyMaxBlockSize)
225 r.err = ErrSnappyCorrupt
226 return written, r.err
227 }
228 r.block.literals = r.block.literals[:n]
229 if !r.readFull(r.block.literals, false) {
230 return written, r.err
231 }
232 if snappyCRC(r.block.literals) != checksum {
233 println("literals crc mismatch")
234 r.err = ErrSnappyCorrupt
235 return written, r.err
236 }
237 err := r.block.encodeLits(r.block.literals, false)
238 if err != nil {
239 return written, err
240 }
241 n, r.err = w.Write(r.block.output)
242 if r.err != nil {
243 return written, err
244 }
245 written += int64(n)
246 continue
247
248 case chunkTypeStreamIdentifier:
249 if debug {
250 println("stream id", chunkLen, len(snappyMagicBody))
251 }
252 // Section 4.1. Stream identifier (chunk type 0xff).
253 if chunkLen != len(snappyMagicBody) {
254 println("chunkLen != len(snappyMagicBody)", chunkLen, len(snappyMagicBody))
255 r.err = ErrSnappyCorrupt
256 return written, r.err
257 }
258 if !r.readFull(r.buf[:len(snappyMagicBody)], false) {
259 return written, r.err
260 }
261 for i := 0; i < len(snappyMagicBody); i++ {
262 if r.buf[i] != snappyMagicBody[i] {
263 println("r.buf[i] != snappyMagicBody[i]", r.buf[i], snappyMagicBody[i], i)
264 r.err = ErrSnappyCorrupt
265 return written, r.err
266 }
267 }
268 continue
269 }
270
271 if chunkType <= 0x7f {
272 // Section 4.5. Reserved unskippable chunks (chunk types 0x02-0x7f).
273 println("chunkType <= 0x7f")
274 r.err = ErrSnappyUnsupported
275 return written, r.err
276 }
277 // Section 4.4 Padding (chunk type 0xfe).
278 // Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd).
279 if !r.readFull(r.buf[:chunkLen], false) {
280 return written, r.err
281 }
282 }
283}
284
285// decodeSnappy writes the decoding of src to dst. It assumes that the varint-encoded
286// length of the decompressed bytes has already been read.
287func decodeSnappy(blk *blockEnc, src []byte) error {
288 //decodeRef(make([]byte, snappyMaxBlockSize), src)
289 var s, length int
290 lits := blk.extraLits
291 var offset uint32
292 for s < len(src) {
293 switch src[s] & 0x03 {
294 case snappyTagLiteral:
295 x := uint32(src[s] >> 2)
296 switch {
297 case x < 60:
298 s++
299 case x == 60:
300 s += 2
301 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
302 println("uint(s) > uint(len(src)", s, src)
303 return ErrSnappyCorrupt
304 }
305 x = uint32(src[s-1])
306 case x == 61:
307 s += 3
308 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
309 println("uint(s) > uint(len(src)", s, src)
310 return ErrSnappyCorrupt
311 }
312 x = uint32(src[s-2]) | uint32(src[s-1])<<8
313 case x == 62:
314 s += 4
315 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
316 println("uint(s) > uint(len(src)", s, src)
317 return ErrSnappyCorrupt
318 }
319 x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16
320 case x == 63:
321 s += 5
322 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
323 println("uint(s) > uint(len(src)", s, src)
324 return ErrSnappyCorrupt
325 }
326 x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
327 }
328 if x > snappyMaxBlockSize {
329 println("x > snappyMaxBlockSize", x, snappyMaxBlockSize)
330 return ErrSnappyCorrupt
331 }
332 length = int(x) + 1
333 if length <= 0 {
334 println("length <= 0 ", length)
335
336 return errUnsupportedLiteralLength
337 }
338 //if length > snappyMaxBlockSize-d || uint32(length) > len(src)-s {
339 // return ErrSnappyCorrupt
340 //}
341
342 blk.literals = append(blk.literals, src[s:s+length]...)
343 //println(length, "litLen")
344 lits += length
345 s += length
346 continue
347
348 case snappyTagCopy1:
349 s += 2
350 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
351 println("uint(s) > uint(len(src)", s, len(src))
352 return ErrSnappyCorrupt
353 }
354 length = 4 + int(src[s-2])>>2&0x7
355 offset = uint32(src[s-2])&0xe0<<3 | uint32(src[s-1])
356
357 case snappyTagCopy2:
358 s += 3
359 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
360 println("uint(s) > uint(len(src)", s, len(src))
361 return ErrSnappyCorrupt
362 }
363 length = 1 + int(src[s-3])>>2
364 offset = uint32(src[s-2]) | uint32(src[s-1])<<8
365
366 case snappyTagCopy4:
367 s += 5
368 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
369 println("uint(s) > uint(len(src)", s, len(src))
370 return ErrSnappyCorrupt
371 }
372 length = 1 + int(src[s-5])>>2
373 offset = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
374 }
375
376 if offset <= 0 || blk.size+lits < int(offset) /*|| length > len(blk)-d */ {
377 println("offset <= 0 || blk.size+lits < int(offset)", offset, blk.size+lits, int(offset), blk.size, lits)
378
379 return ErrSnappyCorrupt
380 }
381
382 // Check if offset is one of the recent offsets.
383 // Adjusts the output offset accordingly.
384 // Gives a tiny bit of compression, typically around 1%.
385 if false {
386 offset = blk.matchOffset(offset, uint32(lits))
387 } else {
388 offset += 3
389 }
390
391 blk.sequences = append(blk.sequences, seq{
392 litLen: uint32(lits),
393 offset: offset,
394 matchLen: uint32(length) - zstdMinMatch,
395 })
396 blk.size += length + lits
397 lits = 0
398 }
399 blk.extraLits = lits
400 return nil
401}
402
403func (r *SnappyConverter) readFull(p []byte, allowEOF bool) (ok bool) {
404 if _, r.err = io.ReadFull(r.r, p); r.err != nil {
405 if r.err == io.ErrUnexpectedEOF || (r.err == io.EOF && !allowEOF) {
406 r.err = ErrSnappyCorrupt
407 }
408 return false
409 }
410 return true
411}
412
413var crcTable = crc32.MakeTable(crc32.Castagnoli)
414
415// crc implements the checksum specified in section 3 of
416// https://github.com/google/snappy/blob/master/framing_format.txt
417func snappyCRC(b []byte) uint32 {
418 c := crc32.Update(0, crcTable, b)
419 return c>>15 | c<<17 + 0xa282ead8
420}
421
422// snappyDecodedLen returns the length of the decoded block and the number of bytes
423// that the length header occupied.
424func snappyDecodedLen(src []byte) (blockLen, headerLen int, err error) {
425 v, n := binary.Uvarint(src)
426 if n <= 0 || v > 0xffffffff {
427 return 0, 0, ErrSnappyCorrupt
428 }
429
430 const wordSize = 32 << (^uint(0) >> 32 & 1)
431 if wordSize == 32 && v > 0x7fffffff {
432 return 0, 0, ErrSnappyTooLarge
433 }
434 return int(v), n, nil
435}