blob: 07f35ab99416a2b28d72d8c75686e416fd92a0dc [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package wiremessage
8
9import (
10 "errors"
11
12 "github.com/mongodb/mongo-go-driver/bson"
13 "github.com/mongodb/mongo-go-driver/x/bsonx"
14)
15
16// Msg represents the OP_MSG message of the MongoDB wire protocol.
17type Msg struct {
18 MsgHeader Header
19 FlagBits MsgFlag
20 Sections []Section
21 Checksum uint32
22}
23
24// MarshalWireMessage implements the Marshaler and WireMessage interfaces.
25func (m Msg) MarshalWireMessage() ([]byte, error) {
26 b := make([]byte, 0, m.Len())
27 return m.AppendWireMessage(b)
28}
29
30// ValidateWireMessage implements the Validator and WireMessage interfaces.
31func (m Msg) ValidateWireMessage() error {
32 if int(m.MsgHeader.MessageLength) != m.Len() {
33 return errors.New("incorrect header: message length is not correct")
34 }
35 if m.MsgHeader.OpCode != OpMsg {
36 return errors.New("incorrect header: opcode is not OpMsg")
37 }
38
39 return nil
40}
41
42// AppendWireMessage implements the Appender and WireMessage interfaces.
43//
44// AppendWireMesssage will set the MessageLength property of the MsgHeader if it is zero. It will also set the Opcode
45// to OP_MSG if it is zero. If either of these properties are non-zero and not correct, this method will return both the
46// []byte with the wire message appended to it and an invalid header error.
47func (m Msg) AppendWireMessage(b []byte) ([]byte, error) {
48 var err error
49 err = m.MsgHeader.SetDefaults(m.Len(), OpMsg)
50
51 b = m.MsgHeader.AppendHeader(b)
52 b = appendInt32(b, int32(m.FlagBits))
53
54 for _, section := range m.Sections {
55 newB := make([]byte, 0)
56 newB = section.AppendSection(newB)
57
58 b = section.AppendSection(b)
59 }
60
61 return b, err
62}
63
64// String implements the fmt.Stringer interface.
65func (m Msg) String() string {
66 panic("not implemented")
67}
68
69// Len implements the WireMessage interface.
70func (m Msg) Len() int {
71 // Header + Flags + len of each section + optional checksum
72 totalLen := 16 + 4 // header and flag
73
74 for _, section := range m.Sections {
75 totalLen += section.Len()
76 }
77
78 if m.FlagBits&ChecksumPresent > 0 {
79 totalLen += 4
80 }
81
82 return totalLen
83}
84
85// UnmarshalWireMessage implements the Unmarshaler interface.
86func (m *Msg) UnmarshalWireMessage(b []byte) error {
87 var err error
88
89 m.MsgHeader, err = ReadHeader(b, 0)
90 if err != nil {
91 return err
92 }
93 if len(b) < int(m.MsgHeader.MessageLength) {
94 return Error{
95 Type: ErrOpMsg,
96 Message: "[]byte too small",
97 }
98 }
99
100 m.FlagBits = MsgFlag(readInt32(b, 16))
101
102 // read each section
103 sectionBytes := m.MsgHeader.MessageLength - 16 - 4 // number of bytes taken up by sections
104 hasChecksum := m.FlagBits&ChecksumPresent > 0
105 if hasChecksum {
106 sectionBytes -= 4 // 4 bytes at end for checksum
107 }
108
109 m.Sections = make([]Section, 0)
110 position := 20 // position to read from
111 for sectionBytes > 0 {
112 sectionType := SectionType(b[position])
113 position++
114
115 switch sectionType {
116 case SingleDocument:
117 rdr, size, err := readDocument(b, int32(position))
118 if err.Message != "" {
119 err.Type = ErrOpMsg
120 return err
121 }
122
123 position += size
124 sb := SectionBody{
125 Document: rdr,
126 }
127 sb.PayloadType = sb.Kind()
128
129 sectionBytes -= int32(sb.Len())
130 m.Sections = append(m.Sections, sb)
131 case DocumentSequence:
132 sds := SectionDocumentSequence{}
133 sds.Size = readInt32(b, int32(position))
134 position += 4
135
136 identifier, err := readCString(b, int32(position))
137 if err != nil {
138 return err
139 }
140
141 sds.Identifier = identifier
142 position += len(identifier) + 1 // +1 for \0
143 sds.PayloadType = sds.Kind()
144
145 // length of documents to read
146 // sequenceLen - 4 bytes for size field - identifierLength (including \0)
147 docsLen := int(sds.Size) - 4 - len(identifier) - 1
148 for docsLen > 0 {
149 rdr, size, err := readDocument(b, int32(position))
150 if err.Message != "" {
151 err.Type = ErrOpMsg
152 return err
153 }
154
155 position += size
156 sds.Documents = append(sds.Documents, rdr)
157 docsLen -= size
158 }
159
160 sectionBytes -= int32(sds.Len())
161 m.Sections = append(m.Sections, sds)
162 }
163 }
164
165 if hasChecksum {
166 m.Checksum = uint32(readInt32(b, int32(position)))
167 }
168
169 return nil
170}
171
172// GetMainDocument returns the document containing the message to send.
173func (m *Msg) GetMainDocument() (bsonx.Doc, error) {
174 return bsonx.ReadDoc(m.Sections[0].(SectionBody).Document)
175}
176
177// GetSequenceArray returns this message's document sequence as a BSON array along with the array identifier.
178// If this message has no associated document sequence, a nil array is returned.
179func (m *Msg) GetSequenceArray() (bsonx.Arr, string, error) {
180 if len(m.Sections) == 1 {
181 return nil, "", nil
182 }
183
184 arr := bsonx.Arr{}
185 sds := m.Sections[1].(SectionDocumentSequence)
186
187 for _, rdr := range sds.Documents {
188 doc, err := bsonx.ReadDoc([]byte(rdr))
189 if err != nil {
190 return nil, "", err
191 }
192
193 arr = append(arr, bsonx.Document(doc))
194 }
195
196 return arr, sds.Identifier, nil
197}
198
199// AcknowledgedWrite returns true if this msg represents an acknowledged write command.
200func (m *Msg) AcknowledgedWrite() bool {
201 return m.FlagBits&MoreToCome == 0
202}
203
204// MsgFlag represents the flags on an OP_MSG message.
205type MsgFlag uint32
206
207// These constants represent the individual flags on an OP_MSG message.
208const (
209 ChecksumPresent MsgFlag = 1 << iota
210 MoreToCome
211
212 ExhaustAllowed MsgFlag = 1 << 16
213)
214
215// Section represents a section on an OP_MSG message.
216type Section interface {
217 Kind() SectionType
218 Len() int
219 AppendSection([]byte) []byte
220}
221
222// SectionBody represents the kind body of an OP_MSG message.
223type SectionBody struct {
224 PayloadType SectionType
225 Document bson.Raw
226}
227
228// Kind implements the Section interface.
229func (sb SectionBody) Kind() SectionType {
230 return SingleDocument
231}
232
233// Len implements the Section interface
234func (sb SectionBody) Len() int {
235 return 1 + len(sb.Document) // 1 for PayloadType
236}
237
238// AppendSection implements the Section interface.
239func (sb SectionBody) AppendSection(dest []byte) []byte {
240 dest = append(dest, byte(SingleDocument))
241 dest = append(dest, sb.Document...)
242 return dest
243}
244
245// SectionDocumentSequence represents the kind document sequence of an OP_MSG message.
246type SectionDocumentSequence struct {
247 PayloadType SectionType
248 Size int32
249 Identifier string
250 Documents []bson.Raw
251}
252
253// Kind implements the Section interface.
254func (sds SectionDocumentSequence) Kind() SectionType {
255 return DocumentSequence
256}
257
258// Len implements the Section interface
259func (sds SectionDocumentSequence) Len() int {
260 // PayloadType + Size + Identifier + 1 (null terminator) + totalDocLen
261 totalDocLen := 0
262 for _, doc := range sds.Documents {
263 totalDocLen += len(doc)
264 }
265
266 return 1 + 4 + len(sds.Identifier) + 1 + totalDocLen
267}
268
269// PayloadLen returns the length of the payload
270func (sds SectionDocumentSequence) PayloadLen() int {
271 // 4 bytes for size field, len identifier (including \0), and total docs len
272 return sds.Len() - 1
273}
274
275// AppendSection implements the Section interface
276func (sds SectionDocumentSequence) AppendSection(dest []byte) []byte {
277 dest = append(dest, byte(DocumentSequence))
278 dest = appendInt32(dest, sds.Size)
279 dest = appendCString(dest, sds.Identifier)
280
281 for _, doc := range sds.Documents {
282 dest = append(dest, doc...)
283 }
284
285 return dest
286}
287
288// SectionType represents the type for 1 section in an OP_MSG
289type SectionType uint8
290
291// These constants represent the individual section types for a section in an OP_MSG
292const (
293 SingleDocument SectionType = iota
294 DocumentSequence
295)
296
297// OpmsgWireVersion is the minimum wire version needed to use OP_MSG
298const OpmsgWireVersion = 6