blob: 6157b4ef9e2997bea80b2930408a2b73fee59afc [file] [log] [blame]
Takahiro Suzuki241c10e2020-12-17 20:17:57 +09001// Package ndr provides the ability to unmarshal NDR encoded byte steams into Go data structures
2package ndr
3
4import (
5 "bufio"
6 "fmt"
7 "io"
8 "reflect"
9 "strings"
10)
11
12// Struct tag values
13const (
14 TagConformant = "conformant"
15 TagVarying = "varying"
16 TagPointer = "pointer"
17 TagPipe = "pipe"
18)
19
20// Decoder unmarshals NDR byte stream data into a Go struct representation
21type Decoder struct {
22 r *bufio.Reader // source of the data
23 size int // initial size of bytes in buffer
24 ch CommonHeader // NDR common header
25 ph PrivateHeader // NDR private header
26 conformantMax []uint32 // conformant max values that were moved to the beginning of the structure
27 s interface{} // pointer to the structure being populated
28 current []string // keeps track of the current field being populated
29}
30
31type deferedPtr struct {
32 v reflect.Value
33 tag reflect.StructTag
34}
35
36// NewDecoder creates a new instance of a NDR Decoder.
37func NewDecoder(r io.Reader) *Decoder {
38 dec := new(Decoder)
39 dec.r = bufio.NewReader(r)
40 dec.r.Peek(int(commonHeaderBytes)) // For some reason an operation is needed on the buffer to initialise it so Buffered() != 0
41 dec.size = dec.r.Buffered()
42 return dec
43}
44
45// Decode unmarshals the NDR encoded bytes into the pointer of a struct provided.
46func (dec *Decoder) Decode(s interface{}) error {
47 dec.s = s
48 err := dec.readCommonHeader()
49 if err != nil {
50 return err
51 }
52 err = dec.readPrivateHeader()
53 if err != nil {
54 return err
55 }
56 _, err = dec.r.Discard(4) //The next 4 bytes are an RPC unique pointer referent. We just skip these.
57 if err != nil {
58 return Errorf("unable to process byte stream: %v", err)
59 }
60
61 return dec.process(s, reflect.StructTag(""))
62}
63
64func (dec *Decoder) process(s interface{}, tag reflect.StructTag) error {
65 // Scan for conformant fields as their max counts are moved to the beginning
66 // http://pubs.opengroup.org/onlinepubs/9629399/chap14.htm#tagfcjh_37
67 err := dec.scanConformantArrays(s, tag)
68 if err != nil {
69 return err
70 }
71 // Recursively fill the struct fields
72 var localDef []deferedPtr
73 err = dec.fill(s, tag, &localDef)
74 if err != nil {
75 return Errorf("could not decode: %v", err)
76 }
77 // Read any deferred referents associated with pointers
78 for _, p := range localDef {
79 err = dec.process(p.v, p.tag)
80 if err != nil {
81 return fmt.Errorf("could not decode deferred referent: %v", err)
82 }
83 }
84 return nil
85}
86
87// scanConformantArrays scans the structure for embedded conformant fields and captures the maximum element counts for
88// dimensions of the array that are moved to the beginning of the structure.
89func (dec *Decoder) scanConformantArrays(s interface{}, tag reflect.StructTag) error {
90 err := dec.conformantScan(s, tag)
91 if err != nil {
92 return fmt.Errorf("failed to scan for embedded conformant arrays: %v", err)
93 }
94 for i := range dec.conformantMax {
95 dec.conformantMax[i], err = dec.readUint32()
96 if err != nil {
97 return fmt.Errorf("could not read preceding conformant max count index %d: %v", i, err)
98 }
99 }
100 return nil
101}
102
103// conformantScan inspects the structure's fields for whether they are conformant.
104func (dec *Decoder) conformantScan(s interface{}, tag reflect.StructTag) error {
105 ndrTag := parseTags(tag)
106 if ndrTag.HasValue(TagPointer) {
107 return nil
108 }
109 v := getReflectValue(s)
110 switch v.Kind() {
111 case reflect.Struct:
112 for i := 0; i < v.NumField(); i++ {
113 err := dec.conformantScan(v.Field(i), v.Type().Field(i).Tag)
114 if err != nil {
115 return err
116 }
117 }
118 case reflect.String:
119 if !ndrTag.HasValue(TagConformant) {
120 break
121 }
122 dec.conformantMax = append(dec.conformantMax, uint32(0))
123 case reflect.Slice:
124 if !ndrTag.HasValue(TagConformant) {
125 break
126 }
127 d, t := sliceDimensions(v.Type())
128 for i := 0; i < d; i++ {
129 dec.conformantMax = append(dec.conformantMax, uint32(0))
130 }
131 // For string arrays there is a common max for the strings within the array.
132 if t.Kind() == reflect.String {
133 dec.conformantMax = append(dec.conformantMax, uint32(0))
134 }
135 }
136 return nil
137}
138
139func (dec *Decoder) isPointer(v reflect.Value, tag reflect.StructTag, def *[]deferedPtr) (bool, error) {
140 // Pointer so defer filling the referent
141 ndrTag := parseTags(tag)
142 if ndrTag.HasValue(TagPointer) {
143 p, err := dec.readUint32()
144 if err != nil {
145 return true, fmt.Errorf("could not read pointer: %v", err)
146 }
147 ndrTag.delete(TagPointer)
148 if p != 0 {
149 // if pointer is not zero add to the deferred items at end of stream
150 *def = append(*def, deferedPtr{v, ndrTag.StructTag()})
151 }
152 return true, nil
153 }
154 return false, nil
155}
156
157func getReflectValue(s interface{}) (v reflect.Value) {
158 if r, ok := s.(reflect.Value); ok {
159 v = r
160 } else {
161 if reflect.ValueOf(s).Kind() == reflect.Ptr {
162 v = reflect.ValueOf(s).Elem()
163 }
164 }
165 return
166}
167
168// fill populates fields with values from the NDR byte stream.
169func (dec *Decoder) fill(s interface{}, tag reflect.StructTag, localDef *[]deferedPtr) error {
170 v := getReflectValue(s)
171
172 //// Pointer so defer filling the referent
173 ptr, err := dec.isPointer(v, tag, localDef)
174 if err != nil {
175 return fmt.Errorf("could not process struct field(%s): %v", strings.Join(dec.current, "/"), err)
176 }
177 if ptr {
178 return nil
179 }
180
181 // Populate the value from the byte stream
182 switch v.Kind() {
183 case reflect.Struct:
184 dec.current = append(dec.current, v.Type().Name()) //Track the current field being filled
185 // in case struct is a union, track this and the selected union field for efficiency
186 var unionTag reflect.Value
187 var unionField string // field to fill if struct is a union
188 // Go through each field in the struct and recursively fill
189 for i := 0; i < v.NumField(); i++ {
190 fieldName := v.Type().Field(i).Name
191 dec.current = append(dec.current, fieldName) //Track the current field being filled
192 //fmt.Fprintf(os.Stderr, "DEBUG Decoding: %s\n", strings.Join(dec.current, "/"))
193 structTag := v.Type().Field(i).Tag
194 ndrTag := parseTags(structTag)
195
196 // Union handling
197 if !unionTag.IsValid() {
198 // Is this field a union tag?
199 unionTag = dec.isUnion(v.Field(i), structTag)
200 } else {
201 // What is the selected field value of the union if we don't already know
202 if unionField == "" {
203 unionField, err = unionSelectedField(v, unionTag)
204 if err != nil {
205 return fmt.Errorf("could not determine selected union value field for %s with discriminat"+
206 " tag %s: %v", v.Type().Name(), unionTag, err)
207 }
208 }
209 if ndrTag.HasValue(TagUnionField) && fieldName != unionField {
210 // is a union and this field has not been selected so will skip it.
211 dec.current = dec.current[:len(dec.current)-1] //This field has been skipped so remove it from the current field tracker
212 continue
213 }
214 }
215
216 // Check if field is a pointer
217 if v.Field(i).Type().Implements(reflect.TypeOf(new(RawBytes)).Elem()) &&
218 v.Field(i).Type().Kind() == reflect.Slice && v.Field(i).Type().Elem().Kind() == reflect.Uint8 {
219 //field is for rawbytes
220 structTag, err = addSizeToTag(v, v.Field(i), structTag)
221 if err != nil {
222 return fmt.Errorf("could not get rawbytes field(%s) size: %v", strings.Join(dec.current, "/"), err)
223 }
224 ptr, err := dec.isPointer(v.Field(i), structTag, localDef)
225 if err != nil {
226 return fmt.Errorf("could not process struct field(%s): %v", strings.Join(dec.current, "/"), err)
227 }
228 if !ptr {
229 err := dec.readRawBytes(v.Field(i), structTag)
230 if err != nil {
231 return fmt.Errorf("could not fill raw bytes struct field(%s): %v", strings.Join(dec.current, "/"), err)
232 }
233 }
234 } else {
235 err := dec.fill(v.Field(i), structTag, localDef)
236 if err != nil {
237 return fmt.Errorf("could not fill struct field(%s): %v", strings.Join(dec.current, "/"), err)
238 }
239 }
240 dec.current = dec.current[:len(dec.current)-1] //This field has been filled so remove it from the current field tracker
241 }
242 dec.current = dec.current[:len(dec.current)-1] //This field has been filled so remove it from the current field tracker
243 case reflect.Bool:
244 i, err := dec.readBool()
245 if err != nil {
246 return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
247 }
248 v.Set(reflect.ValueOf(i))
249 case reflect.Uint8:
250 i, err := dec.readUint8()
251 if err != nil {
252 return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
253 }
254 v.Set(reflect.ValueOf(i))
255 case reflect.Uint16:
256 i, err := dec.readUint16()
257 if err != nil {
258 return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
259 }
260 v.Set(reflect.ValueOf(i))
261 case reflect.Uint32:
262 i, err := dec.readUint32()
263 if err != nil {
264 return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
265 }
266 v.Set(reflect.ValueOf(i))
267 case reflect.Uint64:
268 i, err := dec.readUint64()
269 if err != nil {
270 return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
271 }
272 v.Set(reflect.ValueOf(i))
273 case reflect.Int8:
274 i, err := dec.readInt8()
275 if err != nil {
276 return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
277 }
278 v.Set(reflect.ValueOf(i))
279 case reflect.Int16:
280 i, err := dec.readInt16()
281 if err != nil {
282 return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
283 }
284 v.Set(reflect.ValueOf(i))
285 case reflect.Int32:
286 i, err := dec.readInt32()
287 if err != nil {
288 return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
289 }
290 v.Set(reflect.ValueOf(i))
291 case reflect.Int64:
292 i, err := dec.readInt64()
293 if err != nil {
294 return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
295 }
296 v.Set(reflect.ValueOf(i))
297 case reflect.String:
298 ndrTag := parseTags(tag)
299 conformant := ndrTag.HasValue(TagConformant)
300 // strings are always varying so this is assumed without an explicit tag
301 var s string
302 var err error
303 if conformant {
304 s, err = dec.readConformantVaryingString(localDef)
305 if err != nil {
306 return fmt.Errorf("could not fill with conformant varying string: %v", err)
307 }
308 } else {
309 s, err = dec.readVaryingString(localDef)
310 if err != nil {
311 return fmt.Errorf("could not fill with varying string: %v", err)
312 }
313 }
314 v.Set(reflect.ValueOf(s))
315 case reflect.Float32:
316 i, err := dec.readFloat32()
317 if err != nil {
318 return fmt.Errorf("could not fill %v: %v", v.Type().Name(), err)
319 }
320 v.Set(reflect.ValueOf(i))
321 case reflect.Float64:
322 i, err := dec.readFloat64()
323 if err != nil {
324 return fmt.Errorf("could not fill %v: %v", v.Type().Name(), err)
325 }
326 v.Set(reflect.ValueOf(i))
327 case reflect.Array:
328 err := dec.fillFixedArray(v, tag, localDef)
329 if err != nil {
330 return err
331 }
332 case reflect.Slice:
333 if v.Type().Implements(reflect.TypeOf(new(RawBytes)).Elem()) && v.Type().Elem().Kind() == reflect.Uint8 {
334 //field is for rawbytes
335 err := dec.readRawBytes(v, tag)
336 if err != nil {
337 return fmt.Errorf("could not fill raw bytes struct field(%s): %v", strings.Join(dec.current, "/"), err)
338 }
339 break
340 }
341 ndrTag := parseTags(tag)
342 conformant := ndrTag.HasValue(TagConformant)
343 varying := ndrTag.HasValue(TagVarying)
344 if ndrTag.HasValue(TagPipe) {
345 err := dec.fillPipe(v, tag)
346 if err != nil {
347 return err
348 }
349 break
350 }
351 _, t := sliceDimensions(v.Type())
352 if t.Kind() == reflect.String && !ndrTag.HasValue(subStringArrayValue) {
353 // String array
354 err := dec.readStringsArray(v, tag, localDef)
355 if err != nil {
356 return err
357 }
358 break
359 }
360 // varying is assumed as fixed arrays use the Go array type rather than slice
361 if conformant && varying {
362 err := dec.fillConformantVaryingArray(v, tag, localDef)
363 if err != nil {
364 return err
365 }
366 } else if !conformant && varying {
367 err := dec.fillVaryingArray(v, tag, localDef)
368 if err != nil {
369 return err
370 }
371 } else {
372 //default to conformant and not varying
373 err := dec.fillConformantArray(v, tag, localDef)
374 if err != nil {
375 return err
376 }
377 }
378 default:
379 return fmt.Errorf("unsupported type")
380 }
381 return nil
382}
383
384// readBytes returns a number of bytes from the NDR byte stream.
385func (dec *Decoder) readBytes(n int) ([]byte, error) {
386 //TODO make this take an int64 as input to allow for larger values on all systems?
387 b := make([]byte, n, n)
388 m, err := dec.r.Read(b)
389 if err != nil || m != n {
390 return b, fmt.Errorf("error reading bytes from stream: %v", err)
391 }
392 return b, nil
393}