blob: e6e8f286e1294928404cbd7bcd06bd06cab59308 [file] [log] [blame]
khenaidooab1f7bd2019-11-14 14:00:27 -05001package runtime
2
3import (
khenaidood948f772021-08-11 17:49:24 -04004 "context"
khenaidooab1f7bd2019-11-14 14:00:27 -05005 "errors"
6 "fmt"
7 "io"
8 "net/http"
9 "net/textproto"
10
khenaidooab1f7bd2019-11-14 14:00:27 -050011 "github.com/golang/protobuf/proto"
12 "github.com/grpc-ecosystem/grpc-gateway/internal"
13 "google.golang.org/grpc/grpclog"
14)
15
16var errEmptyResponse = errors.New("empty response")
17
18// ForwardResponseStream forwards the stream from gRPC server to REST client.
19func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
20 f, ok := w.(http.Flusher)
21 if !ok {
22 grpclog.Infof("Flush not supported in %T", w)
23 http.Error(w, "unexpected type of web server", http.StatusInternalServerError)
24 return
25 }
26
27 md, ok := ServerMetadataFromContext(ctx)
28 if !ok {
29 grpclog.Infof("Failed to extract ServerMetadata from context")
30 http.Error(w, "unexpected error", http.StatusInternalServerError)
31 return
32 }
33 handleForwardResponseServerMetadata(w, mux, md)
34
35 w.Header().Set("Transfer-Encoding", "chunked")
36 w.Header().Set("Content-Type", marshaler.ContentType())
37 if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
38 HTTPError(ctx, mux, marshaler, w, req, err)
39 return
40 }
41
42 var delimiter []byte
43 if d, ok := marshaler.(Delimited); ok {
44 delimiter = d.Delimiter()
45 } else {
46 delimiter = []byte("\n")
47 }
48
49 var wroteHeader bool
50 for {
51 resp, err := recv()
52 if err == io.EOF {
53 return
54 }
55 if err != nil {
56 handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
57 return
58 }
59 if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
60 handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
61 return
62 }
63
khenaidood948f772021-08-11 17:49:24 -040064 var buf []byte
65 switch {
66 case resp == nil:
67 buf, err = marshaler.Marshal(errorChunk(streamError(ctx, mux.streamErrorHandler, errEmptyResponse)))
68 default:
69 result := map[string]interface{}{"result": resp}
70 if rb, ok := resp.(responseBody); ok {
71 result["result"] = rb.XXX_ResponseBody()
72 }
73
74 buf, err = marshaler.Marshal(result)
75 }
76
khenaidooab1f7bd2019-11-14 14:00:27 -050077 if err != nil {
78 grpclog.Infof("Failed to marshal response chunk: %v", err)
79 handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
80 return
81 }
82 if _, err = w.Write(buf); err != nil {
83 grpclog.Infof("Failed to send response chunk: %v", err)
84 return
85 }
86 wroteHeader = true
87 if _, err = w.Write(delimiter); err != nil {
88 grpclog.Infof("Failed to send delimiter chunk: %v", err)
89 return
90 }
91 f.Flush()
92 }
93}
94
95func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
96 for k, vs := range md.HeaderMD {
97 if h, ok := mux.outgoingHeaderMatcher(k); ok {
98 for _, v := range vs {
99 w.Header().Add(h, v)
100 }
101 }
102 }
103}
104
105func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) {
106 for k := range md.TrailerMD {
107 tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", MetadataTrailerPrefix, k))
108 w.Header().Add("Trailer", tKey)
109 }
110}
111
112func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) {
113 for k, vs := range md.TrailerMD {
114 tKey := fmt.Sprintf("%s%s", MetadataTrailerPrefix, k)
115 for _, v := range vs {
116 w.Header().Add(tKey, v)
117 }
118 }
119}
120
121// responseBody interface contains method for getting field for marshaling to the response body
122// this method is generated for response struct from the value of `response_body` in the `google.api.HttpRule`
123type responseBody interface {
124 XXX_ResponseBody() interface{}
125}
126
127// ForwardResponseMessage forwards the message "resp" from gRPC server to REST client.
128func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
129 md, ok := ServerMetadataFromContext(ctx)
130 if !ok {
131 grpclog.Infof("Failed to extract ServerMetadata from context")
132 }
133
134 handleForwardResponseServerMetadata(w, mux, md)
135 handleForwardResponseTrailerHeader(w, md)
136
137 contentType := marshaler.ContentType()
khenaidood948f772021-08-11 17:49:24 -0400138 // Check marshaler on run time in order to keep backwards compatibility
khenaidooab1f7bd2019-11-14 14:00:27 -0500139 // An interface param needs to be added to the ContentType() function on
140 // the Marshal interface to be able to remove this check
khenaidood948f772021-08-11 17:49:24 -0400141 if typeMarshaler, ok := marshaler.(contentTypeMarshaler); ok {
142 contentType = typeMarshaler.ContentTypeFromMessage(resp)
khenaidooab1f7bd2019-11-14 14:00:27 -0500143 }
144 w.Header().Set("Content-Type", contentType)
145
146 if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
147 HTTPError(ctx, mux, marshaler, w, req, err)
148 return
149 }
150 var buf []byte
151 var err error
152 if rb, ok := resp.(responseBody); ok {
153 buf, err = marshaler.Marshal(rb.XXX_ResponseBody())
154 } else {
155 buf, err = marshaler.Marshal(resp)
156 }
157 if err != nil {
158 grpclog.Infof("Marshal error: %v", err)
159 HTTPError(ctx, mux, marshaler, w, req, err)
160 return
161 }
162
163 if _, err = w.Write(buf); err != nil {
164 grpclog.Infof("Failed to write response: %v", err)
165 }
166
167 handleForwardResponseTrailer(w, md)
168}
169
170func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error {
171 if len(opts) == 0 {
172 return nil
173 }
174 for _, opt := range opts {
175 if err := opt(ctx, w, resp); err != nil {
176 grpclog.Infof("Error handling ForwardResponseOptions: %v", err)
177 return err
178 }
179 }
180 return nil
181}
182
183func handleForwardResponseStreamError(ctx context.Context, wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, req *http.Request, mux *ServeMux, err error) {
184 serr := streamError(ctx, mux.streamErrorHandler, err)
185 if !wroteHeader {
186 w.WriteHeader(int(serr.HttpCode))
187 }
188 buf, merr := marshaler.Marshal(errorChunk(serr))
189 if merr != nil {
190 grpclog.Infof("Failed to marshal an error: %v", merr)
191 return
192 }
193 if _, werr := w.Write(buf); werr != nil {
194 grpclog.Infof("Failed to notify error to client: %v", werr)
195 return
196 }
197}
198
khenaidooab1f7bd2019-11-14 14:00:27 -0500199// streamError returns the payload for the final message in a response stream
200// that represents the given err.
201func streamError(ctx context.Context, errHandler StreamErrorHandlerFunc, err error) *StreamError {
202 serr := errHandler(ctx, err)
203 if serr != nil {
204 return serr
205 }
206 // TODO: log about misbehaving stream error handler?
207 return DefaultHTTPStreamErrorHandler(ctx, err)
208}
209
210func errorChunk(err *StreamError) map[string]proto.Message {
211 return map[string]proto.Message{"error": (*internal.StreamError)(err)}
212}