Don Newton | 379ae25 | 2019-04-01 12:17:06 -0400 | [diff] [blame^] | 1 | // Copyright (C) MongoDB, Inc. 2017-present. |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may |
| 4 | // not use this file except in compliance with the License. You may obtain |
| 5 | // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | |
| 7 | package wiremessage |
| 8 | |
| 9 | import ( |
| 10 | "errors" |
| 11 | |
| 12 | "github.com/mongodb/mongo-go-driver/bson" |
| 13 | "github.com/mongodb/mongo-go-driver/x/bsonx" |
| 14 | ) |
| 15 | |
| 16 | // Msg represents the OP_MSG message of the MongoDB wire protocol. |
| 17 | type Msg struct { |
| 18 | MsgHeader Header |
| 19 | FlagBits MsgFlag |
| 20 | Sections []Section |
| 21 | Checksum uint32 |
| 22 | } |
| 23 | |
| 24 | // MarshalWireMessage implements the Marshaler and WireMessage interfaces. |
| 25 | func (m Msg) MarshalWireMessage() ([]byte, error) { |
| 26 | b := make([]byte, 0, m.Len()) |
| 27 | return m.AppendWireMessage(b) |
| 28 | } |
| 29 | |
| 30 | // ValidateWireMessage implements the Validator and WireMessage interfaces. |
| 31 | func (m Msg) ValidateWireMessage() error { |
| 32 | if int(m.MsgHeader.MessageLength) != m.Len() { |
| 33 | return errors.New("incorrect header: message length is not correct") |
| 34 | } |
| 35 | if m.MsgHeader.OpCode != OpMsg { |
| 36 | return errors.New("incorrect header: opcode is not OpMsg") |
| 37 | } |
| 38 | |
| 39 | return nil |
| 40 | } |
| 41 | |
| 42 | // AppendWireMessage implements the Appender and WireMessage interfaces. |
| 43 | // |
| 44 | // AppendWireMesssage will set the MessageLength property of the MsgHeader if it is zero. It will also set the Opcode |
| 45 | // to OP_MSG if it is zero. If either of these properties are non-zero and not correct, this method will return both the |
| 46 | // []byte with the wire message appended to it and an invalid header error. |
| 47 | func (m Msg) AppendWireMessage(b []byte) ([]byte, error) { |
| 48 | var err error |
| 49 | err = m.MsgHeader.SetDefaults(m.Len(), OpMsg) |
| 50 | |
| 51 | b = m.MsgHeader.AppendHeader(b) |
| 52 | b = appendInt32(b, int32(m.FlagBits)) |
| 53 | |
| 54 | for _, section := range m.Sections { |
| 55 | newB := make([]byte, 0) |
| 56 | newB = section.AppendSection(newB) |
| 57 | |
| 58 | b = section.AppendSection(b) |
| 59 | } |
| 60 | |
| 61 | return b, err |
| 62 | } |
| 63 | |
| 64 | // String implements the fmt.Stringer interface. |
| 65 | func (m Msg) String() string { |
| 66 | panic("not implemented") |
| 67 | } |
| 68 | |
| 69 | // Len implements the WireMessage interface. |
| 70 | func (m Msg) Len() int { |
| 71 | // Header + Flags + len of each section + optional checksum |
| 72 | totalLen := 16 + 4 // header and flag |
| 73 | |
| 74 | for _, section := range m.Sections { |
| 75 | totalLen += section.Len() |
| 76 | } |
| 77 | |
| 78 | if m.FlagBits&ChecksumPresent > 0 { |
| 79 | totalLen += 4 |
| 80 | } |
| 81 | |
| 82 | return totalLen |
| 83 | } |
| 84 | |
| 85 | // UnmarshalWireMessage implements the Unmarshaler interface. |
| 86 | func (m *Msg) UnmarshalWireMessage(b []byte) error { |
| 87 | var err error |
| 88 | |
| 89 | m.MsgHeader, err = ReadHeader(b, 0) |
| 90 | if err != nil { |
| 91 | return err |
| 92 | } |
| 93 | if len(b) < int(m.MsgHeader.MessageLength) { |
| 94 | return Error{ |
| 95 | Type: ErrOpMsg, |
| 96 | Message: "[]byte too small", |
| 97 | } |
| 98 | } |
| 99 | |
| 100 | m.FlagBits = MsgFlag(readInt32(b, 16)) |
| 101 | |
| 102 | // read each section |
| 103 | sectionBytes := m.MsgHeader.MessageLength - 16 - 4 // number of bytes taken up by sections |
| 104 | hasChecksum := m.FlagBits&ChecksumPresent > 0 |
| 105 | if hasChecksum { |
| 106 | sectionBytes -= 4 // 4 bytes at end for checksum |
| 107 | } |
| 108 | |
| 109 | m.Sections = make([]Section, 0) |
| 110 | position := 20 // position to read from |
| 111 | for sectionBytes > 0 { |
| 112 | sectionType := SectionType(b[position]) |
| 113 | position++ |
| 114 | |
| 115 | switch sectionType { |
| 116 | case SingleDocument: |
| 117 | rdr, size, err := readDocument(b, int32(position)) |
| 118 | if err.Message != "" { |
| 119 | err.Type = ErrOpMsg |
| 120 | return err |
| 121 | } |
| 122 | |
| 123 | position += size |
| 124 | sb := SectionBody{ |
| 125 | Document: rdr, |
| 126 | } |
| 127 | sb.PayloadType = sb.Kind() |
| 128 | |
| 129 | sectionBytes -= int32(sb.Len()) |
| 130 | m.Sections = append(m.Sections, sb) |
| 131 | case DocumentSequence: |
| 132 | sds := SectionDocumentSequence{} |
| 133 | sds.Size = readInt32(b, int32(position)) |
| 134 | position += 4 |
| 135 | |
| 136 | identifier, err := readCString(b, int32(position)) |
| 137 | if err != nil { |
| 138 | return err |
| 139 | } |
| 140 | |
| 141 | sds.Identifier = identifier |
| 142 | position += len(identifier) + 1 // +1 for \0 |
| 143 | sds.PayloadType = sds.Kind() |
| 144 | |
| 145 | // length of documents to read |
| 146 | // sequenceLen - 4 bytes for size field - identifierLength (including \0) |
| 147 | docsLen := int(sds.Size) - 4 - len(identifier) - 1 |
| 148 | for docsLen > 0 { |
| 149 | rdr, size, err := readDocument(b, int32(position)) |
| 150 | if err.Message != "" { |
| 151 | err.Type = ErrOpMsg |
| 152 | return err |
| 153 | } |
| 154 | |
| 155 | position += size |
| 156 | sds.Documents = append(sds.Documents, rdr) |
| 157 | docsLen -= size |
| 158 | } |
| 159 | |
| 160 | sectionBytes -= int32(sds.Len()) |
| 161 | m.Sections = append(m.Sections, sds) |
| 162 | } |
| 163 | } |
| 164 | |
| 165 | if hasChecksum { |
| 166 | m.Checksum = uint32(readInt32(b, int32(position))) |
| 167 | } |
| 168 | |
| 169 | return nil |
| 170 | } |
| 171 | |
| 172 | // GetMainDocument returns the document containing the message to send. |
| 173 | func (m *Msg) GetMainDocument() (bsonx.Doc, error) { |
| 174 | return bsonx.ReadDoc(m.Sections[0].(SectionBody).Document) |
| 175 | } |
| 176 | |
| 177 | // GetSequenceArray returns this message's document sequence as a BSON array along with the array identifier. |
| 178 | // If this message has no associated document sequence, a nil array is returned. |
| 179 | func (m *Msg) GetSequenceArray() (bsonx.Arr, string, error) { |
| 180 | if len(m.Sections) == 1 { |
| 181 | return nil, "", nil |
| 182 | } |
| 183 | |
| 184 | arr := bsonx.Arr{} |
| 185 | sds := m.Sections[1].(SectionDocumentSequence) |
| 186 | |
| 187 | for _, rdr := range sds.Documents { |
| 188 | doc, err := bsonx.ReadDoc([]byte(rdr)) |
| 189 | if err != nil { |
| 190 | return nil, "", err |
| 191 | } |
| 192 | |
| 193 | arr = append(arr, bsonx.Document(doc)) |
| 194 | } |
| 195 | |
| 196 | return arr, sds.Identifier, nil |
| 197 | } |
| 198 | |
| 199 | // AcknowledgedWrite returns true if this msg represents an acknowledged write command. |
| 200 | func (m *Msg) AcknowledgedWrite() bool { |
| 201 | return m.FlagBits&MoreToCome == 0 |
| 202 | } |
| 203 | |
| 204 | // MsgFlag represents the flags on an OP_MSG message. |
| 205 | type MsgFlag uint32 |
| 206 | |
| 207 | // These constants represent the individual flags on an OP_MSG message. |
| 208 | const ( |
| 209 | ChecksumPresent MsgFlag = 1 << iota |
| 210 | MoreToCome |
| 211 | |
| 212 | ExhaustAllowed MsgFlag = 1 << 16 |
| 213 | ) |
| 214 | |
| 215 | // Section represents a section on an OP_MSG message. |
| 216 | type Section interface { |
| 217 | Kind() SectionType |
| 218 | Len() int |
| 219 | AppendSection([]byte) []byte |
| 220 | } |
| 221 | |
| 222 | // SectionBody represents the kind body of an OP_MSG message. |
| 223 | type SectionBody struct { |
| 224 | PayloadType SectionType |
| 225 | Document bson.Raw |
| 226 | } |
| 227 | |
| 228 | // Kind implements the Section interface. |
| 229 | func (sb SectionBody) Kind() SectionType { |
| 230 | return SingleDocument |
| 231 | } |
| 232 | |
| 233 | // Len implements the Section interface |
| 234 | func (sb SectionBody) Len() int { |
| 235 | return 1 + len(sb.Document) // 1 for PayloadType |
| 236 | } |
| 237 | |
| 238 | // AppendSection implements the Section interface. |
| 239 | func (sb SectionBody) AppendSection(dest []byte) []byte { |
| 240 | dest = append(dest, byte(SingleDocument)) |
| 241 | dest = append(dest, sb.Document...) |
| 242 | return dest |
| 243 | } |
| 244 | |
| 245 | // SectionDocumentSequence represents the kind document sequence of an OP_MSG message. |
| 246 | type SectionDocumentSequence struct { |
| 247 | PayloadType SectionType |
| 248 | Size int32 |
| 249 | Identifier string |
| 250 | Documents []bson.Raw |
| 251 | } |
| 252 | |
| 253 | // Kind implements the Section interface. |
| 254 | func (sds SectionDocumentSequence) Kind() SectionType { |
| 255 | return DocumentSequence |
| 256 | } |
| 257 | |
| 258 | // Len implements the Section interface |
| 259 | func (sds SectionDocumentSequence) Len() int { |
| 260 | // PayloadType + Size + Identifier + 1 (null terminator) + totalDocLen |
| 261 | totalDocLen := 0 |
| 262 | for _, doc := range sds.Documents { |
| 263 | totalDocLen += len(doc) |
| 264 | } |
| 265 | |
| 266 | return 1 + 4 + len(sds.Identifier) + 1 + totalDocLen |
| 267 | } |
| 268 | |
| 269 | // PayloadLen returns the length of the payload |
| 270 | func (sds SectionDocumentSequence) PayloadLen() int { |
| 271 | // 4 bytes for size field, len identifier (including \0), and total docs len |
| 272 | return sds.Len() - 1 |
| 273 | } |
| 274 | |
| 275 | // AppendSection implements the Section interface |
| 276 | func (sds SectionDocumentSequence) AppendSection(dest []byte) []byte { |
| 277 | dest = append(dest, byte(DocumentSequence)) |
| 278 | dest = appendInt32(dest, sds.Size) |
| 279 | dest = appendCString(dest, sds.Identifier) |
| 280 | |
| 281 | for _, doc := range sds.Documents { |
| 282 | dest = append(dest, doc...) |
| 283 | } |
| 284 | |
| 285 | return dest |
| 286 | } |
| 287 | |
| 288 | // SectionType represents the type for 1 section in an OP_MSG |
| 289 | type SectionType uint8 |
| 290 | |
| 291 | // These constants represent the individual section types for a section in an OP_MSG |
| 292 | const ( |
| 293 | SingleDocument SectionType = iota |
| 294 | DocumentSequence |
| 295 | ) |
| 296 | |
| 297 | // OpmsgWireVersion is the minimum wire version needed to use OP_MSG |
| 298 | const OpmsgWireVersion = 6 |