blob: 658ef78380e1dcd6a9aaa39924fcf7ea616d227d [file] [log] [blame]
khenaidoo7d3c5582021-08-11 18:09:44 -04001// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8 "fmt"
9 "io"
10 "io/ioutil"
11)
12
13type byteBuffer interface {
14 // Read up to 8 bytes.
15 // Returns nil if no more input is available.
16 readSmall(n int) []byte
17
18 // Read >8 bytes.
19 // MAY use the destination slice.
20 readBig(n int, dst []byte) ([]byte, error)
21
22 // Read a single byte.
23 readByte() (byte, error)
24
25 // Skip n bytes.
26 skipN(n int) error
27}
28
29// in-memory buffer
30type byteBuf []byte
31
32func (b *byteBuf) readSmall(n int) []byte {
33 if debugAsserts && n > 8 {
34 panic(fmt.Errorf("small read > 8 (%d). use readBig", n))
35 }
36 bb := *b
37 if len(bb) < n {
38 return nil
39 }
40 r := bb[:n]
41 *b = bb[n:]
42 return r
43}
44
45func (b *byteBuf) readBig(n int, dst []byte) ([]byte, error) {
46 bb := *b
47 if len(bb) < n {
48 return nil, io.ErrUnexpectedEOF
49 }
50 r := bb[:n]
51 *b = bb[n:]
52 return r, nil
53}
54
55func (b *byteBuf) remain() []byte {
56 return *b
57}
58
59func (b *byteBuf) readByte() (byte, error) {
60 bb := *b
61 if len(bb) < 1 {
62 return 0, nil
63 }
64 r := bb[0]
65 *b = bb[1:]
66 return r, nil
67}
68
69func (b *byteBuf) skipN(n int) error {
70 bb := *b
71 if len(bb) < n {
72 return io.ErrUnexpectedEOF
73 }
74 *b = bb[n:]
75 return nil
76}
77
78// wrapper around a reader.
79type readerWrapper struct {
80 r io.Reader
81 tmp [8]byte
82}
83
84func (r *readerWrapper) readSmall(n int) []byte {
85 if debugAsserts && n > 8 {
86 panic(fmt.Errorf("small read > 8 (%d). use readBig", n))
87 }
88 n2, err := io.ReadFull(r.r, r.tmp[:n])
89 // We only really care about the actual bytes read.
90 if n2 != n {
91 if debug {
92 println("readSmall: got", n2, "want", n, "err", err)
93 }
94 return nil
95 }
96 return r.tmp[:n]
97}
98
99func (r *readerWrapper) readBig(n int, dst []byte) ([]byte, error) {
100 if cap(dst) < n {
101 dst = make([]byte, n)
102 }
103 n2, err := io.ReadFull(r.r, dst[:n])
104 if err == io.EOF && n > 0 {
105 err = io.ErrUnexpectedEOF
106 }
107 return dst[:n2], err
108}
109
110func (r *readerWrapper) readByte() (byte, error) {
111 n2, err := r.r.Read(r.tmp[:1])
112 if err != nil {
113 return 0, err
114 }
115 if n2 != 1 {
116 return 0, io.ErrUnexpectedEOF
117 }
118 return r.tmp[0], nil
119}
120
121func (r *readerWrapper) skipN(n int) error {
122 n2, err := io.CopyN(ioutil.Discard, r.r, int64(n))
123 if n2 != int64(n) {
124 err = io.ErrUnexpectedEOF
125 }
126 return err
127}