blob: 690d341111b5f4e6558f23de4d450f7bc20b1c02 [file] [log] [blame]
Matteo Scandoloa4285862020-12-01 18:10:10 -08001/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20package thrift
21
22import (
23 "bytes"
24 "encoding/binary"
25 "errors"
26 "fmt"
27 "io"
28 "math"
29)
30
31type TBinaryProtocol struct {
32 trans TRichTransport
33 origTransport TTransport
34 reader io.Reader
35 writer io.Writer
36 strictRead bool
37 strictWrite bool
38 buffer [64]byte
39}
40
41type TBinaryProtocolFactory struct {
42 strictRead bool
43 strictWrite bool
44}
45
46func NewTBinaryProtocolTransport(t TTransport) *TBinaryProtocol {
47 return NewTBinaryProtocol(t, false, true)
48}
49
50func NewTBinaryProtocol(t TTransport, strictRead, strictWrite bool) *TBinaryProtocol {
51 p := &TBinaryProtocol{origTransport: t, strictRead: strictRead, strictWrite: strictWrite}
52 if et, ok := t.(TRichTransport); ok {
53 p.trans = et
54 } else {
55 p.trans = NewTRichTransport(t)
56 }
57 p.reader = p.trans
58 p.writer = p.trans
59 return p
60}
61
62func NewTBinaryProtocolFactoryDefault() *TBinaryProtocolFactory {
63 return NewTBinaryProtocolFactory(false, true)
64}
65
66func NewTBinaryProtocolFactory(strictRead, strictWrite bool) *TBinaryProtocolFactory {
67 return &TBinaryProtocolFactory{strictRead: strictRead, strictWrite: strictWrite}
68}
69
70func (p *TBinaryProtocolFactory) GetProtocol(t TTransport) TProtocol {
71 return NewTBinaryProtocol(t, p.strictRead, p.strictWrite)
72}
73
74/**
75 * Writing Methods
76 */
77
78func (p *TBinaryProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error {
79 if p.strictWrite {
80 version := uint32(VERSION_1) | uint32(typeId)
81 e := p.WriteI32(int32(version))
82 if e != nil {
83 return e
84 }
85 e = p.WriteString(name)
86 if e != nil {
87 return e
88 }
89 e = p.WriteI32(seqId)
90 return e
91 } else {
92 e := p.WriteString(name)
93 if e != nil {
94 return e
95 }
96 e = p.WriteByte(int8(typeId))
97 if e != nil {
98 return e
99 }
100 e = p.WriteI32(seqId)
101 return e
102 }
103 return nil
104}
105
106func (p *TBinaryProtocol) WriteMessageEnd() error {
107 return nil
108}
109
110func (p *TBinaryProtocol) WriteStructBegin(name string) error {
111 return nil
112}
113
114func (p *TBinaryProtocol) WriteStructEnd() error {
115 return nil
116}
117
118func (p *TBinaryProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
119 e := p.WriteByte(int8(typeId))
120 if e != nil {
121 return e
122 }
123 e = p.WriteI16(id)
124 return e
125}
126
127func (p *TBinaryProtocol) WriteFieldEnd() error {
128 return nil
129}
130
131func (p *TBinaryProtocol) WriteFieldStop() error {
132 e := p.WriteByte(STOP)
133 return e
134}
135
136func (p *TBinaryProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
137 e := p.WriteByte(int8(keyType))
138 if e != nil {
139 return e
140 }
141 e = p.WriteByte(int8(valueType))
142 if e != nil {
143 return e
144 }
145 e = p.WriteI32(int32(size))
146 return e
147}
148
149func (p *TBinaryProtocol) WriteMapEnd() error {
150 return nil
151}
152
153func (p *TBinaryProtocol) WriteListBegin(elemType TType, size int) error {
154 e := p.WriteByte(int8(elemType))
155 if e != nil {
156 return e
157 }
158 e = p.WriteI32(int32(size))
159 return e
160}
161
162func (p *TBinaryProtocol) WriteListEnd() error {
163 return nil
164}
165
166func (p *TBinaryProtocol) WriteSetBegin(elemType TType, size int) error {
167 e := p.WriteByte(int8(elemType))
168 if e != nil {
169 return e
170 }
171 e = p.WriteI32(int32(size))
172 return e
173}
174
175func (p *TBinaryProtocol) WriteSetEnd() error {
176 return nil
177}
178
179func (p *TBinaryProtocol) WriteBool(value bool) error {
180 if value {
181 return p.WriteByte(1)
182 }
183 return p.WriteByte(0)
184}
185
186func (p *TBinaryProtocol) WriteByte(value int8) error {
187 e := p.trans.WriteByte(byte(value))
188 return NewTProtocolException(e)
189}
190
191func (p *TBinaryProtocol) WriteI16(value int16) error {
192 v := p.buffer[0:2]
193 binary.BigEndian.PutUint16(v, uint16(value))
194 _, e := p.writer.Write(v)
195 return NewTProtocolException(e)
196}
197
198func (p *TBinaryProtocol) WriteI32(value int32) error {
199 v := p.buffer[0:4]
200 binary.BigEndian.PutUint32(v, uint32(value))
201 _, e := p.writer.Write(v)
202 return NewTProtocolException(e)
203}
204
205func (p *TBinaryProtocol) WriteI64(value int64) error {
206 v := p.buffer[0:8]
207 binary.BigEndian.PutUint64(v, uint64(value))
208 _, err := p.writer.Write(v)
209 return NewTProtocolException(err)
210}
211
212func (p *TBinaryProtocol) WriteDouble(value float64) error {
213 return p.WriteI64(int64(math.Float64bits(value)))
214}
215
216func (p *TBinaryProtocol) WriteString(value string) error {
217 e := p.WriteI32(int32(len(value)))
218 if e != nil {
219 return e
220 }
221 _, err := p.trans.WriteString(value)
222 return NewTProtocolException(err)
223}
224
225func (p *TBinaryProtocol) WriteBinary(value []byte) error {
226 e := p.WriteI32(int32(len(value)))
227 if e != nil {
228 return e
229 }
230 _, err := p.writer.Write(value)
231 return NewTProtocolException(err)
232}
233
234/**
235 * Reading methods
236 */
237
238func (p *TBinaryProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) {
239 size, e := p.ReadI32()
240 if e != nil {
241 return "", typeId, 0, NewTProtocolException(e)
242 }
243 if size < 0 {
244 typeId = TMessageType(size & 0x0ff)
245 version := int64(int64(size) & VERSION_MASK)
246 if version != VERSION_1 {
247 return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Bad version in ReadMessageBegin"))
248 }
249 name, e = p.ReadString()
250 if e != nil {
251 return name, typeId, seqId, NewTProtocolException(e)
252 }
253 seqId, e = p.ReadI32()
254 if e != nil {
255 return name, typeId, seqId, NewTProtocolException(e)
256 }
257 return name, typeId, seqId, nil
258 }
259 if p.strictRead {
260 return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Missing version in ReadMessageBegin"))
261 }
262 name, e2 := p.readStringBody(size)
263 if e2 != nil {
264 return name, typeId, seqId, e2
265 }
266 b, e3 := p.ReadByte()
267 if e3 != nil {
268 return name, typeId, seqId, e3
269 }
270 typeId = TMessageType(b)
271 seqId, e4 := p.ReadI32()
272 if e4 != nil {
273 return name, typeId, seqId, e4
274 }
275 return name, typeId, seqId, nil
276}
277
278func (p *TBinaryProtocol) ReadMessageEnd() error {
279 return nil
280}
281
282func (p *TBinaryProtocol) ReadStructBegin() (name string, err error) {
283 return
284}
285
286func (p *TBinaryProtocol) ReadStructEnd() error {
287 return nil
288}
289
290func (p *TBinaryProtocol) ReadFieldBegin() (name string, typeId TType, seqId int16, err error) {
291 t, err := p.ReadByte()
292 typeId = TType(t)
293 if err != nil {
294 return name, typeId, seqId, err
295 }
296 if t != STOP {
297 seqId, err = p.ReadI16()
298 }
299 return name, typeId, seqId, err
300}
301
302func (p *TBinaryProtocol) ReadFieldEnd() error {
303 return nil
304}
305
306var invalidDataLength = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Invalid data length"))
307
308func (p *TBinaryProtocol) ReadMapBegin() (kType, vType TType, size int, err error) {
309 k, e := p.ReadByte()
310 if e != nil {
311 err = NewTProtocolException(e)
312 return
313 }
314 kType = TType(k)
315 v, e := p.ReadByte()
316 if e != nil {
317 err = NewTProtocolException(e)
318 return
319 }
320 vType = TType(v)
321 size32, e := p.ReadI32()
322 if e != nil {
323 err = NewTProtocolException(e)
324 return
325 }
326 if size32 < 0 {
327 err = invalidDataLength
328 return
329 }
330 size = int(size32)
331 return kType, vType, size, nil
332}
333
334func (p *TBinaryProtocol) ReadMapEnd() error {
335 return nil
336}
337
338func (p *TBinaryProtocol) ReadListBegin() (elemType TType, size int, err error) {
339 b, e := p.ReadByte()
340 if e != nil {
341 err = NewTProtocolException(e)
342 return
343 }
344 elemType = TType(b)
345 size32, e := p.ReadI32()
346 if e != nil {
347 err = NewTProtocolException(e)
348 return
349 }
350 if size32 < 0 {
351 err = invalidDataLength
352 return
353 }
354 size = int(size32)
355
356 return
357}
358
359func (p *TBinaryProtocol) ReadListEnd() error {
360 return nil
361}
362
363func (p *TBinaryProtocol) ReadSetBegin() (elemType TType, size int, err error) {
364 b, e := p.ReadByte()
365 if e != nil {
366 err = NewTProtocolException(e)
367 return
368 }
369 elemType = TType(b)
370 size32, e := p.ReadI32()
371 if e != nil {
372 err = NewTProtocolException(e)
373 return
374 }
375 if size32 < 0 {
376 err = invalidDataLength
377 return
378 }
379 size = int(size32)
380 return elemType, size, nil
381}
382
383func (p *TBinaryProtocol) ReadSetEnd() error {
384 return nil
385}
386
387func (p *TBinaryProtocol) ReadBool() (bool, error) {
388 b, e := p.ReadByte()
389 v := true
390 if b != 1 {
391 v = false
392 }
393 return v, e
394}
395
396func (p *TBinaryProtocol) ReadByte() (int8, error) {
397 v, err := p.trans.ReadByte()
398 return int8(v), err
399}
400
401func (p *TBinaryProtocol) ReadI16() (value int16, err error) {
402 buf := p.buffer[0:2]
403 err = p.readAll(buf)
404 value = int16(binary.BigEndian.Uint16(buf))
405 return value, err
406}
407
408func (p *TBinaryProtocol) ReadI32() (value int32, err error) {
409 buf := p.buffer[0:4]
410 err = p.readAll(buf)
411 value = int32(binary.BigEndian.Uint32(buf))
412 return value, err
413}
414
415func (p *TBinaryProtocol) ReadI64() (value int64, err error) {
416 buf := p.buffer[0:8]
417 err = p.readAll(buf)
418 value = int64(binary.BigEndian.Uint64(buf))
419 return value, err
420}
421
422func (p *TBinaryProtocol) ReadDouble() (value float64, err error) {
423 buf := p.buffer[0:8]
424 err = p.readAll(buf)
425 value = math.Float64frombits(binary.BigEndian.Uint64(buf))
426 return value, err
427}
428
429func (p *TBinaryProtocol) ReadString() (value string, err error) {
430 size, e := p.ReadI32()
431 if e != nil {
432 return "", e
433 }
434 if size < 0 {
435 err = invalidDataLength
436 return
437 }
438
439 return p.readStringBody(size)
440}
441
442func (p *TBinaryProtocol) ReadBinary() ([]byte, error) {
443 size, e := p.ReadI32()
444 if e != nil {
445 return nil, e
446 }
447 if size < 0 {
448 return nil, invalidDataLength
449 }
450 if uint64(size) > p.trans.RemainingBytes() {
451 return nil, invalidDataLength
452 }
453
454 isize := int(size)
455 buf := make([]byte, isize)
456 _, err := io.ReadFull(p.trans, buf)
457 return buf, NewTProtocolException(err)
458}
459
460func (p *TBinaryProtocol) Flush() (err error) {
461 return NewTProtocolException(p.trans.Flush())
462}
463
464func (p *TBinaryProtocol) Skip(fieldType TType) (err error) {
465 return SkipDefaultDepth(p, fieldType)
466}
467
468func (p *TBinaryProtocol) Transport() TTransport {
469 return p.origTransport
470}
471
472func (p *TBinaryProtocol) readAll(buf []byte) error {
473 _, err := io.ReadFull(p.reader, buf)
474 return NewTProtocolException(err)
475}
476
477const readLimit = 32768
478
479func (p *TBinaryProtocol) readStringBody(size int32) (value string, err error) {
480 if size < 0 {
481 return "", nil
482 }
483 if uint64(size) > p.trans.RemainingBytes() {
484 return "", invalidDataLength
485 }
486
487 var (
488 buf bytes.Buffer
489 e error
490 b []byte
491 )
492
493 switch {
494 case int(size) <= len(p.buffer):
495 b = p.buffer[:size] // avoids allocation for small reads
496 case int(size) < readLimit:
497 b = make([]byte, size)
498 default:
499 b = make([]byte, readLimit)
500 }
501
502 for size > 0 {
503 _, e = io.ReadFull(p.trans, b)
504 buf.Write(b)
505 if e != nil {
506 break
507 }
508 size -= readLimit
509 if size < readLimit && size > 0 {
510 b = b[:size]
511 }
512 }
513 return buf.String(), NewTProtocolException(e)
514}