blob: 813ffb1e84336da415256244366e09bcd65b6765 [file] [log] [blame]
khenaidooab1f7bd2019-11-14 14:00:27 -05001// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package websocket
6
7import (
8 "compress/flate"
9 "errors"
10 "io"
11 "strings"
12 "sync"
13)
14
15const (
16 minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6
17 maxCompressionLevel = flate.BestCompression
18 defaultCompressionLevel = 1
19)
20
21var (
22 flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool
23 flateReaderPool = sync.Pool{New: func() interface{} {
24 return flate.NewReader(nil)
25 }}
26)
27
28func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
29 const tail =
30 // Add four bytes as specified in RFC
31 "\x00\x00\xff\xff" +
32 // Add final block to squelch unexpected EOF error from flate reader.
33 "\x01\x00\x00\xff\xff"
34
35 fr, _ := flateReaderPool.Get().(io.ReadCloser)
36 fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
37 return &flateReadWrapper{fr}
38}
39
40func isValidCompressionLevel(level int) bool {
41 return minCompressionLevel <= level && level <= maxCompressionLevel
42}
43
44func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
45 p := &flateWriterPools[level-minCompressionLevel]
46 tw := &truncWriter{w: w}
47 fw, _ := p.Get().(*flate.Writer)
48 if fw == nil {
49 fw, _ = flate.NewWriter(tw, level)
50 } else {
51 fw.Reset(tw)
52 }
53 return &flateWriteWrapper{fw: fw, tw: tw, p: p}
54}
55
56// truncWriter is an io.Writer that writes all but the last four bytes of the
57// stream to another io.Writer.
58type truncWriter struct {
59 w io.WriteCloser
60 n int
61 p [4]byte
62}
63
64func (w *truncWriter) Write(p []byte) (int, error) {
65 n := 0
66
67 // fill buffer first for simplicity.
68 if w.n < len(w.p) {
69 n = copy(w.p[w.n:], p)
70 p = p[n:]
71 w.n += n
72 if len(p) == 0 {
73 return n, nil
74 }
75 }
76
77 m := len(p)
78 if m > len(w.p) {
79 m = len(w.p)
80 }
81
82 if nn, err := w.w.Write(w.p[:m]); err != nil {
83 return n + nn, err
84 }
85
86 copy(w.p[:], w.p[m:])
87 copy(w.p[len(w.p)-m:], p[len(p)-m:])
88 nn, err := w.w.Write(p[:len(p)-m])
89 return n + nn, err
90}
91
92type flateWriteWrapper struct {
93 fw *flate.Writer
94 tw *truncWriter
95 p *sync.Pool
96}
97
98func (w *flateWriteWrapper) Write(p []byte) (int, error) {
99 if w.fw == nil {
100 return 0, errWriteClosed
101 }
102 return w.fw.Write(p)
103}
104
105func (w *flateWriteWrapper) Close() error {
106 if w.fw == nil {
107 return errWriteClosed
108 }
109 err1 := w.fw.Flush()
110 w.p.Put(w.fw)
111 w.fw = nil
112 if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
113 return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
114 }
115 err2 := w.tw.w.Close()
116 if err1 != nil {
117 return err1
118 }
119 return err2
120}
121
122type flateReadWrapper struct {
123 fr io.ReadCloser
124}
125
126func (r *flateReadWrapper) Read(p []byte) (int, error) {
127 if r.fr == nil {
128 return 0, io.ErrClosedPipe
129 }
130 n, err := r.fr.Read(p)
131 if err == io.EOF {
132 // Preemptively place the reader back in the pool. This helps with
133 // scenarios where the application does not call NextReader() soon after
134 // this final read.
135 r.Close()
136 }
137 return n, err
138}
139
140func (r *flateReadWrapper) Close() error {
141 if r.fr == nil {
142 return io.ErrClosedPipe
143 }
144 err := r.fr.Close()
145 flateReaderPool.Put(r.fr)
146 r.fr = nil
147 return err
148}