Matteo Scandolo | a6a3aee | 2019-11-26 13:30:14 -0700 | [diff] [blame^] | 1 | package genswagger |
| 2 | |
| 3 | import ( |
| 4 | "bytes" |
| 5 | "encoding/json" |
| 6 | "errors" |
| 7 | "fmt" |
| 8 | "path/filepath" |
| 9 | "reflect" |
| 10 | "strings" |
| 11 | |
| 12 | "github.com/golang/glog" |
| 13 | pbdescriptor "github.com/golang/protobuf/descriptor" |
| 14 | "github.com/golang/protobuf/proto" |
| 15 | protocdescriptor "github.com/golang/protobuf/protoc-gen-go/descriptor" |
| 16 | plugin "github.com/golang/protobuf/protoc-gen-go/plugin" |
| 17 | "github.com/golang/protobuf/ptypes/any" |
| 18 | "github.com/grpc-ecosystem/grpc-gateway/internal" |
| 19 | "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor" |
| 20 | gen "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/generator" |
| 21 | swagger_options "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-swagger/options" |
| 22 | ) |
| 23 | |
| 24 | var ( |
| 25 | errNoTargetService = errors.New("no target service defined in the file") |
| 26 | ) |
| 27 | |
| 28 | type generator struct { |
| 29 | reg *descriptor.Registry |
| 30 | } |
| 31 | |
| 32 | type wrapper struct { |
| 33 | fileName string |
| 34 | swagger *swaggerObject |
| 35 | } |
| 36 | |
| 37 | // New returns a new generator which generates grpc gateway files. |
| 38 | func New(reg *descriptor.Registry) gen.Generator { |
| 39 | return &generator{reg: reg} |
| 40 | } |
| 41 | |
| 42 | // Merge a lot of swagger file (wrapper) to single one swagger file |
| 43 | func mergeTargetFile(targets []*wrapper, mergeFileName string) *wrapper { |
| 44 | var mergedTarget *wrapper |
| 45 | for _, f := range targets { |
| 46 | if mergedTarget == nil { |
| 47 | mergedTarget = &wrapper{ |
| 48 | fileName: mergeFileName, |
| 49 | swagger: f.swagger, |
| 50 | } |
| 51 | } else { |
| 52 | for k, v := range f.swagger.Definitions { |
| 53 | mergedTarget.swagger.Definitions[k] = v |
| 54 | } |
| 55 | for k, v := range f.swagger.StreamDefinitions { |
| 56 | mergedTarget.swagger.StreamDefinitions[k] = v |
| 57 | } |
| 58 | for k, v := range f.swagger.Paths { |
| 59 | mergedTarget.swagger.Paths[k] = v |
| 60 | } |
| 61 | for k, v := range f.swagger.SecurityDefinitions { |
| 62 | mergedTarget.swagger.SecurityDefinitions[k] = v |
| 63 | } |
| 64 | mergedTarget.swagger.Security = append(mergedTarget.swagger.Security, f.swagger.Security...) |
| 65 | } |
| 66 | } |
| 67 | return mergedTarget |
| 68 | } |
| 69 | |
| 70 | func fieldName(k string) string { |
| 71 | return strings.ReplaceAll(strings.Title(k), "-", "_") |
| 72 | } |
| 73 | |
| 74 | // Q: What's up with the alias types here? |
| 75 | // A: We don't want to completely override how these structs are marshaled into |
| 76 | // JSON, we only want to add fields (see below, extensionMarshalJSON). |
| 77 | // An infinite recursion would happen if we'd call json.Marshal on the struct |
| 78 | // that has swaggerObject as an embedded field. To avoid that, we'll create |
| 79 | // type aliases, and those don't have the custom MarshalJSON methods defined |
| 80 | // on them. See http://choly.ca/post/go-json-marshalling/ (or, if it ever |
| 81 | // goes away, use |
| 82 | // https://web.archive.org/web/20190806073003/http://choly.ca/post/go-json-marshalling/. |
| 83 | func (so swaggerObject) MarshalJSON() ([]byte, error) { |
| 84 | type alias swaggerObject |
| 85 | return extensionMarshalJSON(alias(so), so.extensions) |
| 86 | } |
| 87 | |
| 88 | func (so swaggerInfoObject) MarshalJSON() ([]byte, error) { |
| 89 | type alias swaggerInfoObject |
| 90 | return extensionMarshalJSON(alias(so), so.extensions) |
| 91 | } |
| 92 | |
| 93 | func (so swaggerSecuritySchemeObject) MarshalJSON() ([]byte, error) { |
| 94 | type alias swaggerSecuritySchemeObject |
| 95 | return extensionMarshalJSON(alias(so), so.extensions) |
| 96 | } |
| 97 | |
| 98 | func (so swaggerOperationObject) MarshalJSON() ([]byte, error) { |
| 99 | type alias swaggerOperationObject |
| 100 | return extensionMarshalJSON(alias(so), so.extensions) |
| 101 | } |
| 102 | |
| 103 | func (so swaggerResponseObject) MarshalJSON() ([]byte, error) { |
| 104 | type alias swaggerResponseObject |
| 105 | return extensionMarshalJSON(alias(so), so.extensions) |
| 106 | } |
| 107 | |
| 108 | func extensionMarshalJSON(so interface{}, extensions []extension) ([]byte, error) { |
| 109 | // To append arbitrary keys to the struct we'll render into json, |
| 110 | // we're creating another struct that embeds the original one, and |
| 111 | // its extra fields: |
| 112 | // |
| 113 | // The struct will look like |
| 114 | // struct { |
| 115 | // *swaggerCore |
| 116 | // XGrpcGatewayFoo json.RawMessage `json:"x-grpc-gateway-foo"` |
| 117 | // XGrpcGatewayBar json.RawMessage `json:"x-grpc-gateway-bar"` |
| 118 | // } |
| 119 | // and thus render into what we want -- the JSON of swaggerCore with the |
| 120 | // extensions appended. |
| 121 | fields := []reflect.StructField{ |
| 122 | reflect.StructField{ // embedded |
| 123 | Name: "Embedded", |
| 124 | Type: reflect.TypeOf(so), |
| 125 | Anonymous: true, |
| 126 | }, |
| 127 | } |
| 128 | for _, ext := range extensions { |
| 129 | fields = append(fields, reflect.StructField{ |
| 130 | Name: fieldName(ext.key), |
| 131 | Type: reflect.TypeOf(ext.value), |
| 132 | Tag: reflect.StructTag(fmt.Sprintf("json:\"%s\"", ext.key)), |
| 133 | }) |
| 134 | } |
| 135 | |
| 136 | t := reflect.StructOf(fields) |
| 137 | s := reflect.New(t).Elem() |
| 138 | s.Field(0).Set(reflect.ValueOf(so)) |
| 139 | for _, ext := range extensions { |
| 140 | s.FieldByName(fieldName(ext.key)).Set(reflect.ValueOf(ext.value)) |
| 141 | } |
| 142 | return json.Marshal(s.Interface()) |
| 143 | } |
| 144 | |
| 145 | // encodeSwagger converts swagger file obj to plugin.CodeGeneratorResponse_File |
| 146 | func encodeSwagger(file *wrapper) (*plugin.CodeGeneratorResponse_File, error) { |
| 147 | var formatted bytes.Buffer |
| 148 | enc := json.NewEncoder(&formatted) |
| 149 | enc.SetIndent("", " ") |
| 150 | if err := enc.Encode(*file.swagger); err != nil { |
| 151 | return nil, err |
| 152 | } |
| 153 | name := file.fileName |
| 154 | ext := filepath.Ext(name) |
| 155 | base := strings.TrimSuffix(name, ext) |
| 156 | output := fmt.Sprintf("%s.swagger.json", base) |
| 157 | return &plugin.CodeGeneratorResponse_File{ |
| 158 | Name: proto.String(output), |
| 159 | Content: proto.String(formatted.String()), |
| 160 | }, nil |
| 161 | } |
| 162 | |
| 163 | func (g *generator) Generate(targets []*descriptor.File) ([]*plugin.CodeGeneratorResponse_File, error) { |
| 164 | var files []*plugin.CodeGeneratorResponse_File |
| 165 | if g.reg.IsAllowMerge() { |
| 166 | var mergedTarget *descriptor.File |
| 167 | // try to find proto leader |
| 168 | for _, f := range targets { |
| 169 | if proto.HasExtension(f.Options, swagger_options.E_Openapiv2Swagger) { |
| 170 | mergedTarget = f |
| 171 | break |
| 172 | } |
| 173 | } |
| 174 | // merge protos to leader |
| 175 | for _, f := range targets { |
| 176 | if mergedTarget == nil { |
| 177 | mergedTarget = f |
| 178 | } else { |
| 179 | mergedTarget.Enums = append(mergedTarget.Enums, f.Enums...) |
| 180 | mergedTarget.Messages = append(mergedTarget.Messages, f.Messages...) |
| 181 | mergedTarget.Services = append(mergedTarget.Services, f.Services...) |
| 182 | } |
| 183 | } |
| 184 | |
| 185 | targets = nil |
| 186 | targets = append(targets, mergedTarget) |
| 187 | } |
| 188 | |
| 189 | var swaggers []*wrapper |
| 190 | for _, file := range targets { |
| 191 | glog.V(1).Infof("Processing %s", file.GetName()) |
| 192 | swagger, err := applyTemplate(param{File: file, reg: g.reg}) |
| 193 | if err == errNoTargetService { |
| 194 | glog.V(1).Infof("%s: %v", file.GetName(), err) |
| 195 | continue |
| 196 | } |
| 197 | if err != nil { |
| 198 | return nil, err |
| 199 | } |
| 200 | swaggers = append(swaggers, &wrapper{ |
| 201 | fileName: file.GetName(), |
| 202 | swagger: swagger, |
| 203 | }) |
| 204 | } |
| 205 | |
| 206 | if g.reg.IsAllowMerge() { |
| 207 | targetSwagger := mergeTargetFile(swaggers, g.reg.GetMergeFileName()) |
| 208 | f, err := encodeSwagger(targetSwagger) |
| 209 | if err != nil { |
| 210 | return nil, fmt.Errorf("failed to encode swagger for %s: %s", g.reg.GetMergeFileName(), err) |
| 211 | } |
| 212 | files = append(files, f) |
| 213 | glog.V(1).Infof("New swagger file will emit") |
| 214 | } else { |
| 215 | for _, file := range swaggers { |
| 216 | f, err := encodeSwagger(file) |
| 217 | if err != nil { |
| 218 | return nil, fmt.Errorf("failed to encode swagger for %s: %s", file.fileName, err) |
| 219 | } |
| 220 | files = append(files, f) |
| 221 | glog.V(1).Infof("New swagger file will emit") |
| 222 | } |
| 223 | } |
| 224 | return files, nil |
| 225 | } |
| 226 | |
| 227 | //AddStreamError Adds grpc.gateway.runtime.StreamError and google.protobuf.Any to registry for stream responses |
| 228 | func AddStreamError(reg *descriptor.Registry) error { |
| 229 | //load internal protos |
| 230 | any := fileDescriptorProtoForMessage(&any.Any{}) |
| 231 | streamError := fileDescriptorProtoForMessage(&internal.StreamError{}) |
| 232 | if err := reg.Load(&plugin.CodeGeneratorRequest{ |
| 233 | ProtoFile: []*protocdescriptor.FileDescriptorProto{ |
| 234 | any, |
| 235 | streamError, |
| 236 | }, |
| 237 | }); err != nil { |
| 238 | return err |
| 239 | } |
| 240 | return nil |
| 241 | } |
| 242 | |
| 243 | func fileDescriptorProtoForMessage(msg pbdescriptor.Message) *protocdescriptor.FileDescriptorProto { |
| 244 | fdp, _ := pbdescriptor.ForMessage(msg) |
| 245 | fdp.SourceCodeInfo = &protocdescriptor.SourceCodeInfo{} |
| 246 | return fdp |
| 247 | } |