blob: d0a61c2ecaa399ddcbb9e16508a3eb00019f6510 [file] [log] [blame]
Zack Williamse940c7a2019-08-21 14:25:39 -07001package protoparse
2
3import (
4 "bytes"
5 "reflect"
6 "sort"
7 "strings"
8
9 "github.com/golang/protobuf/proto"
10 dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
11
12 "github.com/jhump/protoreflect/desc/internal"
13)
14
15func (r *parseResult) generateSourceCodeInfo() *dpb.SourceCodeInfo {
16 if r.nodes == nil {
17 // skip files that do not have AST info (these will be files
18 // that came from well-known descriptors, instead of from source)
19 return nil
20 }
21
22 sci := sourceCodeInfo{commentsUsed: map[*comment]struct{}{}}
23 path := make([]int32, 0, 10)
24
25 fn := r.getFileNode(r.fd).(*fileNode)
26 if fn.syntax != nil {
27 sci.newLoc(fn.syntax, append(path, internal.File_syntaxTag))
28 }
29 if fn.pkg != nil {
30 sci.newLoc(fn.pkg, append(path, internal.File_packageTag))
31 }
32 for i, imp := range fn.imports {
33 sci.newLoc(imp, append(path, internal.File_dependencyTag, int32(i)))
34 }
35
36 // file options
37 r.generateSourceCodeInfoForOptions(&sci, fn.decls, func(n interface{}) *optionNode {
38 return n.(*fileElement).option
39 }, r.fd.Options.GetUninterpretedOption(), append(path, internal.File_optionsTag))
40
41 // message types
42 for i, msg := range r.fd.GetMessageType() {
43 r.generateSourceCodeInfoForMessage(&sci, msg, append(path, internal.File_messagesTag, int32(i)))
44 }
45
46 // enum types
47 for i, enum := range r.fd.GetEnumType() {
48 r.generateSourceCodeInfoForEnum(&sci, enum, append(path, internal.File_enumsTag, int32(i)))
49 }
50
51 // extension fields
52 for i, ext := range r.fd.GetExtension() {
53 r.generateSourceCodeInfoForField(&sci, ext, append(path, internal.File_extensionsTag, int32(i)))
54 }
55
56 // services and methods
57 for i, svc := range r.fd.GetService() {
58 n := r.getServiceNode(svc).(*serviceNode)
59 svcPath := append(path, internal.File_servicesTag, int32(i))
60 sci.newLoc(n, svcPath)
61 sci.newLoc(n.name, append(svcPath, internal.Service_nameTag))
62
63 // service options
64 r.generateSourceCodeInfoForOptions(&sci, n.decls, func(n interface{}) *optionNode {
65 return n.(*serviceElement).option
66 }, svc.Options.GetUninterpretedOption(), append(svcPath, internal.Service_optionsTag))
67
68 // methods
69 for j, mtd := range svc.GetMethod() {
70 mn := r.getMethodNode(mtd).(*methodNode)
71 mtdPath := append(svcPath, internal.Service_methodsTag, int32(j))
72 sci.newLoc(mn, mtdPath)
73 sci.newLoc(mn.name, append(mtdPath, internal.Method_nameTag))
74
75 sci.newLoc(mn.input.msgType, append(mtdPath, internal.Method_inputTag))
76 if mn.input.streamKeyword != nil {
77 sci.newLoc(mn.input.streamKeyword, append(mtdPath, internal.Method_inputStreamTag))
78 }
79 sci.newLoc(mn.output.msgType, append(mtdPath, internal.Method_outputTag))
80 if mn.output.streamKeyword != nil {
81 sci.newLoc(mn.output.streamKeyword, append(mtdPath, internal.Method_outputStreamTag))
82 }
83
84 // method options
85 r.generateSourceCodeInfoForOptions(&sci, mn.options, func(n interface{}) *optionNode {
86 return n.(*optionNode)
87 }, mtd.Options.GetUninterpretedOption(), append(mtdPath, internal.Method_optionsTag))
88 }
89 }
90 return &dpb.SourceCodeInfo{Location: sci.generateLocs()}
91}
92
93func (r *parseResult) generateSourceCodeInfoForOptions(sci *sourceCodeInfo, elements interface{}, extractor func(interface{}) *optionNode, uninterp []*dpb.UninterpretedOption, path []int32) {
94 // Known options are option node elements that have a corresponding
95 // path in r.interpretedOptions. We'll do those first.
96 rv := reflect.ValueOf(elements)
97 for i := 0; i < rv.Len(); i++ {
98 on := extractor(rv.Index(i).Interface())
99 if on == nil {
100 continue
101 }
102 optPath := r.interpretedOptions[on]
103 if len(optPath) > 0 {
104 p := path
105 if optPath[0] == -1 {
106 // used by "default" and "json_name" field pseudo-options
107 // to attribute path to parent element (since those are
108 // stored directly on the descriptor, not its options)
109 p = make([]int32, len(path)-1)
110 copy(p, path)
111 optPath = optPath[1:]
112 }
113 sci.newLoc(on, append(p, optPath...))
114 }
115 }
116
117 // Now uninterpreted options
118 for i, uo := range uninterp {
119 optPath := append(path, internal.UninterpretedOptionsTag, int32(i))
120 on := r.getOptionNode(uo).(*optionNode)
121 sci.newLoc(on, optPath)
122
123 var valTag int32
124 switch {
125 case uo.IdentifierValue != nil:
126 valTag = internal.Uninterpreted_identTag
127 case uo.PositiveIntValue != nil:
128 valTag = internal.Uninterpreted_posIntTag
129 case uo.NegativeIntValue != nil:
130 valTag = internal.Uninterpreted_negIntTag
131 case uo.DoubleValue != nil:
132 valTag = internal.Uninterpreted_doubleTag
133 case uo.StringValue != nil:
134 valTag = internal.Uninterpreted_stringTag
135 case uo.AggregateValue != nil:
136 valTag = internal.Uninterpreted_aggregateTag
137 }
138 if valTag != 0 {
139 sci.newLoc(on.val, append(optPath, valTag))
140 }
141
142 for j, n := range uo.Name {
143 optNmPath := append(optPath, internal.Uninterpreted_nameTag, int32(j))
144 nn := r.getOptionNamePartNode(n).(*optionNamePartNode)
145 sci.newLoc(nn, optNmPath)
146 sci.newLoc(nn.text, append(optNmPath, internal.UninterpretedName_nameTag))
147 }
148 }
149}
150
151func (r *parseResult) generateSourceCodeInfoForMessage(sci *sourceCodeInfo, msg *dpb.DescriptorProto, path []int32) {
152 n := r.getMessageNode(msg)
153 sci.newLoc(n, path)
154
155 var decls []*messageElement
156 var resvdNames []*stringLiteralNode
157 switch n := n.(type) {
158 case *messageNode:
159 decls = n.decls
160 resvdNames = n.reserved
161 case *groupNode:
162 decls = n.decls
163 resvdNames = n.reserved
164 }
165 if decls == nil {
166 // map entry so nothing else to do
167 return
168 }
169
170 sci.newLoc(n.messageName(), append(path, internal.Message_nameTag))
171
172 // message options
173 r.generateSourceCodeInfoForOptions(sci, decls, func(n interface{}) *optionNode {
174 return n.(*messageElement).option
175 }, msg.Options.GetUninterpretedOption(), append(path, internal.Message_optionsTag))
176
177 // fields
178 for i, fld := range msg.GetField() {
179 r.generateSourceCodeInfoForField(sci, fld, append(path, internal.Message_fieldsTag, int32(i)))
180 }
181
182 // one-ofs
183 for i, ood := range msg.GetOneofDecl() {
184 oon := r.getOneOfNode(ood).(*oneOfNode)
185 ooPath := append(path, internal.Message_oneOfsTag, int32(i))
186 sci.newLoc(oon, ooPath)
187 sci.newLoc(oon.name, append(ooPath, internal.OneOf_nameTag))
188
189 // one-of options
190 r.generateSourceCodeInfoForOptions(sci, oon.decls, func(n interface{}) *optionNode {
191 return n.(*oneOfElement).option
192 }, ood.Options.GetUninterpretedOption(), append(ooPath, internal.OneOf_optionsTag))
193 }
194
195 // nested messages
196 for i, nm := range msg.GetNestedType() {
197 r.generateSourceCodeInfoForMessage(sci, nm, append(path, internal.Message_nestedMessagesTag, int32(i)))
198 }
199
200 // nested enums
201 for i, enum := range msg.GetEnumType() {
202 r.generateSourceCodeInfoForEnum(sci, enum, append(path, internal.Message_enumsTag, int32(i)))
203 }
204
205 // nested extensions
206 for i, ext := range msg.GetExtension() {
207 r.generateSourceCodeInfoForField(sci, ext, append(path, internal.Message_extensionsTag, int32(i)))
208 }
209
210 // extension ranges
211 for i, er := range msg.ExtensionRange {
212 rangePath := append(path, internal.Message_extensionRangeTag, int32(i))
213 rn := r.getExtensionRangeNode(er).(*rangeNode)
214 sci.newLoc(rn, rangePath)
215 sci.newLoc(rn.stNode, append(rangePath, internal.ExtensionRange_startTag))
216 if rn.stNode != rn.enNode {
217 sci.newLoc(rn.enNode, append(rangePath, internal.ExtensionRange_endTag))
218 }
219 // now we have to find the extension decl and options that correspond to this range :(
220 for _, d := range decls {
221 found := false
222 if d.extensionRange != nil {
223 for _, r := range d.extensionRange.ranges {
224 if rn == r {
225 found = true
226 break
227 }
228 }
229 }
230 if found {
231 r.generateSourceCodeInfoForOptions(sci, d.extensionRange.options, func(n interface{}) *optionNode {
232 return n.(*optionNode)
233 }, er.Options.GetUninterpretedOption(), append(rangePath, internal.ExtensionRange_optionsTag))
234 break
235 }
236 }
237 }
238
239 // reserved ranges
240 for i, rr := range msg.ReservedRange {
241 rangePath := append(path, internal.Message_reservedRangeTag, int32(i))
242 rn := r.getMessageReservedRangeNode(rr).(*rangeNode)
243 sci.newLoc(rn, rangePath)
244 sci.newLoc(rn.stNode, append(rangePath, internal.ReservedRange_startTag))
245 if rn.stNode != rn.enNode {
246 sci.newLoc(rn.enNode, append(rangePath, internal.ReservedRange_endTag))
247 }
248 }
249
250 // reserved names
251 for i, n := range resvdNames {
252 sci.newLoc(n, append(path, internal.Message_reservedNameTag, int32(i)))
253 }
254}
255
256func (r *parseResult) generateSourceCodeInfoForEnum(sci *sourceCodeInfo, enum *dpb.EnumDescriptorProto, path []int32) {
257 n := r.getEnumNode(enum).(*enumNode)
258 sci.newLoc(n, path)
259 sci.newLoc(n.name, append(path, internal.Enum_nameTag))
260
261 // enum options
262 r.generateSourceCodeInfoForOptions(sci, n.decls, func(n interface{}) *optionNode {
263 return n.(*enumElement).option
264 }, enum.Options.GetUninterpretedOption(), append(path, internal.Enum_optionsTag))
265
266 // enum values
267 for j, ev := range enum.GetValue() {
268 evn := r.getEnumValueNode(ev).(*enumValueNode)
269 evPath := append(path, internal.Enum_valuesTag, int32(j))
270 sci.newLoc(evn, evPath)
271 sci.newLoc(evn.name, append(evPath, internal.EnumVal_nameTag))
272 sci.newLoc(evn.getNumber(), append(evPath, internal.EnumVal_numberTag))
273
274 // enum value options
275 r.generateSourceCodeInfoForOptions(sci, evn.options, func(n interface{}) *optionNode {
276 return n.(*optionNode)
277 }, ev.Options.GetUninterpretedOption(), append(evPath, internal.EnumVal_optionsTag))
278 }
279
280 // reserved ranges
281 for i, rr := range enum.GetReservedRange() {
282 rangePath := append(path, internal.Enum_reservedRangeTag, int32(i))
283 rn := r.getEnumReservedRangeNode(rr).(*rangeNode)
284 sci.newLoc(rn, rangePath)
285 sci.newLoc(rn.stNode, append(rangePath, internal.ReservedRange_startTag))
286 if rn.stNode != rn.enNode {
287 sci.newLoc(rn.enNode, append(rangePath, internal.ReservedRange_endTag))
288 }
289 }
290
291 // reserved names
292 for i, rn := range n.reserved {
293 sci.newLoc(rn, append(path, internal.Enum_reservedNameTag, int32(i)))
294 }
295}
296
297func (r *parseResult) generateSourceCodeInfoForField(sci *sourceCodeInfo, fld *dpb.FieldDescriptorProto, path []int32) {
298 n := r.getFieldNode(fld)
299
300 isGroup := false
301 var opts []*optionNode
302 var extendee *extendNode
303 switch n := n.(type) {
304 case *fieldNode:
305 opts = n.options
306 extendee = n.extendee
307 case *mapFieldNode:
308 opts = n.options
309 case *groupNode:
310 isGroup = true
311 extendee = n.extendee
312 case *syntheticMapField:
313 // shouldn't get here since we don't recurse into fields from a mapNode
314 // in generateSourceCodeInfoForMessage... but just in case
315 return
316 }
317
318 sci.newLoc(n, path)
319 if !isGroup {
320 sci.newLoc(n.fieldName(), append(path, internal.Field_nameTag))
321 sci.newLoc(n.fieldType(), append(path, internal.Field_typeTag))
322 }
323 if n.fieldLabel() != nil {
324 sci.newLoc(n.fieldLabel(), append(path, internal.Field_labelTag))
325 }
326 sci.newLoc(n.fieldTag(), append(path, internal.Field_numberTag))
327 if extendee != nil {
328 sci.newLoc(extendee.extendee, append(path, internal.Field_extendeeTag))
329 }
330
331 r.generateSourceCodeInfoForOptions(sci, opts, func(n interface{}) *optionNode {
332 return n.(*optionNode)
333 }, fld.Options.GetUninterpretedOption(), append(path, internal.Field_optionsTag))
334}
335
336type sourceCodeInfo struct {
337 locs []*dpb.SourceCodeInfo_Location
338 commentsUsed map[*comment]struct{}
339}
340
341func (sci *sourceCodeInfo) newLoc(n node, path []int32) {
342 leadingComments := n.leadingComments()
343 trailingComments := n.trailingComments()
344 if sci.commentUsed(leadingComments) {
345 leadingComments = nil
346 }
347 if sci.commentUsed(trailingComments) {
348 trailingComments = nil
349 }
350 detached := groupComments(leadingComments)
351 trail := combineComments(trailingComments)
352 var lead *string
353 if len(leadingComments) > 0 && leadingComments[len(leadingComments)-1].end.Line >= n.start().Line-1 {
354 lead = proto.String(detached[len(detached)-1])
355 detached = detached[:len(detached)-1]
356 }
357 dup := make([]int32, len(path))
358 copy(dup, path)
359 var span []int32
360 if n.start().Line == n.end().Line {
361 span = []int32{int32(n.start().Line) - 1, int32(n.start().Col) - 1, int32(n.end().Col) - 1}
362 } else {
363 span = []int32{int32(n.start().Line) - 1, int32(n.start().Col) - 1, int32(n.end().Line) - 1, int32(n.end().Col) - 1}
364 }
365 sci.locs = append(sci.locs, &dpb.SourceCodeInfo_Location{
366 LeadingDetachedComments: detached,
367 LeadingComments: lead,
368 TrailingComments: trail,
369 Path: dup,
370 Span: span,
371 })
372}
373
374func (sci *sourceCodeInfo) commentUsed(c []*comment) bool {
375 if len(c) == 0 {
376 return false
377 }
378 if _, ok := sci.commentsUsed[c[0]]; ok {
379 return true
380 }
381
382 sci.commentsUsed[c[0]] = struct{}{}
383 return false
384}
385
386func groupComments(comments []*comment) []string {
387 if len(comments) == 0 {
388 return nil
389 }
390
391 var groups []string
392 singleLineStyle := comments[0].text[:2] == "//"
393 line := comments[0].end.Line
394 start := 0
395 for i := 1; i < len(comments); i++ {
396 c := comments[i]
397 prevSingleLine := singleLineStyle
398 singleLineStyle = strings.HasPrefix(comments[i].text, "//")
399 if !singleLineStyle || prevSingleLine != singleLineStyle || c.start.Line > line+1 {
400 // new group!
401 groups = append(groups, *combineComments(comments[start:i]))
402 start = i
403 }
404 line = c.end.Line
405 }
406 // don't forget last group
407 groups = append(groups, *combineComments(comments[start:]))
408
409 return groups
410}
411
412func combineComments(comments []*comment) *string {
413 if len(comments) == 0 {
414 return nil
415 }
416 first := true
417 var buf bytes.Buffer
418 for _, c := range comments {
419 if first {
420 first = false
421 } else {
422 buf.WriteByte('\n')
423 }
424 if c.text[:2] == "//" {
425 buf.WriteString(c.text[2:])
426 } else {
427 lines := strings.Split(c.text[2:len(c.text)-2], "\n")
428 first := true
429 for _, l := range lines {
430 if first {
431 first = false
432 } else {
433 buf.WriteByte('\n')
434 }
435
436 // strip a prefix of whitespace followed by '*'
437 j := 0
438 for j < len(l) {
439 if l[j] != ' ' && l[j] != '\t' {
440 break
441 }
442 j++
443 }
444 if j == len(l) {
445 l = ""
446 } else if l[j] == '*' {
447 l = l[j+1:]
448 } else if j > 0 {
449 l = " " + l[j:]
450 }
451
452 buf.WriteString(l)
453 }
454 }
455 }
456 return proto.String(buf.String())
457}
458
459func (sci *sourceCodeInfo) generateLocs() []*dpb.SourceCodeInfo_Location {
460 // generate intermediate locations: paths between root (inclusive) and the
461 // leaf locations already created, these will not have comments but will
462 // have aggregate span, than runs from min(start pos) to max(end pos) for
463 // all descendent paths.
464
465 if len(sci.locs) == 0 {
466 // nothing to generate
467 return nil
468 }
469
470 var root locTrie
471 for _, loc := range sci.locs {
472 root.add(loc.Path, loc)
473 }
474 root.fillIn()
475 locs := make([]*dpb.SourceCodeInfo_Location, 0, root.countLocs())
476 root.aggregate(&locs)
477 // finally, sort the resulting slice by location
478 sort.Slice(locs, func(i, j int) bool {
479 startI, endI := getSpanPositions(locs[i].Span)
480 startJ, endJ := getSpanPositions(locs[j].Span)
481 cmp := compareSlice(startI, startJ)
482 if cmp == 0 {
483 // if start position is the same, sort by end position _decreasing_
484 // (so enclosing locations will appear before leaves)
485 cmp = -compareSlice(endI, endJ)
486 if cmp == 0 {
487 // start and end position are the same? so break ties using path
488 cmp = compareSlice(locs[i].Path, locs[j].Path)
489 }
490 }
491 return cmp < 0
492 })
493 return locs
494}
495
496type locTrie struct {
497 children map[int32]*locTrie
498 loc *dpb.SourceCodeInfo_Location
499}
500
501func (t *locTrie) add(path []int32, loc *dpb.SourceCodeInfo_Location) {
502 if len(path) == 0 {
503 t.loc = loc
504 return
505 }
506 child := t.children[path[0]]
507 if child == nil {
508 if t.children == nil {
509 t.children = map[int32]*locTrie{}
510 }
511 child = &locTrie{}
512 t.children[path[0]] = child
513 }
514 child.add(path[1:], loc)
515}
516
517func (t *locTrie) fillIn() {
518 var path []int32
519 var start, end []int32
520 for _, child := range t.children {
521 // recurse
522 child.fillIn()
523 if t.loc == nil {
524 // maintain min(start) and max(end) so we can
525 // populate t.loc below
526 childStart, childEnd := getSpanPositions(child.loc.Span)
527
528 if start == nil {
529 if path == nil {
530 path = child.loc.Path[:len(child.loc.Path)-1]
531 }
532 start = childStart
533 end = childEnd
534 } else {
535 if compareSlice(childStart, start) < 0 {
536 start = childStart
537 }
538 if compareSlice(childEnd, end) > 0 {
539 end = childEnd
540 }
541 }
542 }
543 }
544
545 if t.loc == nil {
546 var span []int32
547 // we don't use append below because we want a new slice
548 // that doesn't share underlying buffer with spans from
549 // any other location
550 if start[0] == end[0] {
551 span = []int32{start[0], start[1], end[1]}
552 } else {
553 span = []int32{start[0], start[1], end[0], end[1]}
554 }
555 t.loc = &dpb.SourceCodeInfo_Location{
556 Path: path,
557 Span: span,
558 }
559 }
560}
561
562func (t *locTrie) countLocs() int {
563 count := 0
564 if t.loc != nil {
565 count = 1
566 }
567 for _, ch := range t.children {
568 count += ch.countLocs()
569 }
570 return count
571}
572
573func (t *locTrie) aggregate(dest *[]*dpb.SourceCodeInfo_Location) {
574 if t.loc != nil {
575 *dest = append(*dest, t.loc)
576 }
577 for _, child := range t.children {
578 child.aggregate(dest)
579 }
580}
581
582func getSpanPositions(span []int32) (start, end []int32) {
583 start = span[:2]
584 if len(span) == 3 {
585 end = []int32{span[0], span[2]}
586 } else {
587 end = span[2:]
588 }
589 return
590}
591
592func compareSlice(a, b []int32) int {
593 end := len(a)
594 if len(b) < end {
595 end = len(b)
596 }
597 for i := 0; i < end; i++ {
598 if a[i] < b[i] {
599 return -1
600 }
601 if a[i] > b[i] {
602 return 1
603 }
604 }
605 if len(a) < len(b) {
606 return -1
607 }
608 if len(a) > len(b) {
609 return 1
610 }
611 return 0
612}