David K. Bainbridge | bd6b288 | 2021-08-26 13:31:02 +0000 | [diff] [blame] | 1 | // Copyright 2021 Google Inc. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
| 4 | |
| 5 | package uuid |
| 6 | |
| 7 | import ( |
| 8 | "bytes" |
| 9 | "database/sql/driver" |
| 10 | "encoding/json" |
| 11 | "fmt" |
| 12 | ) |
| 13 | |
| 14 | var jsonNull = []byte("null") |
| 15 | |
| 16 | // NullUUID represents a UUID that may be null. |
| 17 | // NullUUID implements the SQL driver.Scanner interface so |
| 18 | // it can be used as a scan destination: |
| 19 | // |
| 20 | // var u uuid.NullUUID |
| 21 | // err := db.QueryRow("SELECT name FROM foo WHERE id=?", id).Scan(&u) |
| 22 | // ... |
| 23 | // if u.Valid { |
| 24 | // // use u.UUID |
| 25 | // } else { |
| 26 | // // NULL value |
| 27 | // } |
| 28 | // |
| 29 | type NullUUID struct { |
| 30 | UUID UUID |
| 31 | Valid bool // Valid is true if UUID is not NULL |
| 32 | } |
| 33 | |
| 34 | // Scan implements the SQL driver.Scanner interface. |
| 35 | func (nu *NullUUID) Scan(value interface{}) error { |
| 36 | if value == nil { |
| 37 | nu.UUID, nu.Valid = Nil, false |
| 38 | return nil |
| 39 | } |
| 40 | |
| 41 | err := nu.UUID.Scan(value) |
| 42 | if err != nil { |
| 43 | nu.Valid = false |
| 44 | return err |
| 45 | } |
| 46 | |
| 47 | nu.Valid = true |
| 48 | return nil |
| 49 | } |
| 50 | |
| 51 | // Value implements the driver Valuer interface. |
| 52 | func (nu NullUUID) Value() (driver.Value, error) { |
| 53 | if !nu.Valid { |
| 54 | return nil, nil |
| 55 | } |
| 56 | // Delegate to UUID Value function |
| 57 | return nu.UUID.Value() |
| 58 | } |
| 59 | |
| 60 | // MarshalBinary implements encoding.BinaryMarshaler. |
| 61 | func (nu NullUUID) MarshalBinary() ([]byte, error) { |
| 62 | if nu.Valid { |
| 63 | return nu.UUID[:], nil |
| 64 | } |
| 65 | |
| 66 | return []byte(nil), nil |
| 67 | } |
| 68 | |
| 69 | // UnmarshalBinary implements encoding.BinaryUnmarshaler. |
| 70 | func (nu *NullUUID) UnmarshalBinary(data []byte) error { |
| 71 | if len(data) != 16 { |
| 72 | return fmt.Errorf("invalid UUID (got %d bytes)", len(data)) |
| 73 | } |
| 74 | copy(nu.UUID[:], data) |
| 75 | nu.Valid = true |
| 76 | return nil |
| 77 | } |
| 78 | |
| 79 | // MarshalText implements encoding.TextMarshaler. |
| 80 | func (nu NullUUID) MarshalText() ([]byte, error) { |
| 81 | if nu.Valid { |
| 82 | return nu.UUID.MarshalText() |
| 83 | } |
| 84 | |
| 85 | return jsonNull, nil |
| 86 | } |
| 87 | |
| 88 | // UnmarshalText implements encoding.TextUnmarshaler. |
| 89 | func (nu *NullUUID) UnmarshalText(data []byte) error { |
| 90 | id, err := ParseBytes(data) |
| 91 | if err != nil { |
| 92 | nu.Valid = false |
| 93 | return err |
| 94 | } |
| 95 | nu.UUID = id |
| 96 | nu.Valid = true |
| 97 | return nil |
| 98 | } |
| 99 | |
| 100 | // MarshalJSON implements json.Marshaler. |
| 101 | func (nu NullUUID) MarshalJSON() ([]byte, error) { |
| 102 | if nu.Valid { |
| 103 | return json.Marshal(nu.UUID) |
| 104 | } |
| 105 | |
| 106 | return jsonNull, nil |
| 107 | } |
| 108 | |
| 109 | // UnmarshalJSON implements json.Unmarshaler. |
| 110 | func (nu *NullUUID) UnmarshalJSON(data []byte) error { |
| 111 | if bytes.Equal(data, jsonNull) { |
| 112 | *nu = NullUUID{} |
| 113 | return nil // valid null UUID |
| 114 | } |
| 115 | err := json.Unmarshal(data, &nu.UUID) |
| 116 | nu.Valid = err == nil |
| 117 | return err |
| 118 | } |