blob: aab71c6cf851b922691d5e13f23fcc34b1d06ceb [file] [log] [blame]
kesavandc71914f2022-03-25 11:19:03 +05301// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8 "fmt"
9 "io"
10 "io/ioutil"
11)
12
13type byteBuffer interface {
14 // Read up to 8 bytes.
15 // Returns io.ErrUnexpectedEOF if this cannot be satisfied.
16 readSmall(n int) ([]byte, error)
17
18 // Read >8 bytes.
19 // MAY use the destination slice.
20 readBig(n int, dst []byte) ([]byte, error)
21
22 // Read a single byte.
23 readByte() (byte, error)
24
25 // Skip n bytes.
26 skipN(n int) error
27}
28
29// in-memory buffer
30type byteBuf []byte
31
32func (b *byteBuf) readSmall(n int) ([]byte, error) {
33 if debugAsserts && n > 8 {
34 panic(fmt.Errorf("small read > 8 (%d). use readBig", n))
35 }
36 bb := *b
37 if len(bb) < n {
38 return nil, io.ErrUnexpectedEOF
39 }
40 r := bb[:n]
41 *b = bb[n:]
42 return r, nil
43}
44
45func (b *byteBuf) readBig(n int, dst []byte) ([]byte, error) {
46 bb := *b
47 if len(bb) < n {
48 return nil, io.ErrUnexpectedEOF
49 }
50 r := bb[:n]
51 *b = bb[n:]
52 return r, nil
53}
54
55func (b *byteBuf) remain() []byte {
56 return *b
57}
58
59func (b *byteBuf) readByte() (byte, error) {
60 bb := *b
61 if len(bb) < 1 {
62 return 0, nil
63 }
64 r := bb[0]
65 *b = bb[1:]
66 return r, nil
67}
68
69func (b *byteBuf) skipN(n int) error {
70 bb := *b
71 if len(bb) < n {
72 return io.ErrUnexpectedEOF
73 }
74 *b = bb[n:]
75 return nil
76}
77
78// wrapper around a reader.
79type readerWrapper struct {
80 r io.Reader
81 tmp [8]byte
82}
83
84func (r *readerWrapper) readSmall(n int) ([]byte, error) {
85 if debugAsserts && n > 8 {
86 panic(fmt.Errorf("small read > 8 (%d). use readBig", n))
87 }
88 n2, err := io.ReadFull(r.r, r.tmp[:n])
89 // We only really care about the actual bytes read.
90 if err != nil {
91 if err == io.EOF {
92 return nil, io.ErrUnexpectedEOF
93 }
94 if debugDecoder {
95 println("readSmall: got", n2, "want", n, "err", err)
96 }
97 return nil, err
98 }
99 return r.tmp[:n], nil
100}
101
102func (r *readerWrapper) readBig(n int, dst []byte) ([]byte, error) {
103 if cap(dst) < n {
104 dst = make([]byte, n)
105 }
106 n2, err := io.ReadFull(r.r, dst[:n])
107 if err == io.EOF && n > 0 {
108 err = io.ErrUnexpectedEOF
109 }
110 return dst[:n2], err
111}
112
113func (r *readerWrapper) readByte() (byte, error) {
114 n2, err := r.r.Read(r.tmp[:1])
115 if err != nil {
116 return 0, err
117 }
118 if n2 != 1 {
119 return 0, io.ErrUnexpectedEOF
120 }
121 return r.tmp[0], nil
122}
123
124func (r *readerWrapper) skipN(n int) error {
125 n2, err := io.CopyN(ioutil.Discard, r.r, int64(n))
126 if n2 != int64(n) {
127 err = io.ErrUnexpectedEOF
128 }
129 return err
130}