blob: 1eaedfa004a7ad74c72d25d4602466a77bc46749 [file] [log] [blame]
khenaidooefff76e2021-12-15 16:51:30 -05001// Package grpcdynamic provides a dynamic RPC stub. It can be used to invoke RPC
2// method where only method descriptors are known. The actual request and response
3// messages may be dynamic messages.
4package grpcdynamic
5
6import (
7 "fmt"
8 "io"
9
10 "github.com/golang/protobuf/proto"
11 "golang.org/x/net/context"
12 "google.golang.org/grpc"
13 "google.golang.org/grpc/metadata"
14
15 "github.com/jhump/protoreflect/desc"
16 "github.com/jhump/protoreflect/dynamic"
17)
18
19// Stub is an RPC client stub, used for dynamically dispatching RPCs to a server.
20type Stub struct {
21 channel Channel
22 mf *dynamic.MessageFactory
23}
24
25// Channel represents the operations necessary to issue RPCs via gRPC. The
26// *grpc.ClientConn type provides this interface and will typically the concrete
27// type used to construct Stubs. But the use of this interface allows
28// construction of stubs that use alternate concrete types as the transport for
29// RPC operations.
30type Channel interface {
31 Invoke(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error
32 NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error)
33}
34
35var _ Channel = (*grpc.ClientConn)(nil)
36
37// NewStub creates a new RPC stub that uses the given channel for dispatching RPCs.
38func NewStub(channel Channel) Stub {
39 return NewStubWithMessageFactory(channel, nil)
40}
41
42// NewStubWithMessageFactory creates a new RPC stub that uses the given channel for
43// dispatching RPCs and the given MessageFactory for creating response messages.
44func NewStubWithMessageFactory(channel Channel, mf *dynamic.MessageFactory) Stub {
45 return Stub{channel: channel, mf: mf}
46}
47
48func requestMethod(md *desc.MethodDescriptor) string {
49 return fmt.Sprintf("/%s/%s", md.GetService().GetFullyQualifiedName(), md.GetName())
50}
51
52// InvokeRpc sends a unary RPC and returns the response. Use this for unary methods.
53func (s Stub) InvokeRpc(ctx context.Context, method *desc.MethodDescriptor, request proto.Message, opts ...grpc.CallOption) (proto.Message, error) {
54 if method.IsClientStreaming() || method.IsServerStreaming() {
55 return nil, fmt.Errorf("InvokeRpc is for unary methods; %q is %s", method.GetFullyQualifiedName(), methodType(method))
56 }
57 if err := checkMessageType(method.GetInputType(), request); err != nil {
58 return nil, err
59 }
60 resp := s.mf.NewMessage(method.GetOutputType())
61 if err := s.channel.Invoke(ctx, requestMethod(method), request, resp, opts...); err != nil {
62 return nil, err
63 }
64 return resp, nil
65}
66
67// InvokeRpcServerStream sends a unary RPC and returns the response stream. Use this for server-streaming methods.
68func (s Stub) InvokeRpcServerStream(ctx context.Context, method *desc.MethodDescriptor, request proto.Message, opts ...grpc.CallOption) (*ServerStream, error) {
69 if method.IsClientStreaming() || !method.IsServerStreaming() {
70 return nil, fmt.Errorf("InvokeRpcServerStream is for server-streaming methods; %q is %s", method.GetFullyQualifiedName(), methodType(method))
71 }
72 if err := checkMessageType(method.GetInputType(), request); err != nil {
73 return nil, err
74 }
75 ctx, cancel := context.WithCancel(ctx)
76 sd := grpc.StreamDesc{
77 StreamName: method.GetName(),
78 ServerStreams: method.IsServerStreaming(),
79 ClientStreams: method.IsClientStreaming(),
80 }
81 if cs, err := s.channel.NewStream(ctx, &sd, requestMethod(method), opts...); err != nil {
82 return nil, err
83 } else {
84 err = cs.SendMsg(request)
85 if err != nil {
86 cancel()
87 return nil, err
88 }
89 err = cs.CloseSend()
90 if err != nil {
91 cancel()
92 return nil, err
93 }
94 return &ServerStream{cs, method.GetOutputType(), s.mf}, nil
95 }
96}
97
98// InvokeRpcClientStream creates a new stream that is used to send request messages and, at the end,
99// receive the response message. Use this for client-streaming methods.
100func (s Stub) InvokeRpcClientStream(ctx context.Context, method *desc.MethodDescriptor, opts ...grpc.CallOption) (*ClientStream, error) {
101 if !method.IsClientStreaming() || method.IsServerStreaming() {
102 return nil, fmt.Errorf("InvokeRpcClientStream is for client-streaming methods; %q is %s", method.GetFullyQualifiedName(), methodType(method))
103 }
104 ctx, cancel := context.WithCancel(ctx)
105 sd := grpc.StreamDesc{
106 StreamName: method.GetName(),
107 ServerStreams: method.IsServerStreaming(),
108 ClientStreams: method.IsClientStreaming(),
109 }
110 if cs, err := s.channel.NewStream(ctx, &sd, requestMethod(method), opts...); err != nil {
111 return nil, err
112 } else {
113 return &ClientStream{cs, method, s.mf, cancel}, nil
114 }
115}
116
117// InvokeRpcBidiStream creates a new stream that is used to both send request messages and receive response
118// messages. Use this for bidi-streaming methods.
119func (s Stub) InvokeRpcBidiStream(ctx context.Context, method *desc.MethodDescriptor, opts ...grpc.CallOption) (*BidiStream, error) {
120 if !method.IsClientStreaming() || !method.IsServerStreaming() {
121 return nil, fmt.Errorf("InvokeRpcBidiStream is for bidi-streaming methods; %q is %s", method.GetFullyQualifiedName(), methodType(method))
122 }
123 sd := grpc.StreamDesc{
124 StreamName: method.GetName(),
125 ServerStreams: method.IsServerStreaming(),
126 ClientStreams: method.IsClientStreaming(),
127 }
128 if cs, err := s.channel.NewStream(ctx, &sd, requestMethod(method), opts...); err != nil {
129 return nil, err
130 } else {
131 return &BidiStream{cs, method.GetInputType(), method.GetOutputType(), s.mf}, nil
132 }
133}
134
135func methodType(md *desc.MethodDescriptor) string {
136 if md.IsClientStreaming() && md.IsServerStreaming() {
137 return "bidi-streaming"
138 } else if md.IsClientStreaming() {
139 return "client-streaming"
140 } else if md.IsServerStreaming() {
141 return "server-streaming"
142 } else {
143 return "unary"
144 }
145}
146
147func checkMessageType(md *desc.MessageDescriptor, msg proto.Message) error {
148 var typeName string
149 if dm, ok := msg.(*dynamic.Message); ok {
150 typeName = dm.GetMessageDescriptor().GetFullyQualifiedName()
151 } else {
152 typeName = proto.MessageName(msg)
153 }
154 if typeName != md.GetFullyQualifiedName() {
155 return fmt.Errorf("expecting message of type %s; got %s", md.GetFullyQualifiedName(), typeName)
156 }
157 return nil
158}
159
160// ServerStream represents a response stream from a server. Messages in the stream can be queried
161// as can header and trailer metadata sent by the server.
162type ServerStream struct {
163 stream grpc.ClientStream
164 respType *desc.MessageDescriptor
165 mf *dynamic.MessageFactory
166}
167
168// Header returns any header metadata sent by the server (blocks if necessary until headers are
169// received).
170func (s *ServerStream) Header() (metadata.MD, error) {
171 return s.stream.Header()
172}
173
174// Trailer returns the trailer metadata sent by the server. It must only be called after
175// RecvMsg returns a non-nil error (which may be EOF for normal completion of stream).
176func (s *ServerStream) Trailer() metadata.MD {
177 return s.stream.Trailer()
178}
179
180// Context returns the context associated with this streaming operation.
181func (s *ServerStream) Context() context.Context {
182 return s.stream.Context()
183}
184
185// RecvMsg returns the next message in the response stream or an error. If the stream
186// has completed normally, the error is io.EOF. Otherwise, the error indicates the
187// nature of the abnormal termination of the stream.
188func (s *ServerStream) RecvMsg() (proto.Message, error) {
189 resp := s.mf.NewMessage(s.respType)
190 if err := s.stream.RecvMsg(resp); err != nil {
191 return nil, err
192 } else {
193 return resp, nil
194 }
195}
196
197// ClientStream represents a response stream from a client. Messages in the stream can be sent
198// and, when done, the unary server message and header and trailer metadata can be queried.
199type ClientStream struct {
200 stream grpc.ClientStream
201 method *desc.MethodDescriptor
202 mf *dynamic.MessageFactory
203 cancel context.CancelFunc
204}
205
206// Header returns any header metadata sent by the server (blocks if necessary until headers are
207// received).
208func (s *ClientStream) Header() (metadata.MD, error) {
209 return s.stream.Header()
210}
211
212// Trailer returns the trailer metadata sent by the server. It must only be called after
213// RecvMsg returns a non-nil error (which may be EOF for normal completion of stream).
214func (s *ClientStream) Trailer() metadata.MD {
215 return s.stream.Trailer()
216}
217
218// Context returns the context associated with this streaming operation.
219func (s *ClientStream) Context() context.Context {
220 return s.stream.Context()
221}
222
223// SendMsg sends a request message to the server.
224func (s *ClientStream) SendMsg(m proto.Message) error {
225 if err := checkMessageType(s.method.GetInputType(), m); err != nil {
226 return err
227 }
228 return s.stream.SendMsg(m)
229}
230
231// CloseAndReceive closes the outgoing request stream and then blocks for the server's response.
232func (s *ClientStream) CloseAndReceive() (proto.Message, error) {
233 if err := s.stream.CloseSend(); err != nil {
234 return nil, err
235 }
236 resp := s.mf.NewMessage(s.method.GetOutputType())
237 if err := s.stream.RecvMsg(resp); err != nil {
238 return nil, err
239 }
240 // make sure we get EOF for a second message
241 if err := s.stream.RecvMsg(resp); err != io.EOF {
242 if err == nil {
243 s.cancel()
244 return nil, fmt.Errorf("client-streaming method %q returned more than one response message", s.method.GetFullyQualifiedName())
245 } else {
246 return nil, err
247 }
248 }
249 return resp, nil
250}
251
252// BidiStream represents a bi-directional stream for sending messages to and receiving
253// messages from a server. The header and trailer metadata sent by the server can also be
254// queried.
255type BidiStream struct {
256 stream grpc.ClientStream
257 reqType *desc.MessageDescriptor
258 respType *desc.MessageDescriptor
259 mf *dynamic.MessageFactory
260}
261
262// Header returns any header metadata sent by the server (blocks if necessary until headers are
263// received).
264func (s *BidiStream) Header() (metadata.MD, error) {
265 return s.stream.Header()
266}
267
268// Trailer returns the trailer metadata sent by the server. It must only be called after
269// RecvMsg returns a non-nil error (which may be EOF for normal completion of stream).
270func (s *BidiStream) Trailer() metadata.MD {
271 return s.stream.Trailer()
272}
273
274// Context returns the context associated with this streaming operation.
275func (s *BidiStream) Context() context.Context {
276 return s.stream.Context()
277}
278
279// SendMsg sends a request message to the server.
280func (s *BidiStream) SendMsg(m proto.Message) error {
281 if err := checkMessageType(s.reqType, m); err != nil {
282 return err
283 }
284 return s.stream.SendMsg(m)
285}
286
287// CloseSend indicates the request stream has ended. Invoke this after all request messages
288// are sent (even if there are zero such messages).
289func (s *BidiStream) CloseSend() error {
290 return s.stream.CloseSend()
291}
292
293// RecvMsg returns the next message in the response stream or an error. If the stream
294// has completed normally, the error is io.EOF. Otherwise, the error indicates the
295// nature of the abnormal termination of the stream.
296func (s *BidiStream) RecvMsg() (proto.Message, error) {
297 resp := s.mf.NewMessage(s.respType)
298 if err := s.stream.RecvMsg(resp); err != nil {
299 return nil, err
300 } else {
301 return resp, nil
302 }
303}