blob: f52eee9d26186551a731ec35c4ba011f9e36f875 [file] [log] [blame]
Scott Bakered4efab2020-01-13 19:12:25 -08001// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package asn1
6
7import (
8 "bytes"
9 "errors"
10 "fmt"
11 "io"
12 "math/big"
13 "reflect"
14 "time"
15 "unicode/utf8"
16)
17
18// A forkableWriter is an in-memory buffer that can be
19// 'forked' to create new forkableWriters that bracket the
20// original. After
21// pre, post := w.fork()
22// the overall sequence of bytes represented is logically w+pre+post.
23type forkableWriter struct {
24 *bytes.Buffer
25 pre, post *forkableWriter
26}
27
28func newForkableWriter() *forkableWriter {
29 return &forkableWriter{new(bytes.Buffer), nil, nil}
30}
31
32func (f *forkableWriter) fork() (pre, post *forkableWriter) {
33 if f.pre != nil || f.post != nil {
34 panic("have already forked")
35 }
36 f.pre = newForkableWriter()
37 f.post = newForkableWriter()
38 return f.pre, f.post
39}
40
41func (f *forkableWriter) Len() (l int) {
42 l += f.Buffer.Len()
43 if f.pre != nil {
44 l += f.pre.Len()
45 }
46 if f.post != nil {
47 l += f.post.Len()
48 }
49 return
50}
51
52func (f *forkableWriter) writeTo(out io.Writer) (n int, err error) {
53 n, err = out.Write(f.Bytes())
54 if err != nil {
55 return
56 }
57
58 var nn int
59
60 if f.pre != nil {
61 nn, err = f.pre.writeTo(out)
62 n += nn
63 if err != nil {
64 return
65 }
66 }
67
68 if f.post != nil {
69 nn, err = f.post.writeTo(out)
70 n += nn
71 }
72 return
73}
74
75func marshalBase128Int(out *forkableWriter, n int64) (err error) {
76 if n == 0 {
77 err = out.WriteByte(0)
78 return
79 }
80
81 l := 0
82 for i := n; i > 0; i >>= 7 {
83 l++
84 }
85
86 for i := l - 1; i >= 0; i-- {
87 o := byte(n >> uint(i*7))
88 o &= 0x7f
89 if i != 0 {
90 o |= 0x80
91 }
92 err = out.WriteByte(o)
93 if err != nil {
94 return
95 }
96 }
97
98 return nil
99}
100
101func marshalInt64(out *forkableWriter, i int64) (err error) {
102 n := int64Length(i)
103
104 for ; n > 0; n-- {
105 err = out.WriteByte(byte(i >> uint((n-1)*8)))
106 if err != nil {
107 return
108 }
109 }
110
111 return nil
112}
113
114func int64Length(i int64) (numBytes int) {
115 numBytes = 1
116
117 for i > 127 {
118 numBytes++
119 i >>= 8
120 }
121
122 for i < -128 {
123 numBytes++
124 i >>= 8
125 }
126
127 return
128}
129
130func marshalBigInt(out *forkableWriter, n *big.Int) (err error) {
131 if n.Sign() < 0 {
132 // A negative number has to be converted to two's-complement
133 // form. So we'll subtract 1 and invert. If the
134 // most-significant-bit isn't set then we'll need to pad the
135 // beginning with 0xff in order to keep the number negative.
136 nMinus1 := new(big.Int).Neg(n)
137 nMinus1.Sub(nMinus1, bigOne)
138 bytes := nMinus1.Bytes()
139 for i := range bytes {
140 bytes[i] ^= 0xff
141 }
142 if len(bytes) == 0 || bytes[0]&0x80 == 0 {
143 err = out.WriteByte(0xff)
144 if err != nil {
145 return
146 }
147 }
148 _, err = out.Write(bytes)
149 } else if n.Sign() == 0 {
150 // Zero is written as a single 0 zero rather than no bytes.
151 err = out.WriteByte(0x00)
152 } else {
153 bytes := n.Bytes()
154 if len(bytes) > 0 && bytes[0]&0x80 != 0 {
155 // We'll have to pad this with 0x00 in order to stop it
156 // looking like a negative number.
157 err = out.WriteByte(0)
158 if err != nil {
159 return
160 }
161 }
162 _, err = out.Write(bytes)
163 }
164 return
165}
166
167func marshalLength(out *forkableWriter, i int) (err error) {
168 n := lengthLength(i)
169
170 for ; n > 0; n-- {
171 err = out.WriteByte(byte(i >> uint((n-1)*8)))
172 if err != nil {
173 return
174 }
175 }
176
177 return nil
178}
179
180func lengthLength(i int) (numBytes int) {
181 numBytes = 1
182 for i > 255 {
183 numBytes++
184 i >>= 8
185 }
186 return
187}
188
189func marshalTagAndLength(out *forkableWriter, t tagAndLength) (err error) {
190 b := uint8(t.class) << 6
191 if t.isCompound {
192 b |= 0x20
193 }
194 if t.tag >= 31 {
195 b |= 0x1f
196 err = out.WriteByte(b)
197 if err != nil {
198 return
199 }
200 err = marshalBase128Int(out, int64(t.tag))
201 if err != nil {
202 return
203 }
204 } else {
205 b |= uint8(t.tag)
206 err = out.WriteByte(b)
207 if err != nil {
208 return
209 }
210 }
211
212 if t.length >= 128 {
213 l := lengthLength(t.length)
214 err = out.WriteByte(0x80 | byte(l))
215 if err != nil {
216 return
217 }
218 err = marshalLength(out, t.length)
219 if err != nil {
220 return
221 }
222 } else {
223 err = out.WriteByte(byte(t.length))
224 if err != nil {
225 return
226 }
227 }
228
229 return nil
230}
231
232func marshalBitString(out *forkableWriter, b BitString) (err error) {
233 paddingBits := byte((8 - b.BitLength%8) % 8)
234 err = out.WriteByte(paddingBits)
235 if err != nil {
236 return
237 }
238 _, err = out.Write(b.Bytes)
239 return
240}
241
242func marshalObjectIdentifier(out *forkableWriter, oid []int) (err error) {
243 if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
244 return StructuralError{"invalid object identifier"}
245 }
246
247 err = marshalBase128Int(out, int64(oid[0]*40+oid[1]))
248 if err != nil {
249 return
250 }
251 for i := 2; i < len(oid); i++ {
252 err = marshalBase128Int(out, int64(oid[i]))
253 if err != nil {
254 return
255 }
256 }
257
258 return
259}
260
261func marshalPrintableString(out *forkableWriter, s string) (err error) {
262 b := []byte(s)
263 for _, c := range b {
264 if !isPrintable(c) {
265 return StructuralError{"PrintableString contains invalid character"}
266 }
267 }
268
269 _, err = out.Write(b)
270 return
271}
272
273func marshalIA5String(out *forkableWriter, s string) (err error) {
274 b := []byte(s)
275 for _, c := range b {
276 if c > 127 {
277 return StructuralError{"IA5String contains invalid character"}
278 }
279 }
280
281 _, err = out.Write(b)
282 return
283}
284
285func marshalUTF8String(out *forkableWriter, s string) (err error) {
286 _, err = out.Write([]byte(s))
287 return
288}
289
290func marshalTwoDigits(out *forkableWriter, v int) (err error) {
291 err = out.WriteByte(byte('0' + (v/10)%10))
292 if err != nil {
293 return
294 }
295 return out.WriteByte(byte('0' + v%10))
296}
297
298func marshalFourDigits(out *forkableWriter, v int) (err error) {
299 var bytes [4]byte
300 for i := range bytes {
301 bytes[3-i] = '0' + byte(v%10)
302 v /= 10
303 }
304 _, err = out.Write(bytes[:])
305 return
306}
307
308func outsideUTCRange(t time.Time) bool {
309 year := t.Year()
310 return year < 1950 || year >= 2050
311}
312
313func marshalUTCTime(out *forkableWriter, t time.Time) (err error) {
314 year := t.Year()
315
316 switch {
317 case 1950 <= year && year < 2000:
318 err = marshalTwoDigits(out, year-1900)
319 case 2000 <= year && year < 2050:
320 err = marshalTwoDigits(out, year-2000)
321 default:
322 return StructuralError{"cannot represent time as UTCTime"}
323 }
324 if err != nil {
325 return
326 }
327
328 return marshalTimeCommon(out, t)
329}
330
331func marshalGeneralizedTime(out *forkableWriter, t time.Time) (err error) {
332 year := t.Year()
333 if year < 0 || year > 9999 {
334 return StructuralError{"cannot represent time as GeneralizedTime"}
335 }
336 if err = marshalFourDigits(out, year); err != nil {
337 return
338 }
339
340 return marshalTimeCommon(out, t)
341}
342
343func marshalTimeCommon(out *forkableWriter, t time.Time) (err error) {
344 _, month, day := t.Date()
345
346 err = marshalTwoDigits(out, int(month))
347 if err != nil {
348 return
349 }
350
351 err = marshalTwoDigits(out, day)
352 if err != nil {
353 return
354 }
355
356 hour, min, sec := t.Clock()
357
358 err = marshalTwoDigits(out, hour)
359 if err != nil {
360 return
361 }
362
363 err = marshalTwoDigits(out, min)
364 if err != nil {
365 return
366 }
367
368 err = marshalTwoDigits(out, sec)
369 if err != nil {
370 return
371 }
372
373 _, offset := t.Zone()
374
375 switch {
376 case offset/60 == 0:
377 err = out.WriteByte('Z')
378 return
379 case offset > 0:
380 err = out.WriteByte('+')
381 case offset < 0:
382 err = out.WriteByte('-')
383 }
384
385 if err != nil {
386 return
387 }
388
389 offsetMinutes := offset / 60
390 if offsetMinutes < 0 {
391 offsetMinutes = -offsetMinutes
392 }
393
394 err = marshalTwoDigits(out, offsetMinutes/60)
395 if err != nil {
396 return
397 }
398
399 err = marshalTwoDigits(out, offsetMinutes%60)
400 return
401}
402
403func stripTagAndLength(in []byte) []byte {
404 _, offset, err := parseTagAndLength(in, 0)
405 if err != nil {
406 return in
407 }
408 return in[offset:]
409}
410
411func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameters) (err error) {
412 switch value.Type() {
413 case flagType:
414 return nil
415 case timeType:
416 t := value.Interface().(time.Time)
417 if params.timeType == TagGeneralizedTime || outsideUTCRange(t) {
418 return marshalGeneralizedTime(out, t)
419 } else {
420 return marshalUTCTime(out, t)
421 }
422 case bitStringType:
423 return marshalBitString(out, value.Interface().(BitString))
424 case objectIdentifierType:
425 return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier))
426 case bigIntType:
427 return marshalBigInt(out, value.Interface().(*big.Int))
428 }
429
430 switch v := value; v.Kind() {
431 case reflect.Bool:
432 if v.Bool() {
433 return out.WriteByte(255)
434 } else {
435 return out.WriteByte(0)
436 }
437 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
438 return marshalInt64(out, v.Int())
439 case reflect.Struct:
440 t := v.Type()
441
442 startingField := 0
443
444 // If the first element of the structure is a non-empty
445 // RawContents, then we don't bother serializing the rest.
446 if t.NumField() > 0 && t.Field(0).Type == rawContentsType {
447 s := v.Field(0)
448 if s.Len() > 0 {
449 bytes := make([]byte, s.Len())
450 for i := 0; i < s.Len(); i++ {
451 bytes[i] = uint8(s.Index(i).Uint())
452 }
453 /* The RawContents will contain the tag and
454 * length fields but we'll also be writing
455 * those ourselves, so we strip them out of
456 * bytes */
457 _, err = out.Write(stripTagAndLength(bytes))
458 return
459 } else {
460 startingField = 1
461 }
462 }
463
464 for i := startingField; i < t.NumField(); i++ {
465 var pre *forkableWriter
466 pre, out = out.fork()
467 err = marshalField(pre, v.Field(i), parseFieldParameters(t.Field(i).Tag.Get("asn1")))
468 if err != nil {
469 return
470 }
471 }
472 return
473 case reflect.Slice:
474 sliceType := v.Type()
475 if sliceType.Elem().Kind() == reflect.Uint8 {
476 bytes := make([]byte, v.Len())
477 for i := 0; i < v.Len(); i++ {
478 bytes[i] = uint8(v.Index(i).Uint())
479 }
480 _, err = out.Write(bytes)
481 return
482 }
483
484 // jtasn1 Pass on the tags to the members but need to unset explicit switch and implicit value
485 //var fp fieldParameters
486 params.explicit = false
487 params.tag = nil
488 for i := 0; i < v.Len(); i++ {
489 var pre *forkableWriter
490 pre, out = out.fork()
491 err = marshalField(pre, v.Index(i), params)
492 if err != nil {
493 return
494 }
495 }
496 return
497 case reflect.String:
498 switch params.stringType {
499 case TagIA5String:
500 return marshalIA5String(out, v.String())
501 case TagPrintableString:
502 return marshalPrintableString(out, v.String())
503 default:
504 return marshalUTF8String(out, v.String())
505 }
506 }
507
508 return StructuralError{"unknown Go type"}
509}
510
511func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) (err error) {
512 if !v.IsValid() {
513 return fmt.Errorf("asn1: cannot marshal nil value")
514 }
515 // If the field is an interface{} then recurse into it.
516 if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 {
517 return marshalField(out, v.Elem(), params)
518 }
519
520 if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty {
521 return
522 }
523
524 if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) {
525 defaultValue := reflect.New(v.Type()).Elem()
526 defaultValue.SetInt(*params.defaultValue)
527
528 if reflect.DeepEqual(v.Interface(), defaultValue.Interface()) {
529 return
530 }
531 }
532
533 // If no default value is given then the zero value for the type is
534 // assumed to be the default value. This isn't obviously the correct
535 // behaviour, but it's what Go has traditionally done.
536 if params.optional && params.defaultValue == nil {
537 if reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
538 return
539 }
540 }
541
542 if v.Type() == rawValueType {
543 rv := v.Interface().(RawValue)
544 if len(rv.FullBytes) != 0 {
545 _, err = out.Write(rv.FullBytes)
546 } else {
547 err = marshalTagAndLength(out, tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound})
548 if err != nil {
549 return
550 }
551 _, err = out.Write(rv.Bytes)
552 }
553 return
554 }
555
556 tag, isCompound, ok := getUniversalType(v.Type())
557 if !ok {
558 err = StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
559 return
560 }
561 class := ClassUniversal
562
563 if params.timeType != 0 && tag != TagUTCTime {
564 return StructuralError{"explicit time type given to non-time member"}
565 }
566
567 // jtasn1 updated to allow slices of strings
568 if params.stringType != 0 && !(tag == TagPrintableString || (v.Kind() == reflect.Slice && tag == 16 && v.Type().Elem().Kind() == reflect.String)) {
569 return StructuralError{"explicit string type given to non-string member"}
570 }
571
572 switch tag {
573 case TagPrintableString:
574 if params.stringType == 0 {
575 // This is a string without an explicit string type. We'll use
576 // a PrintableString if the character set in the string is
577 // sufficiently limited, otherwise we'll use a UTF8String.
578 for _, r := range v.String() {
579 if r >= utf8.RuneSelf || !isPrintable(byte(r)) {
580 if !utf8.ValidString(v.String()) {
581 return errors.New("asn1: string not valid UTF-8")
582 }
583 tag = TagUTF8String
584 break
585 }
586 }
587 } else {
588 tag = params.stringType
589 }
590 case TagUTCTime:
591 if params.timeType == TagGeneralizedTime || outsideUTCRange(v.Interface().(time.Time)) {
592 tag = TagGeneralizedTime
593 }
594 }
595
596 if params.set {
597 if tag != TagSequence {
598 return StructuralError{"non sequence tagged as set"}
599 }
600 tag = TagSet
601 }
602
603 tags, body := out.fork()
604
605 err = marshalBody(body, v, params)
606 if err != nil {
607 return
608 }
609
610 bodyLen := body.Len()
611
612 var explicitTag *forkableWriter
613 if params.explicit {
614 explicitTag, tags = tags.fork()
615 }
616
617 if !params.explicit && params.tag != nil {
618 // implicit tag.
619 tag = *params.tag
620 class = ClassContextSpecific
621 }
622
623 err = marshalTagAndLength(tags, tagAndLength{class, tag, bodyLen, isCompound})
624 if err != nil {
625 return
626 }
627
628 if params.explicit {
629 err = marshalTagAndLength(explicitTag, tagAndLength{
630 class: ClassContextSpecific,
631 tag: *params.tag,
632 length: bodyLen + tags.Len(),
633 isCompound: true,
634 })
635 }
636
637 return err
638}
639
640// Marshal returns the ASN.1 encoding of val.
641//
642// In addition to the struct tags recognised by Unmarshal, the following can be
643// used:
644//
645// ia5: causes strings to be marshaled as ASN.1, IA5 strings
646// omitempty: causes empty slices to be skipped
647// printable: causes strings to be marshaled as ASN.1, PrintableString strings.
648// utf8: causes strings to be marshaled as ASN.1, UTF8 strings
649func Marshal(val interface{}) ([]byte, error) {
650 var out bytes.Buffer
651 v := reflect.ValueOf(val)
652 f := newForkableWriter()
653 err := marshalField(f, v, fieldParameters{})
654 if err != nil {
655 return nil, err
656 }
657 _, err = f.writeTo(&out)
658 return out.Bytes(), err
659}