William Kurkian | ea86948 | 2019-04-09 15:16:11 -0400 | [diff] [blame] | 1 | // Copyright 2014 The Go Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
| 4 | |
| 5 | package hpack |
| 6 | |
| 7 | import ( |
| 8 | "bytes" |
| 9 | "errors" |
| 10 | "io" |
| 11 | "sync" |
| 12 | ) |
| 13 | |
| 14 | var bufPool = sync.Pool{ |
| 15 | New: func() interface{} { return new(bytes.Buffer) }, |
| 16 | } |
| 17 | |
| 18 | // HuffmanDecode decodes the string in v and writes the expanded |
| 19 | // result to w, returning the number of bytes written to w and the |
| 20 | // Write call's return value. At most one Write call is made. |
| 21 | func HuffmanDecode(w io.Writer, v []byte) (int, error) { |
| 22 | buf := bufPool.Get().(*bytes.Buffer) |
| 23 | buf.Reset() |
| 24 | defer bufPool.Put(buf) |
| 25 | if err := huffmanDecode(buf, 0, v); err != nil { |
| 26 | return 0, err |
| 27 | } |
| 28 | return w.Write(buf.Bytes()) |
| 29 | } |
| 30 | |
| 31 | // HuffmanDecodeToString decodes the string in v. |
| 32 | func HuffmanDecodeToString(v []byte) (string, error) { |
| 33 | buf := bufPool.Get().(*bytes.Buffer) |
| 34 | buf.Reset() |
| 35 | defer bufPool.Put(buf) |
| 36 | if err := huffmanDecode(buf, 0, v); err != nil { |
| 37 | return "", err |
| 38 | } |
| 39 | return buf.String(), nil |
| 40 | } |
| 41 | |
| 42 | // ErrInvalidHuffman is returned for errors found decoding |
| 43 | // Huffman-encoded strings. |
| 44 | var ErrInvalidHuffman = errors.New("hpack: invalid Huffman-encoded data") |
| 45 | |
| 46 | // huffmanDecode decodes v to buf. |
| 47 | // If maxLen is greater than 0, attempts to write more to buf than |
| 48 | // maxLen bytes will return ErrStringLength. |
| 49 | func huffmanDecode(buf *bytes.Buffer, maxLen int, v []byte) error { |
| 50 | rootHuffmanNode := getRootHuffmanNode() |
| 51 | n := rootHuffmanNode |
| 52 | // cur is the bit buffer that has not been fed into n. |
| 53 | // cbits is the number of low order bits in cur that are valid. |
| 54 | // sbits is the number of bits of the symbol prefix being decoded. |
| 55 | cur, cbits, sbits := uint(0), uint8(0), uint8(0) |
| 56 | for _, b := range v { |
| 57 | cur = cur<<8 | uint(b) |
| 58 | cbits += 8 |
| 59 | sbits += 8 |
| 60 | for cbits >= 8 { |
| 61 | idx := byte(cur >> (cbits - 8)) |
| 62 | n = n.children[idx] |
| 63 | if n == nil { |
| 64 | return ErrInvalidHuffman |
| 65 | } |
| 66 | if n.children == nil { |
| 67 | if maxLen != 0 && buf.Len() == maxLen { |
| 68 | return ErrStringLength |
| 69 | } |
| 70 | buf.WriteByte(n.sym) |
| 71 | cbits -= n.codeLen |
| 72 | n = rootHuffmanNode |
| 73 | sbits = cbits |
| 74 | } else { |
| 75 | cbits -= 8 |
| 76 | } |
| 77 | } |
| 78 | } |
| 79 | for cbits > 0 { |
| 80 | n = n.children[byte(cur<<(8-cbits))] |
| 81 | if n == nil { |
| 82 | return ErrInvalidHuffman |
| 83 | } |
| 84 | if n.children != nil || n.codeLen > cbits { |
| 85 | break |
| 86 | } |
| 87 | if maxLen != 0 && buf.Len() == maxLen { |
| 88 | return ErrStringLength |
| 89 | } |
| 90 | buf.WriteByte(n.sym) |
| 91 | cbits -= n.codeLen |
| 92 | n = rootHuffmanNode |
| 93 | sbits = cbits |
| 94 | } |
| 95 | if sbits > 7 { |
| 96 | // Either there was an incomplete symbol, or overlong padding. |
| 97 | // Both are decoding errors per RFC 7541 section 5.2. |
| 98 | return ErrInvalidHuffman |
| 99 | } |
| 100 | if mask := uint(1<<cbits - 1); cur&mask != mask { |
| 101 | // Trailing bits must be a prefix of EOS per RFC 7541 section 5.2. |
| 102 | return ErrInvalidHuffman |
| 103 | } |
| 104 | |
| 105 | return nil |
| 106 | } |
| 107 | |
khenaidoo | 106c61a | 2021-08-11 18:05:46 -0400 | [diff] [blame^] | 108 | // incomparable is a zero-width, non-comparable type. Adding it to a struct |
| 109 | // makes that struct also non-comparable, and generally doesn't add |
| 110 | // any size (as long as it's first). |
| 111 | type incomparable [0]func() |
| 112 | |
William Kurkian | ea86948 | 2019-04-09 15:16:11 -0400 | [diff] [blame] | 113 | type node struct { |
khenaidoo | 106c61a | 2021-08-11 18:05:46 -0400 | [diff] [blame^] | 114 | _ incomparable |
| 115 | |
William Kurkian | ea86948 | 2019-04-09 15:16:11 -0400 | [diff] [blame] | 116 | // children is non-nil for internal nodes |
| 117 | children *[256]*node |
| 118 | |
| 119 | // The following are only valid if children is nil: |
| 120 | codeLen uint8 // number of bits that led to the output of sym |
| 121 | sym byte // output symbol |
| 122 | } |
| 123 | |
| 124 | func newInternalNode() *node { |
| 125 | return &node{children: new([256]*node)} |
| 126 | } |
| 127 | |
| 128 | var ( |
| 129 | buildRootOnce sync.Once |
| 130 | lazyRootHuffmanNode *node |
| 131 | ) |
| 132 | |
| 133 | func getRootHuffmanNode() *node { |
| 134 | buildRootOnce.Do(buildRootHuffmanNode) |
| 135 | return lazyRootHuffmanNode |
| 136 | } |
| 137 | |
| 138 | func buildRootHuffmanNode() { |
| 139 | if len(huffmanCodes) != 256 { |
| 140 | panic("unexpected size") |
| 141 | } |
| 142 | lazyRootHuffmanNode = newInternalNode() |
| 143 | for i, code := range huffmanCodes { |
| 144 | addDecoderNode(byte(i), code, huffmanCodeLen[i]) |
| 145 | } |
| 146 | } |
| 147 | |
| 148 | func addDecoderNode(sym byte, code uint32, codeLen uint8) { |
| 149 | cur := lazyRootHuffmanNode |
| 150 | for codeLen > 8 { |
| 151 | codeLen -= 8 |
| 152 | i := uint8(code >> codeLen) |
| 153 | if cur.children[i] == nil { |
| 154 | cur.children[i] = newInternalNode() |
| 155 | } |
| 156 | cur = cur.children[i] |
| 157 | } |
| 158 | shift := 8 - codeLen |
| 159 | start, end := int(uint8(code<<shift)), int(1<<shift) |
| 160 | for i := start; i < start+end; i++ { |
| 161 | cur.children[i] = &node{sym: sym, codeLen: codeLen} |
| 162 | } |
| 163 | } |
| 164 | |
| 165 | // AppendHuffmanString appends s, as encoded in Huffman codes, to dst |
| 166 | // and returns the extended buffer. |
| 167 | func AppendHuffmanString(dst []byte, s string) []byte { |
| 168 | rembits := uint8(8) |
| 169 | |
| 170 | for i := 0; i < len(s); i++ { |
| 171 | if rembits == 8 { |
| 172 | dst = append(dst, 0) |
| 173 | } |
| 174 | dst, rembits = appendByteToHuffmanCode(dst, rembits, s[i]) |
| 175 | } |
| 176 | |
| 177 | if rembits < 8 { |
| 178 | // special EOS symbol |
| 179 | code := uint32(0x3fffffff) |
| 180 | nbits := uint8(30) |
| 181 | |
| 182 | t := uint8(code >> (nbits - rembits)) |
| 183 | dst[len(dst)-1] |= t |
| 184 | } |
| 185 | |
| 186 | return dst |
| 187 | } |
| 188 | |
| 189 | // HuffmanEncodeLength returns the number of bytes required to encode |
| 190 | // s in Huffman codes. The result is round up to byte boundary. |
| 191 | func HuffmanEncodeLength(s string) uint64 { |
| 192 | n := uint64(0) |
| 193 | for i := 0; i < len(s); i++ { |
| 194 | n += uint64(huffmanCodeLen[s[i]]) |
| 195 | } |
| 196 | return (n + 7) / 8 |
| 197 | } |
| 198 | |
| 199 | // appendByteToHuffmanCode appends Huffman code for c to dst and |
| 200 | // returns the extended buffer and the remaining bits in the last |
| 201 | // element. The appending is not byte aligned and the remaining bits |
| 202 | // in the last element of dst is given in rembits. |
| 203 | func appendByteToHuffmanCode(dst []byte, rembits uint8, c byte) ([]byte, uint8) { |
| 204 | code := huffmanCodes[c] |
| 205 | nbits := huffmanCodeLen[c] |
| 206 | |
| 207 | for { |
| 208 | if rembits > nbits { |
| 209 | t := uint8(code << (rembits - nbits)) |
| 210 | dst[len(dst)-1] |= t |
| 211 | rembits -= nbits |
| 212 | break |
| 213 | } |
| 214 | |
| 215 | t := uint8(code >> (nbits - rembits)) |
| 216 | dst[len(dst)-1] |= t |
| 217 | |
| 218 | nbits -= rembits |
| 219 | rembits = 8 |
| 220 | |
| 221 | if nbits == 0 { |
| 222 | break |
| 223 | } |
| 224 | |
| 225 | dst = append(dst, 0) |
| 226 | } |
| 227 | |
| 228 | return dst, rembits |
| 229 | } |