blob: 97299d499cf0a206ab5340ece617c3e447cb077d [file] [log] [blame]
khenaidoo7d3c5582021-08-11 18:09:44 -04001// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8 "encoding/binary"
9 "errors"
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +053010 "fmt"
khenaidoo7d3c5582021-08-11 18:09:44 -040011 "io"
12 "math/bits"
13)
14
15// bitReader reads a bitstream in reverse.
16// The last set bit indicates the start of the stream and is used
17// for aligning the input.
18type bitReader struct {
19 in []byte
20 off uint // next byte to read is at in[off - 1]
21 value uint64 // Maybe use [16]byte, but shifting is awkward.
22 bitsRead uint8
23}
24
25// init initializes and resets the bit reader.
26func (b *bitReader) init(in []byte) error {
27 if len(in) < 1 {
28 return errors.New("corrupt stream: too short")
29 }
30 b.in = in
31 b.off = uint(len(in))
32 // The highest bit of the last byte indicates where to start
33 v := in[len(in)-1]
34 if v == 0 {
35 return errors.New("corrupt stream, did not find end of stream")
36 }
37 b.bitsRead = 64
38 b.value = 0
39 if len(in) >= 8 {
40 b.fillFastStart()
41 } else {
42 b.fill()
43 b.fill()
44 }
45 b.bitsRead += 8 - uint8(highBits(uint32(v)))
46 return nil
47}
48
49// getBits will return n bits. n can be 0.
50func (b *bitReader) getBits(n uint8) int {
51 if n == 0 /*|| b.bitsRead >= 64 */ {
52 return 0
53 }
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +053054 return int(b.get32BitsFast(n))
khenaidoo7d3c5582021-08-11 18:09:44 -040055}
56
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +053057// get32BitsFast requires that at least one bit is requested every time.
khenaidoo7d3c5582021-08-11 18:09:44 -040058// There are no checks if the buffer is filled.
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +053059func (b *bitReader) get32BitsFast(n uint8) uint32 {
khenaidoo7d3c5582021-08-11 18:09:44 -040060 const regMask = 64 - 1
61 v := uint32((b.value << (b.bitsRead & regMask)) >> ((regMask + 1 - n) & regMask))
62 b.bitsRead += n
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +053063 return v
khenaidoo7d3c5582021-08-11 18:09:44 -040064}
65
66// fillFast() will make sure at least 32 bits are available.
67// There must be at least 4 bytes available.
68func (b *bitReader) fillFast() {
69 if b.bitsRead < 32 {
70 return
71 }
72 // 2 bounds checks.
73 v := b.in[b.off-4:]
74 v = v[:4]
75 low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
76 b.value = (b.value << 32) | uint64(low)
77 b.bitsRead -= 32
78 b.off -= 4
79}
80
81// fillFastStart() assumes the bitreader is empty and there is at least 8 bytes to read.
82func (b *bitReader) fillFastStart() {
83 // Do single re-slice to avoid bounds checks.
84 b.value = binary.LittleEndian.Uint64(b.in[b.off-8:])
85 b.bitsRead = 0
86 b.off -= 8
87}
88
89// fill() will make sure at least 32 bits are available.
90func (b *bitReader) fill() {
91 if b.bitsRead < 32 {
92 return
93 }
94 if b.off >= 4 {
95 v := b.in[b.off-4:]
96 v = v[:4]
97 low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
98 b.value = (b.value << 32) | uint64(low)
99 b.bitsRead -= 32
100 b.off -= 4
101 return
102 }
103 for b.off > 0 {
104 b.value = (b.value << 8) | uint64(b.in[b.off-1])
105 b.bitsRead -= 8
106 b.off--
107 }
108}
109
110// finished returns true if all bits have been read from the bit stream.
111func (b *bitReader) finished() bool {
112 return b.off == 0 && b.bitsRead >= 64
113}
114
115// overread returns true if more bits have been requested than is on the stream.
116func (b *bitReader) overread() bool {
117 return b.bitsRead > 64
118}
119
120// remain returns the number of bits remaining.
121func (b *bitReader) remain() uint {
122 return b.off*8 + 64 - uint(b.bitsRead)
123}
124
125// close the bitstream and returns an error if out-of-buffer reads occurred.
126func (b *bitReader) close() error {
127 // Release reference.
128 b.in = nil
Akash Reddy Kankanalac28f0e22025-06-16 11:00:55 +0530129 if !b.finished() {
130 return fmt.Errorf("%d extra bits on block, should be 0", b.remain())
131 }
khenaidoo7d3c5582021-08-11 18:09:44 -0400132 if b.bitsRead > 64 {
133 return io.ErrUnexpectedEOF
134 }
135 return nil
136}
137
138func highBits(val uint32) (n uint32) {
139 return uint32(bits.Len32(val) - 1)
140}