blob: 2af900650dcd28c538dad43d65dc9c2a98da2955 [file] [log] [blame]
khenaidoo59ce9dd2019-11-11 13:05:32 -05001package runtime
2
3import (
4 "errors"
5 "fmt"
6 "io"
7 "net/http"
8 "net/textproto"
9
10 "context"
11 "github.com/golang/protobuf/proto"
12 "github.com/grpc-ecosystem/grpc-gateway/internal"
13 "google.golang.org/grpc/grpclog"
14)
15
16var errEmptyResponse = errors.New("empty response")
17
18// ForwardResponseStream forwards the stream from gRPC server to REST client.
19func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
20 f, ok := w.(http.Flusher)
21 if !ok {
22 grpclog.Infof("Flush not supported in %T", w)
23 http.Error(w, "unexpected type of web server", http.StatusInternalServerError)
24 return
25 }
26
27 md, ok := ServerMetadataFromContext(ctx)
28 if !ok {
29 grpclog.Infof("Failed to extract ServerMetadata from context")
30 http.Error(w, "unexpected error", http.StatusInternalServerError)
31 return
32 }
33 handleForwardResponseServerMetadata(w, mux, md)
34
35 w.Header().Set("Transfer-Encoding", "chunked")
36 w.Header().Set("Content-Type", marshaler.ContentType())
37 if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
38 HTTPError(ctx, mux, marshaler, w, req, err)
39 return
40 }
41
42 var delimiter []byte
43 if d, ok := marshaler.(Delimited); ok {
44 delimiter = d.Delimiter()
45 } else {
46 delimiter = []byte("\n")
47 }
48
49 var wroteHeader bool
50 for {
51 resp, err := recv()
52 if err == io.EOF {
53 return
54 }
55 if err != nil {
56 handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
57 return
58 }
59 if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
60 handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
61 return
62 }
63
64 buf, err := marshaler.Marshal(streamChunk(ctx, resp, mux.streamErrorHandler))
65 if err != nil {
66 grpclog.Infof("Failed to marshal response chunk: %v", err)
67 handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
68 return
69 }
70 if _, err = w.Write(buf); err != nil {
71 grpclog.Infof("Failed to send response chunk: %v", err)
72 return
73 }
74 wroteHeader = true
75 if _, err = w.Write(delimiter); err != nil {
76 grpclog.Infof("Failed to send delimiter chunk: %v", err)
77 return
78 }
79 f.Flush()
80 }
81}
82
83func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
84 for k, vs := range md.HeaderMD {
85 if h, ok := mux.outgoingHeaderMatcher(k); ok {
86 for _, v := range vs {
87 w.Header().Add(h, v)
88 }
89 }
90 }
91}
92
93func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) {
94 for k := range md.TrailerMD {
95 tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", MetadataTrailerPrefix, k))
96 w.Header().Add("Trailer", tKey)
97 }
98}
99
100func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) {
101 for k, vs := range md.TrailerMD {
102 tKey := fmt.Sprintf("%s%s", MetadataTrailerPrefix, k)
103 for _, v := range vs {
104 w.Header().Add(tKey, v)
105 }
106 }
107}
108
109// responseBody interface contains method for getting field for marshaling to the response body
110// this method is generated for response struct from the value of `response_body` in the `google.api.HttpRule`
111type responseBody interface {
112 XXX_ResponseBody() interface{}
113}
114
115// ForwardResponseMessage forwards the message "resp" from gRPC server to REST client.
116func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
117 md, ok := ServerMetadataFromContext(ctx)
118 if !ok {
119 grpclog.Infof("Failed to extract ServerMetadata from context")
120 }
121
122 handleForwardResponseServerMetadata(w, mux, md)
123 handleForwardResponseTrailerHeader(w, md)
124
125 contentType := marshaler.ContentType()
126 // Check marshaler on run time in order to keep backwards compatability
127 // An interface param needs to be added to the ContentType() function on
128 // the Marshal interface to be able to remove this check
129 if httpBodyMarshaler, ok := marshaler.(*HTTPBodyMarshaler); ok {
130 contentType = httpBodyMarshaler.ContentTypeFromMessage(resp)
131 }
132 w.Header().Set("Content-Type", contentType)
133
134 if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
135 HTTPError(ctx, mux, marshaler, w, req, err)
136 return
137 }
138 var buf []byte
139 var err error
140 if rb, ok := resp.(responseBody); ok {
141 buf, err = marshaler.Marshal(rb.XXX_ResponseBody())
142 } else {
143 buf, err = marshaler.Marshal(resp)
144 }
145 if err != nil {
146 grpclog.Infof("Marshal error: %v", err)
147 HTTPError(ctx, mux, marshaler, w, req, err)
148 return
149 }
150
151 if _, err = w.Write(buf); err != nil {
152 grpclog.Infof("Failed to write response: %v", err)
153 }
154
155 handleForwardResponseTrailer(w, md)
156}
157
158func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error {
159 if len(opts) == 0 {
160 return nil
161 }
162 for _, opt := range opts {
163 if err := opt(ctx, w, resp); err != nil {
164 grpclog.Infof("Error handling ForwardResponseOptions: %v", err)
165 return err
166 }
167 }
168 return nil
169}
170
171func handleForwardResponseStreamError(ctx context.Context, wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, req *http.Request, mux *ServeMux, err error) {
172 serr := streamError(ctx, mux.streamErrorHandler, err)
173 if !wroteHeader {
174 w.WriteHeader(int(serr.HttpCode))
175 }
176 buf, merr := marshaler.Marshal(errorChunk(serr))
177 if merr != nil {
178 grpclog.Infof("Failed to marshal an error: %v", merr)
179 return
180 }
181 if _, werr := w.Write(buf); werr != nil {
182 grpclog.Infof("Failed to notify error to client: %v", werr)
183 return
184 }
185}
186
187// streamChunk returns a chunk in a response stream for the given result. The
188// given errHandler is used to render an error chunk if result is nil.
189func streamChunk(ctx context.Context, result proto.Message, errHandler StreamErrorHandlerFunc) map[string]proto.Message {
190 if result == nil {
191 return errorChunk(streamError(ctx, errHandler, errEmptyResponse))
192 }
193 return map[string]proto.Message{"result": result}
194}
195
196// streamError returns the payload for the final message in a response stream
197// that represents the given err.
198func streamError(ctx context.Context, errHandler StreamErrorHandlerFunc, err error) *StreamError {
199 serr := errHandler(ctx, err)
200 if serr != nil {
201 return serr
202 }
203 // TODO: log about misbehaving stream error handler?
204 return DefaultHTTPStreamErrorHandler(ctx, err)
205}
206
207func errorChunk(err *StreamError) map[string]proto.Message {
208 return map[string]proto.Message{"error": (*internal.StreamError)(err)}
209}