blob: 9f3e9f79e2472c9c20fb86d381b64c75eb2784ad [file] [log] [blame]
Akash Reddy Kankanalac0014632025-05-21 17:12:20 +05301//go:build amd64 && !appengine && !noasm && gc
2// +build amd64,!appengine,!noasm,gc
3
4// This file contains the specialisation of Decoder.Decompress4X
5// and Decoder.Decompress1X that use an asm implementation of thir main loops.
6package huff0
7
8import (
9 "errors"
10 "fmt"
11
12 "github.com/klauspost/compress/internal/cpuinfo"
13)
14
15// decompress4x_main_loop_x86 is an x86 assembler implementation
16// of Decompress4X when tablelog > 8.
17//go:noescape
18func decompress4x_main_loop_amd64(ctx *decompress4xContext)
19
20// decompress4x_8b_loop_x86 is an x86 assembler implementation
21// of Decompress4X when tablelog <= 8 which decodes 4 entries
22// per loop.
23//go:noescape
24func decompress4x_8b_main_loop_amd64(ctx *decompress4xContext)
25
26// fallback8BitSize is the size where using Go version is faster.
27const fallback8BitSize = 800
28
29type decompress4xContext struct {
30 pbr *[4]bitReaderShifted
31 peekBits uint8
32 out *byte
33 dstEvery int
34 tbl *dEntrySingle
35 decoded int
36 limit *byte
37}
38
39// Decompress4X will decompress a 4X encoded stream.
40// The length of the supplied input must match the end of a block exactly.
41// The *capacity* of the dst slice must match the destination size of
42// the uncompressed data exactly.
43func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
44 if len(d.dt.single) == 0 {
45 return nil, errors.New("no table loaded")
46 }
47 if len(src) < 6+(4*1) {
48 return nil, errors.New("input too small")
49 }
50
51 use8BitTables := d.actualTableLog <= 8
52 if cap(dst) < fallback8BitSize && use8BitTables {
53 return d.decompress4X8bit(dst, src)
54 }
55
56 var br [4]bitReaderShifted
57 // Decode "jump table"
58 start := 6
59 for i := 0; i < 3; i++ {
60 length := int(src[i*2]) | (int(src[i*2+1]) << 8)
61 if start+length >= len(src) {
62 return nil, errors.New("truncated input (or invalid offset)")
63 }
64 err := br[i].init(src[start : start+length])
65 if err != nil {
66 return nil, err
67 }
68 start += length
69 }
70 err := br[3].init(src[start:])
71 if err != nil {
72 return nil, err
73 }
74
75 // destination, offset to match first output
76 dstSize := cap(dst)
77 dst = dst[:dstSize]
78 out := dst
79 dstEvery := (dstSize + 3) / 4
80
81 const tlSize = 1 << tableLogMax
82 const tlMask = tlSize - 1
83 single := d.dt.single[:tlSize]
84
85 var decoded int
86
87 if len(out) > 4*4 && !(br[0].off < 4 || br[1].off < 4 || br[2].off < 4 || br[3].off < 4) {
88 ctx := decompress4xContext{
89 pbr: &br,
90 peekBits: uint8((64 - d.actualTableLog) & 63), // see: bitReaderShifted.peekBitsFast()
91 out: &out[0],
92 dstEvery: dstEvery,
93 tbl: &single[0],
94 limit: &out[dstEvery-4], // Always stop decoding when first buffer gets here to avoid writing OOB on last.
95 }
96 if use8BitTables {
97 decompress4x_8b_main_loop_amd64(&ctx)
98 } else {
99 decompress4x_main_loop_amd64(&ctx)
100 }
101
102 decoded = ctx.decoded
103 out = out[decoded/4:]
104 }
105
106 // Decode remaining.
107 remainBytes := dstEvery - (decoded / 4)
108 for i := range br {
109 offset := dstEvery * i
110 endsAt := offset + remainBytes
111 if endsAt > len(out) {
112 endsAt = len(out)
113 }
114 br := &br[i]
115 bitsLeft := br.remaining()
116 for bitsLeft > 0 {
117 br.fill()
118 if offset >= endsAt {
119 return nil, errors.New("corruption detected: stream overrun 4")
120 }
121
122 // Read value and increment offset.
123 val := br.peekBitsFast(d.actualTableLog)
124 v := single[val&tlMask].entry
125 nBits := uint8(v)
126 br.advance(nBits)
127 bitsLeft -= uint(nBits)
128 out[offset] = uint8(v >> 8)
129 offset++
130 }
131 if offset != endsAt {
132 return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt)
133 }
134 decoded += offset - dstEvery*i
135 err = br.close()
136 if err != nil {
137 return nil, err
138 }
139 }
140 if dstSize != decoded {
141 return nil, errors.New("corruption detected: short output block")
142 }
143 return dst, nil
144}
145
146// decompress4x_main_loop_x86 is an x86 assembler implementation
147// of Decompress1X when tablelog > 8.
148//go:noescape
149func decompress1x_main_loop_amd64(ctx *decompress1xContext)
150
151// decompress4x_main_loop_x86 is an x86 with BMI2 assembler implementation
152// of Decompress1X when tablelog > 8.
153//go:noescape
154func decompress1x_main_loop_bmi2(ctx *decompress1xContext)
155
156type decompress1xContext struct {
157 pbr *bitReaderShifted
158 peekBits uint8
159 out *byte
160 outCap int
161 tbl *dEntrySingle
162 decoded int
163}
164
165// Error reported by asm implementations
166const error_max_decoded_size_exeeded = -1
167
168// Decompress1X will decompress a 1X encoded stream.
169// The cap of the output buffer will be the maximum decompressed size.
170// The length of the supplied input must match the end of a block exactly.
171func (d *Decoder) Decompress1X(dst, src []byte) ([]byte, error) {
172 if len(d.dt.single) == 0 {
173 return nil, errors.New("no table loaded")
174 }
175 var br bitReaderShifted
176 err := br.init(src)
177 if err != nil {
178 return dst, err
179 }
180 maxDecodedSize := cap(dst)
181 dst = dst[:maxDecodedSize]
182
183 const tlSize = 1 << tableLogMax
184 const tlMask = tlSize - 1
185
186 if maxDecodedSize >= 4 {
187 ctx := decompress1xContext{
188 pbr: &br,
189 out: &dst[0],
190 outCap: maxDecodedSize,
191 peekBits: uint8((64 - d.actualTableLog) & 63), // see: bitReaderShifted.peekBitsFast()
192 tbl: &d.dt.single[0],
193 }
194
195 if cpuinfo.HasBMI2() {
196 decompress1x_main_loop_bmi2(&ctx)
197 } else {
198 decompress1x_main_loop_amd64(&ctx)
199 }
200 if ctx.decoded == error_max_decoded_size_exeeded {
201 return nil, ErrMaxDecodedSizeExceeded
202 }
203
204 dst = dst[:ctx.decoded]
205 }
206
207 // br < 8, so uint8 is fine
208 bitsLeft := uint8(br.off)*8 + 64 - br.bitsRead
209 for bitsLeft > 0 {
210 br.fill()
211 if len(dst) >= maxDecodedSize {
212 br.close()
213 return nil, ErrMaxDecodedSizeExceeded
214 }
215 v := d.dt.single[br.peekBitsFast(d.actualTableLog)&tlMask]
216 nBits := uint8(v.entry)
217 br.advance(nBits)
218 bitsLeft -= nBits
219 dst = append(dst, uint8(v.entry>>8))
220 }
221 return dst, br.close()
222}