blob: ea2c01fdadbf12b846b298a460675e185b0dbc2b [file] [log] [blame]
khenaidoo26721882021-08-11 17:42:52 -04001package thrift
2
3import (
4 "context"
5 "fmt"
6)
7
8// ResponseMeta represents the metadata attached to the response.
9type ResponseMeta struct {
10 // The headers in the response, if any.
11 // If the underlying transport/protocol is not THeader, this will always be nil.
12 Headers THeaderMap
13}
14
15type TClient interface {
16 Call(ctx context.Context, method string, args, result TStruct) (ResponseMeta, error)
17}
18
19type TStandardClient struct {
20 seqId int32
21 iprot, oprot TProtocol
22}
23
24// TStandardClient implements TClient, and uses the standard message format for Thrift.
25// It is not safe for concurrent use.
26func NewTStandardClient(inputProtocol, outputProtocol TProtocol) *TStandardClient {
27 return &TStandardClient{
28 iprot: inputProtocol,
29 oprot: outputProtocol,
30 }
31}
32
33func (p *TStandardClient) Send(ctx context.Context, oprot TProtocol, seqId int32, method string, args TStruct) error {
34 // Set headers from context object on THeaderProtocol
35 if headerProt, ok := oprot.(*THeaderProtocol); ok {
36 headerProt.ClearWriteHeaders()
37 for _, key := range GetWriteHeaderList(ctx) {
38 if value, ok := GetHeader(ctx, key); ok {
39 headerProt.SetWriteHeader(key, value)
40 }
41 }
42 }
43
44 if err := oprot.WriteMessageBegin(ctx, method, CALL, seqId); err != nil {
45 return err
46 }
47 if err := args.Write(ctx, oprot); err != nil {
48 return err
49 }
50 if err := oprot.WriteMessageEnd(ctx); err != nil {
51 return err
52 }
53 return oprot.Flush(ctx)
54}
55
56func (p *TStandardClient) Recv(ctx context.Context, iprot TProtocol, seqId int32, method string, result TStruct) error {
57 rMethod, rTypeId, rSeqId, err := iprot.ReadMessageBegin(ctx)
58 if err != nil {
59 return err
60 }
61
62 if method != rMethod {
63 return NewTApplicationException(WRONG_METHOD_NAME, fmt.Sprintf("%s: wrong method name", method))
64 } else if seqId != rSeqId {
65 return NewTApplicationException(BAD_SEQUENCE_ID, fmt.Sprintf("%s: out of order sequence response", method))
66 } else if rTypeId == EXCEPTION {
67 var exception tApplicationException
68 if err := exception.Read(ctx, iprot); err != nil {
69 return err
70 }
71
72 if err := iprot.ReadMessageEnd(ctx); err != nil {
73 return err
74 }
75
76 return &exception
77 } else if rTypeId != REPLY {
78 return NewTApplicationException(INVALID_MESSAGE_TYPE_EXCEPTION, fmt.Sprintf("%s: invalid message type", method))
79 }
80
81 if err := result.Read(ctx, iprot); err != nil {
82 return err
83 }
84
85 return iprot.ReadMessageEnd(ctx)
86}
87
88func (p *TStandardClient) Call(ctx context.Context, method string, args, result TStruct) (ResponseMeta, error) {
89 p.seqId++
90 seqId := p.seqId
91
92 if err := p.Send(ctx, p.oprot, seqId, method, args); err != nil {
93 return ResponseMeta{}, err
94 }
95
96 // method is oneway
97 if result == nil {
98 return ResponseMeta{}, nil
99 }
100
101 err := p.Recv(ctx, p.iprot, seqId, method, result)
102 var headers THeaderMap
103 if hp, ok := p.iprot.(*THeaderProtocol); ok {
104 headers = hp.transport.readHeaders
105 }
106 return ResponseMeta{
107 Headers: headers,
108 }, err
109}