blob: 0b6bfbd2b93bb2a46dccd17796bd63f3fe186d2f [file] [log] [blame]
Matteo Scandoloa6a3aee2019-11-26 13:30:14 -07001package gengateway
2
3import (
4 "errors"
5 "fmt"
6 "go/format"
7 "path"
8 "path/filepath"
9 "strings"
10
11 "github.com/golang/glog"
12 "github.com/golang/protobuf/proto"
13 plugin "github.com/golang/protobuf/protoc-gen-go/plugin"
14 "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
15 gen "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/generator"
16)
17
18var (
19 errNoTargetService = errors.New("no target service defined in the file")
20)
21
22type pathType int
23
24const (
25 pathTypeImport pathType = iota
26 pathTypeSourceRelative
27)
28
29type generator struct {
30 reg *descriptor.Registry
31 baseImports []descriptor.GoPackage
32 useRequestContext bool
33 registerFuncSuffix string
34 pathType pathType
35 allowPatchFeature bool
36}
37
38// New returns a new generator which generates grpc gateway files.
39func New(reg *descriptor.Registry, useRequestContext bool, registerFuncSuffix, pathTypeString string, allowPatchFeature bool) gen.Generator {
40 var imports []descriptor.GoPackage
41 for _, pkgpath := range []string{
42 "context",
43 "io",
44 "net/http",
45 "github.com/grpc-ecosystem/grpc-gateway/runtime",
46 "github.com/grpc-ecosystem/grpc-gateway/utilities",
47 "github.com/golang/protobuf/descriptor",
48 "github.com/golang/protobuf/proto",
49 "google.golang.org/grpc",
50 "google.golang.org/grpc/codes",
51 "google.golang.org/grpc/grpclog",
52 "google.golang.org/grpc/status",
53 } {
54 pkg := descriptor.GoPackage{
55 Path: pkgpath,
56 Name: path.Base(pkgpath),
57 }
58 if err := reg.ReserveGoPackageAlias(pkg.Name, pkg.Path); err != nil {
59 for i := 0; ; i++ {
60 alias := fmt.Sprintf("%s_%d", pkg.Name, i)
61 if err := reg.ReserveGoPackageAlias(alias, pkg.Path); err != nil {
62 continue
63 }
64 pkg.Alias = alias
65 break
66 }
67 }
68 imports = append(imports, pkg)
69 }
70
71 var pathType pathType
72 switch pathTypeString {
73 case "", "import":
74 // paths=import is default
75 case "source_relative":
76 pathType = pathTypeSourceRelative
77 default:
78 glog.Fatalf(`Unknown path type %q: want "import" or "source_relative".`, pathTypeString)
79 }
80
81 return &generator{
82 reg: reg,
83 baseImports: imports,
84 useRequestContext: useRequestContext,
85 registerFuncSuffix: registerFuncSuffix,
86 pathType: pathType,
87 allowPatchFeature: allowPatchFeature,
88 }
89}
90
91func (g *generator) Generate(targets []*descriptor.File) ([]*plugin.CodeGeneratorResponse_File, error) {
92 var files []*plugin.CodeGeneratorResponse_File
93 for _, file := range targets {
94 glog.V(1).Infof("Processing %s", file.GetName())
95 code, err := g.generate(file)
96 if err == errNoTargetService {
97 glog.V(1).Infof("%s: %v", file.GetName(), err)
98 continue
99 }
100 if err != nil {
101 return nil, err
102 }
103 formatted, err := format.Source([]byte(code))
104 if err != nil {
105 glog.Errorf("%v: %s", err, code)
106 return nil, err
107 }
108 name := file.GetName()
109 if g.pathType == pathTypeImport && file.GoPkg.Path != "" {
110 name = fmt.Sprintf("%s/%s", file.GoPkg.Path, filepath.Base(name))
111 }
112 ext := filepath.Ext(name)
113 base := strings.TrimSuffix(name, ext)
114 output := fmt.Sprintf("%s.pb.gw.go", base)
115 files = append(files, &plugin.CodeGeneratorResponse_File{
116 Name: proto.String(output),
117 Content: proto.String(string(formatted)),
118 })
119 glog.V(1).Infof("Will emit %s", output)
120 }
121 return files, nil
122}
123
124func (g *generator) generate(file *descriptor.File) (string, error) {
125 pkgSeen := make(map[string]bool)
126 var imports []descriptor.GoPackage
127 for _, pkg := range g.baseImports {
128 pkgSeen[pkg.Path] = true
129 imports = append(imports, pkg)
130 }
131 for _, svc := range file.Services {
132 for _, m := range svc.Methods {
133 imports = append(imports, g.addEnumPathParamImports(file, m, pkgSeen)...)
134 pkg := m.RequestType.File.GoPkg
135 if len(m.Bindings) == 0 ||
136 pkg == file.GoPkg || pkgSeen[pkg.Path] {
137 continue
138 }
139 pkgSeen[pkg.Path] = true
140 imports = append(imports, pkg)
141 }
142 }
143 params := param{
144 File: file,
145 Imports: imports,
146 UseRequestContext: g.useRequestContext,
147 RegisterFuncSuffix: g.registerFuncSuffix,
148 AllowPatchFeature: g.allowPatchFeature,
149 }
150 return applyTemplate(params, g.reg)
151}
152
153// addEnumPathParamImports handles adding import of enum path parameter go packages
154func (g *generator) addEnumPathParamImports(file *descriptor.File, m *descriptor.Method, pkgSeen map[string]bool) []descriptor.GoPackage {
155 var imports []descriptor.GoPackage
156 for _, b := range m.Bindings {
157 for _, p := range b.PathParams {
158 e, err := g.reg.LookupEnum("", p.Target.GetTypeName())
159 if err != nil {
160 continue
161 }
162 pkg := e.File.GoPkg
163 if pkg == file.GoPkg || pkgSeen[pkg.Path] {
164 continue
165 }
166 pkgSeen[pkg.Path] = true
167 imports = append(imports, pkg)
168 }
169 }
170 return imports
171}