David K. Bainbridge | 215e024 | 2017-09-05 23:18:24 -0700 | [diff] [blame] | 1 | // Copyright 2014 The Go Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
| 4 | |
| 5 | package hpack |
| 6 | |
| 7 | import ( |
| 8 | "bytes" |
| 9 | "errors" |
| 10 | "io" |
| 11 | "sync" |
| 12 | ) |
| 13 | |
| 14 | var bufPool = sync.Pool{ |
| 15 | New: func() interface{} { return new(bytes.Buffer) }, |
| 16 | } |
| 17 | |
| 18 | // HuffmanDecode decodes the string in v and writes the expanded |
| 19 | // result to w, returning the number of bytes written to w and the |
| 20 | // Write call's return value. At most one Write call is made. |
| 21 | func HuffmanDecode(w io.Writer, v []byte) (int, error) { |
| 22 | buf := bufPool.Get().(*bytes.Buffer) |
| 23 | buf.Reset() |
| 24 | defer bufPool.Put(buf) |
| 25 | if err := huffmanDecode(buf, 0, v); err != nil { |
| 26 | return 0, err |
| 27 | } |
| 28 | return w.Write(buf.Bytes()) |
| 29 | } |
| 30 | |
| 31 | // HuffmanDecodeToString decodes the string in v. |
| 32 | func HuffmanDecodeToString(v []byte) (string, error) { |
| 33 | buf := bufPool.Get().(*bytes.Buffer) |
| 34 | buf.Reset() |
| 35 | defer bufPool.Put(buf) |
| 36 | if err := huffmanDecode(buf, 0, v); err != nil { |
| 37 | return "", err |
| 38 | } |
| 39 | return buf.String(), nil |
| 40 | } |
| 41 | |
| 42 | // ErrInvalidHuffman is returned for errors found decoding |
| 43 | // Huffman-encoded strings. |
| 44 | var ErrInvalidHuffman = errors.New("hpack: invalid Huffman-encoded data") |
| 45 | |
| 46 | // huffmanDecode decodes v to buf. |
| 47 | // If maxLen is greater than 0, attempts to write more to buf than |
| 48 | // maxLen bytes will return ErrStringLength. |
| 49 | func huffmanDecode(buf *bytes.Buffer, maxLen int, v []byte) error { |
| 50 | n := rootHuffmanNode |
| 51 | // cur is the bit buffer that has not been fed into n. |
| 52 | // cbits is the number of low order bits in cur that are valid. |
| 53 | // sbits is the number of bits of the symbol prefix being decoded. |
| 54 | cur, cbits, sbits := uint(0), uint8(0), uint8(0) |
| 55 | for _, b := range v { |
| 56 | cur = cur<<8 | uint(b) |
| 57 | cbits += 8 |
| 58 | sbits += 8 |
| 59 | for cbits >= 8 { |
| 60 | idx := byte(cur >> (cbits - 8)) |
| 61 | n = n.children[idx] |
| 62 | if n == nil { |
| 63 | return ErrInvalidHuffman |
| 64 | } |
| 65 | if n.children == nil { |
| 66 | if maxLen != 0 && buf.Len() == maxLen { |
| 67 | return ErrStringLength |
| 68 | } |
| 69 | buf.WriteByte(n.sym) |
| 70 | cbits -= n.codeLen |
| 71 | n = rootHuffmanNode |
| 72 | sbits = cbits |
| 73 | } else { |
| 74 | cbits -= 8 |
| 75 | } |
| 76 | } |
| 77 | } |
| 78 | for cbits > 0 { |
| 79 | n = n.children[byte(cur<<(8-cbits))] |
| 80 | if n == nil { |
| 81 | return ErrInvalidHuffman |
| 82 | } |
| 83 | if n.children != nil || n.codeLen > cbits { |
| 84 | break |
| 85 | } |
| 86 | if maxLen != 0 && buf.Len() == maxLen { |
| 87 | return ErrStringLength |
| 88 | } |
| 89 | buf.WriteByte(n.sym) |
| 90 | cbits -= n.codeLen |
| 91 | n = rootHuffmanNode |
| 92 | sbits = cbits |
| 93 | } |
| 94 | if sbits > 7 { |
| 95 | // Either there was an incomplete symbol, or overlong padding. |
| 96 | // Both are decoding errors per RFC 7541 section 5.2. |
| 97 | return ErrInvalidHuffman |
| 98 | } |
| 99 | if mask := uint(1<<cbits - 1); cur&mask != mask { |
| 100 | // Trailing bits must be a prefix of EOS per RFC 7541 section 5.2. |
| 101 | return ErrInvalidHuffman |
| 102 | } |
| 103 | |
| 104 | return nil |
| 105 | } |
| 106 | |
| 107 | type node struct { |
| 108 | // children is non-nil for internal nodes |
| 109 | children []*node |
| 110 | |
| 111 | // The following are only valid if children is nil: |
| 112 | codeLen uint8 // number of bits that led to the output of sym |
| 113 | sym byte // output symbol |
| 114 | } |
| 115 | |
| 116 | func newInternalNode() *node { |
| 117 | return &node{children: make([]*node, 256)} |
| 118 | } |
| 119 | |
| 120 | var rootHuffmanNode = newInternalNode() |
| 121 | |
| 122 | func init() { |
| 123 | if len(huffmanCodes) != 256 { |
| 124 | panic("unexpected size") |
| 125 | } |
| 126 | for i, code := range huffmanCodes { |
| 127 | addDecoderNode(byte(i), code, huffmanCodeLen[i]) |
| 128 | } |
| 129 | } |
| 130 | |
| 131 | func addDecoderNode(sym byte, code uint32, codeLen uint8) { |
| 132 | cur := rootHuffmanNode |
| 133 | for codeLen > 8 { |
| 134 | codeLen -= 8 |
| 135 | i := uint8(code >> codeLen) |
| 136 | if cur.children[i] == nil { |
| 137 | cur.children[i] = newInternalNode() |
| 138 | } |
| 139 | cur = cur.children[i] |
| 140 | } |
| 141 | shift := 8 - codeLen |
| 142 | start, end := int(uint8(code<<shift)), int(1<<shift) |
| 143 | for i := start; i < start+end; i++ { |
| 144 | cur.children[i] = &node{sym: sym, codeLen: codeLen} |
| 145 | } |
| 146 | } |
| 147 | |
| 148 | // AppendHuffmanString appends s, as encoded in Huffman codes, to dst |
| 149 | // and returns the extended buffer. |
| 150 | func AppendHuffmanString(dst []byte, s string) []byte { |
| 151 | rembits := uint8(8) |
| 152 | |
| 153 | for i := 0; i < len(s); i++ { |
| 154 | if rembits == 8 { |
| 155 | dst = append(dst, 0) |
| 156 | } |
| 157 | dst, rembits = appendByteToHuffmanCode(dst, rembits, s[i]) |
| 158 | } |
| 159 | |
| 160 | if rembits < 8 { |
| 161 | // special EOS symbol |
| 162 | code := uint32(0x3fffffff) |
| 163 | nbits := uint8(30) |
| 164 | |
| 165 | t := uint8(code >> (nbits - rembits)) |
| 166 | dst[len(dst)-1] |= t |
| 167 | } |
| 168 | |
| 169 | return dst |
| 170 | } |
| 171 | |
| 172 | // HuffmanEncodeLength returns the number of bytes required to encode |
| 173 | // s in Huffman codes. The result is round up to byte boundary. |
| 174 | func HuffmanEncodeLength(s string) uint64 { |
| 175 | n := uint64(0) |
| 176 | for i := 0; i < len(s); i++ { |
| 177 | n += uint64(huffmanCodeLen[s[i]]) |
| 178 | } |
| 179 | return (n + 7) / 8 |
| 180 | } |
| 181 | |
| 182 | // appendByteToHuffmanCode appends Huffman code for c to dst and |
| 183 | // returns the extended buffer and the remaining bits in the last |
| 184 | // element. The appending is not byte aligned and the remaining bits |
| 185 | // in the last element of dst is given in rembits. |
| 186 | func appendByteToHuffmanCode(dst []byte, rembits uint8, c byte) ([]byte, uint8) { |
| 187 | code := huffmanCodes[c] |
| 188 | nbits := huffmanCodeLen[c] |
| 189 | |
| 190 | for { |
| 191 | if rembits > nbits { |
| 192 | t := uint8(code << (rembits - nbits)) |
| 193 | dst[len(dst)-1] |= t |
| 194 | rembits -= nbits |
| 195 | break |
| 196 | } |
| 197 | |
| 198 | t := uint8(code >> (nbits - rembits)) |
| 199 | dst[len(dst)-1] |= t |
| 200 | |
| 201 | nbits -= rembits |
| 202 | rembits = 8 |
| 203 | |
| 204 | if nbits == 0 { |
| 205 | break |
| 206 | } |
| 207 | |
| 208 | dst = append(dst, 0) |
| 209 | } |
| 210 | |
| 211 | return dst, rembits |
| 212 | } |