blob: 31409ac4c193285589687f90fec8f6d18ab1a212 [file] [log] [blame]
Matteo Scandoloa6a3aee2019-11-26 13:30:14 -07001package genswagger
2
3import (
4 "bytes"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "path/filepath"
9 "reflect"
10 "strings"
11
12 "github.com/golang/glog"
13 pbdescriptor "github.com/golang/protobuf/descriptor"
14 "github.com/golang/protobuf/proto"
15 protocdescriptor "github.com/golang/protobuf/protoc-gen-go/descriptor"
16 plugin "github.com/golang/protobuf/protoc-gen-go/plugin"
17 "github.com/golang/protobuf/ptypes/any"
18 "github.com/grpc-ecosystem/grpc-gateway/internal"
19 "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
20 gen "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/generator"
21 swagger_options "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-swagger/options"
22)
23
24var (
25 errNoTargetService = errors.New("no target service defined in the file")
26)
27
28type generator struct {
29 reg *descriptor.Registry
30}
31
32type wrapper struct {
33 fileName string
34 swagger *swaggerObject
35}
36
37// New returns a new generator which generates grpc gateway files.
38func New(reg *descriptor.Registry) gen.Generator {
39 return &generator{reg: reg}
40}
41
42// Merge a lot of swagger file (wrapper) to single one swagger file
43func mergeTargetFile(targets []*wrapper, mergeFileName string) *wrapper {
44 var mergedTarget *wrapper
45 for _, f := range targets {
46 if mergedTarget == nil {
47 mergedTarget = &wrapper{
48 fileName: mergeFileName,
49 swagger: f.swagger,
50 }
51 } else {
52 for k, v := range f.swagger.Definitions {
53 mergedTarget.swagger.Definitions[k] = v
54 }
55 for k, v := range f.swagger.StreamDefinitions {
56 mergedTarget.swagger.StreamDefinitions[k] = v
57 }
58 for k, v := range f.swagger.Paths {
59 mergedTarget.swagger.Paths[k] = v
60 }
61 for k, v := range f.swagger.SecurityDefinitions {
62 mergedTarget.swagger.SecurityDefinitions[k] = v
63 }
64 mergedTarget.swagger.Security = append(mergedTarget.swagger.Security, f.swagger.Security...)
65 }
66 }
67 return mergedTarget
68}
69
70func fieldName(k string) string {
71 return strings.ReplaceAll(strings.Title(k), "-", "_")
72}
73
74// Q: What's up with the alias types here?
75// A: We don't want to completely override how these structs are marshaled into
76// JSON, we only want to add fields (see below, extensionMarshalJSON).
77// An infinite recursion would happen if we'd call json.Marshal on the struct
78// that has swaggerObject as an embedded field. To avoid that, we'll create
79// type aliases, and those don't have the custom MarshalJSON methods defined
80// on them. See http://choly.ca/post/go-json-marshalling/ (or, if it ever
81// goes away, use
82// https://web.archive.org/web/20190806073003/http://choly.ca/post/go-json-marshalling/.
83func (so swaggerObject) MarshalJSON() ([]byte, error) {
84 type alias swaggerObject
85 return extensionMarshalJSON(alias(so), so.extensions)
86}
87
88func (so swaggerInfoObject) MarshalJSON() ([]byte, error) {
89 type alias swaggerInfoObject
90 return extensionMarshalJSON(alias(so), so.extensions)
91}
92
93func (so swaggerSecuritySchemeObject) MarshalJSON() ([]byte, error) {
94 type alias swaggerSecuritySchemeObject
95 return extensionMarshalJSON(alias(so), so.extensions)
96}
97
98func (so swaggerOperationObject) MarshalJSON() ([]byte, error) {
99 type alias swaggerOperationObject
100 return extensionMarshalJSON(alias(so), so.extensions)
101}
102
103func (so swaggerResponseObject) MarshalJSON() ([]byte, error) {
104 type alias swaggerResponseObject
105 return extensionMarshalJSON(alias(so), so.extensions)
106}
107
108func extensionMarshalJSON(so interface{}, extensions []extension) ([]byte, error) {
109 // To append arbitrary keys to the struct we'll render into json,
110 // we're creating another struct that embeds the original one, and
111 // its extra fields:
112 //
113 // The struct will look like
114 // struct {
115 // *swaggerCore
116 // XGrpcGatewayFoo json.RawMessage `json:"x-grpc-gateway-foo"`
117 // XGrpcGatewayBar json.RawMessage `json:"x-grpc-gateway-bar"`
118 // }
119 // and thus render into what we want -- the JSON of swaggerCore with the
120 // extensions appended.
121 fields := []reflect.StructField{
122 reflect.StructField{ // embedded
123 Name: "Embedded",
124 Type: reflect.TypeOf(so),
125 Anonymous: true,
126 },
127 }
128 for _, ext := range extensions {
129 fields = append(fields, reflect.StructField{
130 Name: fieldName(ext.key),
131 Type: reflect.TypeOf(ext.value),
132 Tag: reflect.StructTag(fmt.Sprintf("json:\"%s\"", ext.key)),
133 })
134 }
135
136 t := reflect.StructOf(fields)
137 s := reflect.New(t).Elem()
138 s.Field(0).Set(reflect.ValueOf(so))
139 for _, ext := range extensions {
140 s.FieldByName(fieldName(ext.key)).Set(reflect.ValueOf(ext.value))
141 }
142 return json.Marshal(s.Interface())
143}
144
145// encodeSwagger converts swagger file obj to plugin.CodeGeneratorResponse_File
146func encodeSwagger(file *wrapper) (*plugin.CodeGeneratorResponse_File, error) {
147 var formatted bytes.Buffer
148 enc := json.NewEncoder(&formatted)
149 enc.SetIndent("", " ")
150 if err := enc.Encode(*file.swagger); err != nil {
151 return nil, err
152 }
153 name := file.fileName
154 ext := filepath.Ext(name)
155 base := strings.TrimSuffix(name, ext)
156 output := fmt.Sprintf("%s.swagger.json", base)
157 return &plugin.CodeGeneratorResponse_File{
158 Name: proto.String(output),
159 Content: proto.String(formatted.String()),
160 }, nil
161}
162
163func (g *generator) Generate(targets []*descriptor.File) ([]*plugin.CodeGeneratorResponse_File, error) {
164 var files []*plugin.CodeGeneratorResponse_File
165 if g.reg.IsAllowMerge() {
166 var mergedTarget *descriptor.File
167 // try to find proto leader
168 for _, f := range targets {
169 if proto.HasExtension(f.Options, swagger_options.E_Openapiv2Swagger) {
170 mergedTarget = f
171 break
172 }
173 }
174 // merge protos to leader
175 for _, f := range targets {
176 if mergedTarget == nil {
177 mergedTarget = f
178 } else {
179 mergedTarget.Enums = append(mergedTarget.Enums, f.Enums...)
180 mergedTarget.Messages = append(mergedTarget.Messages, f.Messages...)
181 mergedTarget.Services = append(mergedTarget.Services, f.Services...)
182 }
183 }
184
185 targets = nil
186 targets = append(targets, mergedTarget)
187 }
188
189 var swaggers []*wrapper
190 for _, file := range targets {
191 glog.V(1).Infof("Processing %s", file.GetName())
192 swagger, err := applyTemplate(param{File: file, reg: g.reg})
193 if err == errNoTargetService {
194 glog.V(1).Infof("%s: %v", file.GetName(), err)
195 continue
196 }
197 if err != nil {
198 return nil, err
199 }
200 swaggers = append(swaggers, &wrapper{
201 fileName: file.GetName(),
202 swagger: swagger,
203 })
204 }
205
206 if g.reg.IsAllowMerge() {
207 targetSwagger := mergeTargetFile(swaggers, g.reg.GetMergeFileName())
208 f, err := encodeSwagger(targetSwagger)
209 if err != nil {
210 return nil, fmt.Errorf("failed to encode swagger for %s: %s", g.reg.GetMergeFileName(), err)
211 }
212 files = append(files, f)
213 glog.V(1).Infof("New swagger file will emit")
214 } else {
215 for _, file := range swaggers {
216 f, err := encodeSwagger(file)
217 if err != nil {
218 return nil, fmt.Errorf("failed to encode swagger for %s: %s", file.fileName, err)
219 }
220 files = append(files, f)
221 glog.V(1).Infof("New swagger file will emit")
222 }
223 }
224 return files, nil
225}
226
227//AddStreamError Adds grpc.gateway.runtime.StreamError and google.protobuf.Any to registry for stream responses
228func AddStreamError(reg *descriptor.Registry) error {
229 //load internal protos
230 any := fileDescriptorProtoForMessage(&any.Any{})
231 streamError := fileDescriptorProtoForMessage(&internal.StreamError{})
232 if err := reg.Load(&plugin.CodeGeneratorRequest{
233 ProtoFile: []*protocdescriptor.FileDescriptorProto{
234 any,
235 streamError,
236 },
237 }); err != nil {
238 return err
239 }
240 return nil
241}
242
243func fileDescriptorProtoForMessage(msg pbdescriptor.Message) *protocdescriptor.FileDescriptorProto {
244 fdp, _ := pbdescriptor.ForMessage(msg)
245 fdp.SourceCodeInfo = &protocdescriptor.SourceCodeInfo{}
246 return fdp
247}