Takahiro Suzuki | d7bf820 | 2020-12-17 20:21:59 +0900 | [diff] [blame^] | 1 | // Package ndr provides the ability to unmarshal NDR encoded byte steams into Go data structures |
| 2 | package ndr |
| 3 | |
| 4 | import ( |
| 5 | "bufio" |
| 6 | "fmt" |
| 7 | "io" |
| 8 | "reflect" |
| 9 | "strings" |
| 10 | ) |
| 11 | |
| 12 | // Struct tag values |
| 13 | const ( |
| 14 | TagConformant = "conformant" |
| 15 | TagVarying = "varying" |
| 16 | TagPointer = "pointer" |
| 17 | TagPipe = "pipe" |
| 18 | ) |
| 19 | |
| 20 | // Decoder unmarshals NDR byte stream data into a Go struct representation |
| 21 | type Decoder struct { |
| 22 | r *bufio.Reader // source of the data |
| 23 | size int // initial size of bytes in buffer |
| 24 | ch CommonHeader // NDR common header |
| 25 | ph PrivateHeader // NDR private header |
| 26 | conformantMax []uint32 // conformant max values that were moved to the beginning of the structure |
| 27 | s interface{} // pointer to the structure being populated |
| 28 | current []string // keeps track of the current field being populated |
| 29 | } |
| 30 | |
| 31 | type deferedPtr struct { |
| 32 | v reflect.Value |
| 33 | tag reflect.StructTag |
| 34 | } |
| 35 | |
| 36 | // NewDecoder creates a new instance of a NDR Decoder. |
| 37 | func NewDecoder(r io.Reader) *Decoder { |
| 38 | dec := new(Decoder) |
| 39 | dec.r = bufio.NewReader(r) |
| 40 | dec.r.Peek(int(commonHeaderBytes)) // For some reason an operation is needed on the buffer to initialise it so Buffered() != 0 |
| 41 | dec.size = dec.r.Buffered() |
| 42 | return dec |
| 43 | } |
| 44 | |
| 45 | // Decode unmarshals the NDR encoded bytes into the pointer of a struct provided. |
| 46 | func (dec *Decoder) Decode(s interface{}) error { |
| 47 | dec.s = s |
| 48 | err := dec.readCommonHeader() |
| 49 | if err != nil { |
| 50 | return err |
| 51 | } |
| 52 | err = dec.readPrivateHeader() |
| 53 | if err != nil { |
| 54 | return err |
| 55 | } |
| 56 | _, err = dec.r.Discard(4) //The next 4 bytes are an RPC unique pointer referent. We just skip these. |
| 57 | if err != nil { |
| 58 | return Errorf("unable to process byte stream: %v", err) |
| 59 | } |
| 60 | |
| 61 | return dec.process(s, reflect.StructTag("")) |
| 62 | } |
| 63 | |
| 64 | func (dec *Decoder) process(s interface{}, tag reflect.StructTag) error { |
| 65 | // Scan for conformant fields as their max counts are moved to the beginning |
| 66 | // http://pubs.opengroup.org/onlinepubs/9629399/chap14.htm#tagfcjh_37 |
| 67 | err := dec.scanConformantArrays(s, tag) |
| 68 | if err != nil { |
| 69 | return err |
| 70 | } |
| 71 | // Recursively fill the struct fields |
| 72 | var localDef []deferedPtr |
| 73 | err = dec.fill(s, tag, &localDef) |
| 74 | if err != nil { |
| 75 | return Errorf("could not decode: %v", err) |
| 76 | } |
| 77 | // Read any deferred referents associated with pointers |
| 78 | for _, p := range localDef { |
| 79 | err = dec.process(p.v, p.tag) |
| 80 | if err != nil { |
| 81 | return fmt.Errorf("could not decode deferred referent: %v", err) |
| 82 | } |
| 83 | } |
| 84 | return nil |
| 85 | } |
| 86 | |
| 87 | // scanConformantArrays scans the structure for embedded conformant fields and captures the maximum element counts for |
| 88 | // dimensions of the array that are moved to the beginning of the structure. |
| 89 | func (dec *Decoder) scanConformantArrays(s interface{}, tag reflect.StructTag) error { |
| 90 | err := dec.conformantScan(s, tag) |
| 91 | if err != nil { |
| 92 | return fmt.Errorf("failed to scan for embedded conformant arrays: %v", err) |
| 93 | } |
| 94 | for i := range dec.conformantMax { |
| 95 | dec.conformantMax[i], err = dec.readUint32() |
| 96 | if err != nil { |
| 97 | return fmt.Errorf("could not read preceding conformant max count index %d: %v", i, err) |
| 98 | } |
| 99 | } |
| 100 | return nil |
| 101 | } |
| 102 | |
| 103 | // conformantScan inspects the structure's fields for whether they are conformant. |
| 104 | func (dec *Decoder) conformantScan(s interface{}, tag reflect.StructTag) error { |
| 105 | ndrTag := parseTags(tag) |
| 106 | if ndrTag.HasValue(TagPointer) { |
| 107 | return nil |
| 108 | } |
| 109 | v := getReflectValue(s) |
| 110 | switch v.Kind() { |
| 111 | case reflect.Struct: |
| 112 | for i := 0; i < v.NumField(); i++ { |
| 113 | err := dec.conformantScan(v.Field(i), v.Type().Field(i).Tag) |
| 114 | if err != nil { |
| 115 | return err |
| 116 | } |
| 117 | } |
| 118 | case reflect.String: |
| 119 | if !ndrTag.HasValue(TagConformant) { |
| 120 | break |
| 121 | } |
| 122 | dec.conformantMax = append(dec.conformantMax, uint32(0)) |
| 123 | case reflect.Slice: |
| 124 | if !ndrTag.HasValue(TagConformant) { |
| 125 | break |
| 126 | } |
| 127 | d, t := sliceDimensions(v.Type()) |
| 128 | for i := 0; i < d; i++ { |
| 129 | dec.conformantMax = append(dec.conformantMax, uint32(0)) |
| 130 | } |
| 131 | // For string arrays there is a common max for the strings within the array. |
| 132 | if t.Kind() == reflect.String { |
| 133 | dec.conformantMax = append(dec.conformantMax, uint32(0)) |
| 134 | } |
| 135 | } |
| 136 | return nil |
| 137 | } |
| 138 | |
| 139 | func (dec *Decoder) isPointer(v reflect.Value, tag reflect.StructTag, def *[]deferedPtr) (bool, error) { |
| 140 | // Pointer so defer filling the referent |
| 141 | ndrTag := parseTags(tag) |
| 142 | if ndrTag.HasValue(TagPointer) { |
| 143 | p, err := dec.readUint32() |
| 144 | if err != nil { |
| 145 | return true, fmt.Errorf("could not read pointer: %v", err) |
| 146 | } |
| 147 | ndrTag.delete(TagPointer) |
| 148 | if p != 0 { |
| 149 | // if pointer is not zero add to the deferred items at end of stream |
| 150 | *def = append(*def, deferedPtr{v, ndrTag.StructTag()}) |
| 151 | } |
| 152 | return true, nil |
| 153 | } |
| 154 | return false, nil |
| 155 | } |
| 156 | |
| 157 | func getReflectValue(s interface{}) (v reflect.Value) { |
| 158 | if r, ok := s.(reflect.Value); ok { |
| 159 | v = r |
| 160 | } else { |
| 161 | if reflect.ValueOf(s).Kind() == reflect.Ptr { |
| 162 | v = reflect.ValueOf(s).Elem() |
| 163 | } |
| 164 | } |
| 165 | return |
| 166 | } |
| 167 | |
| 168 | // fill populates fields with values from the NDR byte stream. |
| 169 | func (dec *Decoder) fill(s interface{}, tag reflect.StructTag, localDef *[]deferedPtr) error { |
| 170 | v := getReflectValue(s) |
| 171 | |
| 172 | //// Pointer so defer filling the referent |
| 173 | ptr, err := dec.isPointer(v, tag, localDef) |
| 174 | if err != nil { |
| 175 | return fmt.Errorf("could not process struct field(%s): %v", strings.Join(dec.current, "/"), err) |
| 176 | } |
| 177 | if ptr { |
| 178 | return nil |
| 179 | } |
| 180 | |
| 181 | // Populate the value from the byte stream |
| 182 | switch v.Kind() { |
| 183 | case reflect.Struct: |
| 184 | dec.current = append(dec.current, v.Type().Name()) //Track the current field being filled |
| 185 | // in case struct is a union, track this and the selected union field for efficiency |
| 186 | var unionTag reflect.Value |
| 187 | var unionField string // field to fill if struct is a union |
| 188 | // Go through each field in the struct and recursively fill |
| 189 | for i := 0; i < v.NumField(); i++ { |
| 190 | fieldName := v.Type().Field(i).Name |
| 191 | dec.current = append(dec.current, fieldName) //Track the current field being filled |
| 192 | //fmt.Fprintf(os.Stderr, "DEBUG Decoding: %s\n", strings.Join(dec.current, "/")) |
| 193 | structTag := v.Type().Field(i).Tag |
| 194 | ndrTag := parseTags(structTag) |
| 195 | |
| 196 | // Union handling |
| 197 | if !unionTag.IsValid() { |
| 198 | // Is this field a union tag? |
| 199 | unionTag = dec.isUnion(v.Field(i), structTag) |
| 200 | } else { |
| 201 | // What is the selected field value of the union if we don't already know |
| 202 | if unionField == "" { |
| 203 | unionField, err = unionSelectedField(v, unionTag) |
| 204 | if err != nil { |
| 205 | return fmt.Errorf("could not determine selected union value field for %s with discriminat"+ |
| 206 | " tag %s: %v", v.Type().Name(), unionTag, err) |
| 207 | } |
| 208 | } |
| 209 | if ndrTag.HasValue(TagUnionField) && fieldName != unionField { |
| 210 | // is a union and this field has not been selected so will skip it. |
| 211 | dec.current = dec.current[:len(dec.current)-1] //This field has been skipped so remove it from the current field tracker |
| 212 | continue |
| 213 | } |
| 214 | } |
| 215 | |
| 216 | // Check if field is a pointer |
| 217 | if v.Field(i).Type().Implements(reflect.TypeOf(new(RawBytes)).Elem()) && |
| 218 | v.Field(i).Type().Kind() == reflect.Slice && v.Field(i).Type().Elem().Kind() == reflect.Uint8 { |
| 219 | //field is for rawbytes |
| 220 | structTag, err = addSizeToTag(v, v.Field(i), structTag) |
| 221 | if err != nil { |
| 222 | return fmt.Errorf("could not get rawbytes field(%s) size: %v", strings.Join(dec.current, "/"), err) |
| 223 | } |
| 224 | ptr, err := dec.isPointer(v.Field(i), structTag, localDef) |
| 225 | if err != nil { |
| 226 | return fmt.Errorf("could not process struct field(%s): %v", strings.Join(dec.current, "/"), err) |
| 227 | } |
| 228 | if !ptr { |
| 229 | err := dec.readRawBytes(v.Field(i), structTag) |
| 230 | if err != nil { |
| 231 | return fmt.Errorf("could not fill raw bytes struct field(%s): %v", strings.Join(dec.current, "/"), err) |
| 232 | } |
| 233 | } |
| 234 | } else { |
| 235 | err := dec.fill(v.Field(i), structTag, localDef) |
| 236 | if err != nil { |
| 237 | return fmt.Errorf("could not fill struct field(%s): %v", strings.Join(dec.current, "/"), err) |
| 238 | } |
| 239 | } |
| 240 | dec.current = dec.current[:len(dec.current)-1] //This field has been filled so remove it from the current field tracker |
| 241 | } |
| 242 | dec.current = dec.current[:len(dec.current)-1] //This field has been filled so remove it from the current field tracker |
| 243 | case reflect.Bool: |
| 244 | i, err := dec.readBool() |
| 245 | if err != nil { |
| 246 | return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err) |
| 247 | } |
| 248 | v.Set(reflect.ValueOf(i)) |
| 249 | case reflect.Uint8: |
| 250 | i, err := dec.readUint8() |
| 251 | if err != nil { |
| 252 | return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err) |
| 253 | } |
| 254 | v.Set(reflect.ValueOf(i)) |
| 255 | case reflect.Uint16: |
| 256 | i, err := dec.readUint16() |
| 257 | if err != nil { |
| 258 | return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err) |
| 259 | } |
| 260 | v.Set(reflect.ValueOf(i)) |
| 261 | case reflect.Uint32: |
| 262 | i, err := dec.readUint32() |
| 263 | if err != nil { |
| 264 | return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err) |
| 265 | } |
| 266 | v.Set(reflect.ValueOf(i)) |
| 267 | case reflect.Uint64: |
| 268 | i, err := dec.readUint64() |
| 269 | if err != nil { |
| 270 | return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err) |
| 271 | } |
| 272 | v.Set(reflect.ValueOf(i)) |
| 273 | case reflect.Int8: |
| 274 | i, err := dec.readInt8() |
| 275 | if err != nil { |
| 276 | return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err) |
| 277 | } |
| 278 | v.Set(reflect.ValueOf(i)) |
| 279 | case reflect.Int16: |
| 280 | i, err := dec.readInt16() |
| 281 | if err != nil { |
| 282 | return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err) |
| 283 | } |
| 284 | v.Set(reflect.ValueOf(i)) |
| 285 | case reflect.Int32: |
| 286 | i, err := dec.readInt32() |
| 287 | if err != nil { |
| 288 | return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err) |
| 289 | } |
| 290 | v.Set(reflect.ValueOf(i)) |
| 291 | case reflect.Int64: |
| 292 | i, err := dec.readInt64() |
| 293 | if err != nil { |
| 294 | return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err) |
| 295 | } |
| 296 | v.Set(reflect.ValueOf(i)) |
| 297 | case reflect.String: |
| 298 | ndrTag := parseTags(tag) |
| 299 | conformant := ndrTag.HasValue(TagConformant) |
| 300 | // strings are always varying so this is assumed without an explicit tag |
| 301 | var s string |
| 302 | var err error |
| 303 | if conformant { |
| 304 | s, err = dec.readConformantVaryingString(localDef) |
| 305 | if err != nil { |
| 306 | return fmt.Errorf("could not fill with conformant varying string: %v", err) |
| 307 | } |
| 308 | } else { |
| 309 | s, err = dec.readVaryingString(localDef) |
| 310 | if err != nil { |
| 311 | return fmt.Errorf("could not fill with varying string: %v", err) |
| 312 | } |
| 313 | } |
| 314 | v.Set(reflect.ValueOf(s)) |
| 315 | case reflect.Float32: |
| 316 | i, err := dec.readFloat32() |
| 317 | if err != nil { |
| 318 | return fmt.Errorf("could not fill %v: %v", v.Type().Name(), err) |
| 319 | } |
| 320 | v.Set(reflect.ValueOf(i)) |
| 321 | case reflect.Float64: |
| 322 | i, err := dec.readFloat64() |
| 323 | if err != nil { |
| 324 | return fmt.Errorf("could not fill %v: %v", v.Type().Name(), err) |
| 325 | } |
| 326 | v.Set(reflect.ValueOf(i)) |
| 327 | case reflect.Array: |
| 328 | err := dec.fillFixedArray(v, tag, localDef) |
| 329 | if err != nil { |
| 330 | return err |
| 331 | } |
| 332 | case reflect.Slice: |
| 333 | if v.Type().Implements(reflect.TypeOf(new(RawBytes)).Elem()) && v.Type().Elem().Kind() == reflect.Uint8 { |
| 334 | //field is for rawbytes |
| 335 | err := dec.readRawBytes(v, tag) |
| 336 | if err != nil { |
| 337 | return fmt.Errorf("could not fill raw bytes struct field(%s): %v", strings.Join(dec.current, "/"), err) |
| 338 | } |
| 339 | break |
| 340 | } |
| 341 | ndrTag := parseTags(tag) |
| 342 | conformant := ndrTag.HasValue(TagConformant) |
| 343 | varying := ndrTag.HasValue(TagVarying) |
| 344 | if ndrTag.HasValue(TagPipe) { |
| 345 | err := dec.fillPipe(v, tag) |
| 346 | if err != nil { |
| 347 | return err |
| 348 | } |
| 349 | break |
| 350 | } |
| 351 | _, t := sliceDimensions(v.Type()) |
| 352 | if t.Kind() == reflect.String && !ndrTag.HasValue(subStringArrayValue) { |
| 353 | // String array |
| 354 | err := dec.readStringsArray(v, tag, localDef) |
| 355 | if err != nil { |
| 356 | return err |
| 357 | } |
| 358 | break |
| 359 | } |
| 360 | // varying is assumed as fixed arrays use the Go array type rather than slice |
| 361 | if conformant && varying { |
| 362 | err := dec.fillConformantVaryingArray(v, tag, localDef) |
| 363 | if err != nil { |
| 364 | return err |
| 365 | } |
| 366 | } else if !conformant && varying { |
| 367 | err := dec.fillVaryingArray(v, tag, localDef) |
| 368 | if err != nil { |
| 369 | return err |
| 370 | } |
| 371 | } else { |
| 372 | //default to conformant and not varying |
| 373 | err := dec.fillConformantArray(v, tag, localDef) |
| 374 | if err != nil { |
| 375 | return err |
| 376 | } |
| 377 | } |
| 378 | default: |
| 379 | return fmt.Errorf("unsupported type") |
| 380 | } |
| 381 | return nil |
| 382 | } |
| 383 | |
| 384 | // readBytes returns a number of bytes from the NDR byte stream. |
| 385 | func (dec *Decoder) readBytes(n int) ([]byte, error) { |
| 386 | //TODO make this take an int64 as input to allow for larger values on all systems? |
| 387 | b := make([]byte, n, n) |
| 388 | m, err := dec.r.Read(b) |
| 389 | if err != nil || m != n { |
| 390 | return b, fmt.Errorf("error reading bytes from stream: %v", err) |
| 391 | } |
| 392 | return b, nil |
| 393 | } |