blob: 5e640373083e30869e9a73935b0bff9e53eb1f37 [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7// Package connection contains the types for building and pooling connections that can speak the
8// MongoDB Wire Protocol. Since this low level library is meant to be used in the context of either
9// a driver or a server there are some extra identifiers on a connection so one can keep track of
10// what a connection is. This package purposefully hides the underlying network and abstracts the
11// writing to and reading from a connection to wireops.Op's. This package also provides types for
12// listening for and accepting Connections, as well as some types for handling connections and
13// proxying connections to another server.
14package connection
15
16import (
17 "context"
18 "crypto/tls"
19 "errors"
20 "fmt"
21 "io"
22 "net"
23 "strings"
24 "sync/atomic"
25 "time"
26
27 "github.com/mongodb/mongo-go-driver/bson"
28 "github.com/mongodb/mongo-go-driver/bson/bsontype"
29 "github.com/mongodb/mongo-go-driver/event"
30 "github.com/mongodb/mongo-go-driver/x/bsonx"
31 "github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore"
32 "github.com/mongodb/mongo-go-driver/x/network/address"
33 "github.com/mongodb/mongo-go-driver/x/network/compressor"
34 "github.com/mongodb/mongo-go-driver/x/network/description"
35 "github.com/mongodb/mongo-go-driver/x/network/wiremessage"
36)
37
38var globalClientConnectionID uint64
39var emptyDoc bson.Raw
40
41func nextClientConnectionID() uint64 {
42 return atomic.AddUint64(&globalClientConnectionID, 1)
43}
44
45// Connection is used to read and write wire protocol messages to a network.
46type Connection interface {
47 WriteWireMessage(context.Context, wiremessage.WireMessage) error
48 ReadWireMessage(context.Context) (wiremessage.WireMessage, error)
49 Close() error
50 Expired() bool
51 Alive() bool
52 ID() string
53}
54
55// Dialer is used to make network connections.
56type Dialer interface {
57 DialContext(ctx context.Context, network, address string) (net.Conn, error)
58}
59
60// DialerFunc is a type implemented by functions that can be used as a Dialer.
61type DialerFunc func(ctx context.Context, network, address string) (net.Conn, error)
62
63// DialContext implements the Dialer interface.
64func (df DialerFunc) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
65 return df(ctx, network, address)
66}
67
68// DefaultDialer is the Dialer implementation that is used by this package. Changing this
69// will also change the Dialer used for this package. This should only be changed why all
70// of the connections being made need to use a different Dialer. Most of the time, using a
71// WithDialer option is more appropriate than changing this variable.
72var DefaultDialer Dialer = &net.Dialer{}
73
74// Handshaker is the interface implemented by types that can perform a MongoDB
75// handshake over a provided ReadWriter. This is used during connection
76// initialization.
77type Handshaker interface {
78 Handshake(context.Context, address.Address, wiremessage.ReadWriter) (description.Server, error)
79}
80
81// HandshakerFunc is an adapter to allow the use of ordinary functions as
82// connection handshakers.
83type HandshakerFunc func(context.Context, address.Address, wiremessage.ReadWriter) (description.Server, error)
84
85// Handshake implements the Handshaker interface.
86func (hf HandshakerFunc) Handshake(ctx context.Context, addr address.Address, rw wiremessage.ReadWriter) (description.Server, error) {
87 return hf(ctx, addr, rw)
88}
89
90type connection struct {
91 addr address.Address
92 id string
93 conn net.Conn
94 compressBuf []byte // buffer to compress messages
95 compressor compressor.Compressor // use for compressing messages
96 // server can compress response with any compressor supported by driver
97 compressorMap map[wiremessage.CompressorID]compressor.Compressor
98 commandMap map[int64]*commandMetadata // map for monitoring commands sent to server
99 dead bool
100 idleTimeout time.Duration
101 idleDeadline time.Time
102 lifetimeDeadline time.Time
103 cmdMonitor *event.CommandMonitor
104 readTimeout time.Duration
105 uncompressBuf []byte // buffer to uncompress messages
106 writeTimeout time.Duration
107 readBuf []byte
108 writeBuf []byte
109 wireMessageBuf []byte // buffer to store uncompressed wire message before compressing
110}
111
112// New opens a connection to a given Addr
113//
114// The server description returned is nil if there was no handshaker provided.
115func New(ctx context.Context, addr address.Address, opts ...Option) (Connection, *description.Server, error) {
116 cfg, err := newConfig(opts...)
117 if err != nil {
118 return nil, nil, err
119 }
120
121 nc, err := cfg.dialer.DialContext(ctx, addr.Network(), addr.String())
122 if err != nil {
123 return nil, nil, err
124 }
125
126 if cfg.tlsConfig != nil {
127 tlsConfig := cfg.tlsConfig.Clone()
128 nc, err = configureTLS(ctx, nc, addr, tlsConfig)
129 if err != nil {
130 return nil, nil, err
131 }
132 }
133
134 var lifetimeDeadline time.Time
135 if cfg.lifeTimeout > 0 {
136 lifetimeDeadline = time.Now().Add(cfg.lifeTimeout)
137 }
138
139 id := fmt.Sprintf("%s[-%d]", addr, nextClientConnectionID())
140 compressorMap := make(map[wiremessage.CompressorID]compressor.Compressor)
141
142 for _, comp := range cfg.compressors {
143 compressorMap[comp.CompressorID()] = comp
144 }
145
146 c := &connection{
147 id: id,
148 conn: nc,
149 compressBuf: make([]byte, 256),
150 compressorMap: compressorMap,
151 commandMap: make(map[int64]*commandMetadata),
152 addr: addr,
153 idleTimeout: cfg.idleTimeout,
154 lifetimeDeadline: lifetimeDeadline,
155 readTimeout: cfg.readTimeout,
156 writeTimeout: cfg.writeTimeout,
157 readBuf: make([]byte, 256),
158 uncompressBuf: make([]byte, 256),
159 writeBuf: make([]byte, 0, 256),
160 wireMessageBuf: make([]byte, 256),
161 }
162
163 c.bumpIdleDeadline()
164
165 var desc *description.Server
166 if cfg.handshaker != nil {
167 d, err := cfg.handshaker.Handshake(ctx, c.addr, c)
168 if err != nil {
169 return nil, nil, err
170 }
171
172 if len(d.Compression) > 0 {
173 clientMethodLoop:
174 for _, comp := range cfg.compressors {
175 method := comp.Name()
176
177 for _, serverMethod := range d.Compression {
178 if method != serverMethod {
179 continue
180 }
181
182 c.compressor = comp // found matching compressor
183 break clientMethodLoop
184 }
185 }
186
187 }
188
189 desc = &d
190 }
191
192 c.cmdMonitor = cfg.cmdMonitor // attach the command monitor later to avoid monitoring auth
193 return c, desc, nil
194}
195
196func configureTLS(ctx context.Context, nc net.Conn, addr address.Address, config *TLSConfig) (net.Conn, error) {
197 if !config.InsecureSkipVerify {
198 hostname := addr.String()
199 colonPos := strings.LastIndex(hostname, ":")
200 if colonPos == -1 {
201 colonPos = len(hostname)
202 }
203
204 hostname = hostname[:colonPos]
205 config.ServerName = hostname
206 }
207
208 client := tls.Client(nc, config.Config)
209
210 errChan := make(chan error, 1)
211 go func() {
212 errChan <- client.Handshake()
213 }()
214
215 select {
216 case err := <-errChan:
217 if err != nil {
218 return nil, err
219 }
220 case <-ctx.Done():
221 return nil, errors.New("server connection cancelled/timeout during TLS handshake")
222 }
223 return client, nil
224}
225
226func (c *connection) Alive() bool {
227 return !c.dead
228}
229
230func (c *connection) Expired() bool {
231 now := time.Now()
232 if !c.idleDeadline.IsZero() && now.After(c.idleDeadline) {
233 return true
234 }
235
236 if !c.lifetimeDeadline.IsZero() && now.After(c.lifetimeDeadline) {
237 return true
238 }
239
240 return c.dead
241}
242
243func canCompress(cmd string) bool {
244 if cmd == "isMaster" || cmd == "saslStart" || cmd == "saslContinue" || cmd == "getnonce" || cmd == "authenticate" ||
245 cmd == "createUser" || cmd == "updateUser" || cmd == "copydbSaslStart" || cmd == "copydbgetnonce" || cmd == "copydb" {
246 return false
247 }
248 return true
249}
250
251func (c *connection) compressMessage(wm wiremessage.WireMessage) (wiremessage.WireMessage, error) {
252 var requestID int32
253 var responseTo int32
254 var origOpcode wiremessage.OpCode
255
256 switch converted := wm.(type) {
257 case wiremessage.Query:
258 firstElem, err := converted.Query.IndexErr(0)
259 if err != nil {
260 return wiremessage.Compressed{}, err
261 }
262
263 key := firstElem.Key()
264 if !canCompress(key) {
265 return wm, nil // return original message because this command can't be compressed
266 }
267 requestID = converted.MsgHeader.RequestID
268 origOpcode = wiremessage.OpQuery
269 responseTo = converted.MsgHeader.ResponseTo
270 case wiremessage.Msg:
271 firstElem, err := converted.Sections[0].(wiremessage.SectionBody).Document.IndexErr(0)
272 if err != nil {
273 return wiremessage.Compressed{}, err
274 }
275
276 key := firstElem.Key()
277 if !canCompress(key) {
278 return wm, nil
279 }
280
281 requestID = converted.MsgHeader.RequestID
282 origOpcode = wiremessage.OpMsg
283 responseTo = converted.MsgHeader.ResponseTo
284 }
285
286 // can compress
287 c.wireMessageBuf = c.wireMessageBuf[:0] // truncate
288 var err error
289 c.wireMessageBuf, err = wm.AppendWireMessage(c.wireMessageBuf)
290 if err != nil {
291 return wiremessage.Compressed{}, err
292 }
293
294 c.wireMessageBuf = c.wireMessageBuf[16:] // strip header
295 c.compressBuf = c.compressBuf[:0]
296 compressedBytes, err := c.compressor.CompressBytes(c.wireMessageBuf, c.compressBuf)
297 if err != nil {
298 return wiremessage.Compressed{}, err
299 }
300
301 compressedMessage := wiremessage.Compressed{
302 MsgHeader: wiremessage.Header{
303 // MessageLength and OpCode will be set when marshalling wire message by SetDefaults()
304 RequestID: requestID,
305 ResponseTo: responseTo,
306 },
307 OriginalOpCode: origOpcode,
308 UncompressedSize: int32(len(c.wireMessageBuf)), // length of uncompressed message excluding MsgHeader
309 CompressorID: wiremessage.CompressorID(c.compressor.CompressorID()),
310 CompressedMessage: compressedBytes,
311 }
312
313 return compressedMessage, nil
314}
315
316// returns []byte of uncompressed message with reconstructed header, original opcode, error
317func (c *connection) uncompressMessage(compressed wiremessage.Compressed) ([]byte, wiremessage.OpCode, error) {
318 // server doesn't guarantee the same compression method will be used each time so the CompressorID field must be
319 // used to find the correct method for uncompressing data
320 uncompressor := c.compressorMap[compressed.CompressorID]
321
322 // reset uncompressBuf
323 c.uncompressBuf = c.uncompressBuf[:0]
324 if int(compressed.UncompressedSize) > cap(c.uncompressBuf) {
325 c.uncompressBuf = make([]byte, 0, compressed.UncompressedSize)
326 }
327
328 uncompressedMessage, err := uncompressor.UncompressBytes(compressed.CompressedMessage, c.uncompressBuf)
329
330 if err != nil {
331 return nil, 0, err
332 }
333
334 origHeader := wiremessage.Header{
335 MessageLength: int32(len(uncompressedMessage)) + 16, // add 16 for original header
336 RequestID: compressed.MsgHeader.RequestID,
337 ResponseTo: compressed.MsgHeader.ResponseTo,
338 }
339
340 switch compressed.OriginalOpCode {
341 case wiremessage.OpReply:
342 origHeader.OpCode = wiremessage.OpReply
343 case wiremessage.OpMsg:
344 origHeader.OpCode = wiremessage.OpMsg
345 default:
346 return nil, 0, fmt.Errorf("opcode %s not implemented", compressed.OriginalOpCode)
347 }
348
349 var fullMessage []byte
350 fullMessage = origHeader.AppendHeader(fullMessage)
351 fullMessage = append(fullMessage, uncompressedMessage...)
352 return fullMessage, origHeader.OpCode, nil
353}
354
355func canMonitor(cmd string) bool {
356 if cmd == "authenticate" || cmd == "saslStart" || cmd == "saslContinue" || cmd == "getnonce" || cmd == "createUser" ||
357 cmd == "updateUser" || cmd == "copydbgetnonce" || cmd == "copydbsaslstart" || cmd == "copydb" {
358 return false
359 }
360
361 return true
362}
363
364func (c *connection) commandStartedEvent(ctx context.Context, wm wiremessage.WireMessage) error {
365 if c.cmdMonitor == nil || c.cmdMonitor.Started == nil {
366 return nil
367 }
368
369 startedEvent := &event.CommandStartedEvent{
370 ConnectionID: c.id,
371 }
372
373 var cmd bsonx.Doc
374 var err error
375 var legacy bool
376 var fullCollName string
377
378 var acknowledged bool
379 switch converted := wm.(type) {
380 case wiremessage.Query:
381 cmd, err = converted.CommandDocument()
382 if err != nil {
383 return err
384 }
385
386 acknowledged = converted.AcknowledgedWrite()
387 startedEvent.DatabaseName = converted.DatabaseName()
388 startedEvent.RequestID = int64(converted.MsgHeader.RequestID)
389 legacy = converted.Legacy()
390 fullCollName = converted.FullCollectionName
391 case wiremessage.Msg:
392 cmd, err = converted.GetMainDocument()
393 if err != nil {
394 return err
395 }
396
397 acknowledged = converted.AcknowledgedWrite()
398 arr, identifier, err := converted.GetSequenceArray()
399 if err != nil {
400 return err
401 }
402 if arr != nil {
403 cmd = cmd.Copy() // make copy to avoid changing original command
404 cmd = append(cmd, bsonx.Elem{identifier, bsonx.Array(arr)})
405 }
406
407 dbVal, err := cmd.LookupErr("$db")
408 if err != nil {
409 return err
410 }
411
412 startedEvent.DatabaseName = dbVal.StringValue()
413 startedEvent.RequestID = int64(converted.MsgHeader.RequestID)
414 case wiremessage.GetMore:
415 cmd = converted.CommandDocument()
416 startedEvent.DatabaseName = converted.DatabaseName()
417 startedEvent.RequestID = int64(converted.MsgHeader.RequestID)
418 acknowledged = true
419 legacy = true
420 fullCollName = converted.FullCollectionName
421 case wiremessage.KillCursors:
422 cmd = converted.CommandDocument()
423 startedEvent.DatabaseName = converted.DatabaseName
424 startedEvent.RequestID = int64(converted.MsgHeader.RequestID)
425 legacy = true
426 }
427
428 rawcmd, _ := cmd.MarshalBSON()
429 startedEvent.Command = rawcmd
430 startedEvent.CommandName = cmd[0].Key
431 if !canMonitor(startedEvent.CommandName) {
432 startedEvent.Command = emptyDoc
433 }
434
435 c.cmdMonitor.Started(ctx, startedEvent)
436
437 if !acknowledged {
438 if c.cmdMonitor.Succeeded == nil {
439 return nil
440 }
441
442 // unack writes must provide a CommandSucceededEvent with an { ok: 1 } reply
443 finishedEvent := event.CommandFinishedEvent{
444 DurationNanos: 0,
445 CommandName: startedEvent.CommandName,
446 RequestID: startedEvent.RequestID,
447 ConnectionID: c.id,
448 }
449
450 c.cmdMonitor.Succeeded(ctx, &event.CommandSucceededEvent{
451 CommandFinishedEvent: finishedEvent,
452 Reply: bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "ok", 1)),
453 })
454
455 return nil
456 }
457
458 c.commandMap[startedEvent.RequestID] = createMetadata(startedEvent.CommandName, legacy, fullCollName)
459 return nil
460}
461
462func processReply(reply bsonx.Doc) (bool, string) {
463 var success bool
464 var errmsg string
465 var errCode int32
466
467 for _, elem := range reply {
468 switch elem.Key {
469 case "ok":
470 switch elem.Value.Type() {
471 case bsontype.Int32:
472 if elem.Value.Int32() == 1 {
473 success = true
474 }
475 case bsontype.Int64:
476 if elem.Value.Int64() == 1 {
477 success = true
478 }
479 case bsontype.Double:
480 if elem.Value.Double() == 1 {
481 success = true
482 }
483 }
484 case "errmsg":
485 if str, ok := elem.Value.StringValueOK(); ok {
486 errmsg = str
487 }
488 case "code":
489 if c, ok := elem.Value.Int32OK(); ok {
490 errCode = c
491 }
492 }
493 }
494
495 if success {
496 return true, ""
497 }
498
499 fullErrMsg := fmt.Sprintf("Error code %d: %s", errCode, errmsg)
500 return false, fullErrMsg
501}
502
503func (c *connection) commandFinishedEvent(ctx context.Context, wm wiremessage.WireMessage) error {
504 if c.cmdMonitor == nil {
505 return nil
506 }
507
508 var reply bsonx.Doc
509 var requestID int64
510 var err error
511
512 switch converted := wm.(type) {
513 case wiremessage.Reply:
514 requestID = int64(converted.MsgHeader.ResponseTo)
515 case wiremessage.Msg:
516 requestID = int64(converted.MsgHeader.ResponseTo)
517 }
518 cmdMetadata := c.commandMap[requestID]
519 delete(c.commandMap, requestID)
520
521 switch converted := wm.(type) {
522 case wiremessage.Reply:
523 if cmdMetadata.Legacy {
524 reply, err = converted.GetMainLegacyDocument(cmdMetadata.FullCollectionName)
525 } else {
526 reply, err = converted.GetMainDocument()
527 }
528 case wiremessage.Msg:
529 reply, err = converted.GetMainDocument()
530 }
531 if err != nil {
532 return err
533 }
534
535 success, errmsg := processReply(reply)
536
537 if (success && c.cmdMonitor.Succeeded == nil) || (!success && c.cmdMonitor.Failed == nil) {
538 return nil
539 }
540
541 finishedEvent := event.CommandFinishedEvent{
542 DurationNanos: cmdMetadata.TimeDifference(),
543 CommandName: cmdMetadata.Name,
544 RequestID: requestID,
545 ConnectionID: c.id,
546 }
547
548 if success {
549 if !canMonitor(finishedEvent.CommandName) {
550 successEvent := &event.CommandSucceededEvent{
551 Reply: emptyDoc,
552 CommandFinishedEvent: finishedEvent,
553 }
554 c.cmdMonitor.Succeeded(ctx, successEvent)
555 return nil
556 }
557
558 // if response has type 1 document sequence, the sequence must be included as a BSON array in the event's reply.
559 if opmsg, ok := wm.(wiremessage.Msg); ok {
560 arr, identifier, err := opmsg.GetSequenceArray()
561 if err != nil {
562 return err
563 }
564 if arr != nil {
565 reply = reply.Copy() // make copy to avoid changing original command
566 reply = append(reply, bsonx.Elem{identifier, bsonx.Array(arr)})
567 }
568 }
569
570 replyraw, _ := reply.MarshalBSON()
571 successEvent := &event.CommandSucceededEvent{
572 Reply: replyraw,
573 CommandFinishedEvent: finishedEvent,
574 }
575
576 c.cmdMonitor.Succeeded(ctx, successEvent)
577 return nil
578 }
579
580 failureEvent := &event.CommandFailedEvent{
581 Failure: errmsg,
582 CommandFinishedEvent: finishedEvent,
583 }
584
585 c.cmdMonitor.Failed(ctx, failureEvent)
586 return nil
587}
588
589func (c *connection) WriteWireMessage(ctx context.Context, wm wiremessage.WireMessage) error {
590 var err error
591 if c.dead {
592 return Error{
593 ConnectionID: c.id,
594 message: "connection is dead",
595 }
596 }
597
598 select {
599 case <-ctx.Done():
600 return Error{
601 ConnectionID: c.id,
602 Wrapped: ctx.Err(),
603 message: "failed to write",
604 }
605 default:
606 }
607
608 deadline := time.Time{}
609 if c.writeTimeout != 0 {
610 deadline = time.Now().Add(c.writeTimeout)
611 }
612
613 if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) {
614 deadline = dl
615 }
616
617 if err := c.conn.SetWriteDeadline(deadline); err != nil {
618 return Error{
619 ConnectionID: c.id,
620 Wrapped: err,
621 message: "failed to set write deadline",
622 }
623 }
624
625 // Truncate the write buffer
626 c.writeBuf = c.writeBuf[:0]
627
628 messageToWrite := wm
629 // Compress if possible
630 if c.compressor != nil {
631 compressed, err := c.compressMessage(wm)
632 if err != nil {
633 return Error{
634 ConnectionID: c.id,
635 Wrapped: err,
636 message: "unable to compress wire message",
637 }
638 }
639 messageToWrite = compressed
640 }
641
642 c.writeBuf, err = messageToWrite.AppendWireMessage(c.writeBuf)
643 if err != nil {
644 return Error{
645 ConnectionID: c.id,
646 Wrapped: err,
647 message: "unable to encode wire message",
648 }
649 }
650
651 _, err = c.conn.Write(c.writeBuf)
652 if err != nil {
653 c.Close()
654 return Error{
655 ConnectionID: c.id,
656 Wrapped: err,
657 message: "unable to write wire message to network",
658 }
659 }
660
661 c.bumpIdleDeadline()
662 err = c.commandStartedEvent(ctx, wm)
663 if err != nil {
664 return err
665 }
666 return nil
667}
668
669func (c *connection) ReadWireMessage(ctx context.Context) (wiremessage.WireMessage, error) {
670 if c.dead {
671 return nil, Error{
672 ConnectionID: c.id,
673 message: "connection is dead",
674 }
675 }
676
677 select {
678 case <-ctx.Done():
679 // We close the connection because we don't know if there
680 // is an unread message on the wire.
681 c.Close()
682 return nil, Error{
683 ConnectionID: c.id,
684 Wrapped: ctx.Err(),
685 message: "failed to read",
686 }
687 default:
688 }
689
690 deadline := time.Time{}
691 if c.readTimeout != 0 {
692 deadline = time.Now().Add(c.readTimeout)
693 }
694
695 if ctxDL, ok := ctx.Deadline(); ok && (deadline.IsZero() || ctxDL.Before(deadline)) {
696 deadline = ctxDL
697 }
698
699 if err := c.conn.SetReadDeadline(deadline); err != nil {
700 return nil, Error{
701 ConnectionID: c.id,
702 Wrapped: ctx.Err(),
703 message: "failed to set read deadline",
704 }
705 }
706
707 var sizeBuf [4]byte
708 _, err := io.ReadFull(c.conn, sizeBuf[:])
709 if err != nil {
710 c.Close()
711 return nil, Error{
712 ConnectionID: c.id,
713 Wrapped: err,
714 message: "unable to decode message length",
715 }
716 }
717
718 size := readInt32(sizeBuf[:], 0)
719
720 // Isn't the best reuse, but resizing a []byte to be larger
721 // is difficult.
722 if cap(c.readBuf) > int(size) {
723 c.readBuf = c.readBuf[:size]
724 } else {
725 c.readBuf = make([]byte, size)
726 }
727
728 c.readBuf[0], c.readBuf[1], c.readBuf[2], c.readBuf[3] = sizeBuf[0], sizeBuf[1], sizeBuf[2], sizeBuf[3]
729
730 _, err = io.ReadFull(c.conn, c.readBuf[4:])
731 if err != nil {
732 c.Close()
733 return nil, Error{
734 ConnectionID: c.id,
735 Wrapped: err,
736 message: "unable to read full message",
737 }
738 }
739
740 hdr, err := wiremessage.ReadHeader(c.readBuf, 0)
741 if err != nil {
742 c.Close()
743 return nil, Error{
744 ConnectionID: c.id,
745 Wrapped: err,
746 message: "unable to decode header",
747 }
748 }
749
750 messageToDecode := c.readBuf
751 opcodeToCheck := hdr.OpCode
752
753 if hdr.OpCode == wiremessage.OpCompressed {
754 var compressed wiremessage.Compressed
755 err := compressed.UnmarshalWireMessage(c.readBuf)
756 if err != nil {
757 defer c.Close()
758 return nil, Error{
759 ConnectionID: c.id,
760 Wrapped: err,
761 message: "unable to decode OP_COMPRESSED",
762 }
763 }
764
765 uncompressed, origOpcode, err := c.uncompressMessage(compressed)
766 if err != nil {
767 defer c.Close()
768 return nil, Error{
769 ConnectionID: c.id,
770 Wrapped: err,
771 message: "unable to uncompress message",
772 }
773 }
774 messageToDecode = uncompressed
775 opcodeToCheck = origOpcode
776 }
777
778 var wm wiremessage.WireMessage
779 switch opcodeToCheck {
780 case wiremessage.OpReply:
781 var r wiremessage.Reply
782 err := r.UnmarshalWireMessage(messageToDecode)
783 if err != nil {
784 c.Close()
785 return nil, Error{
786 ConnectionID: c.id,
787 Wrapped: err,
788 message: "unable to decode OP_REPLY",
789 }
790 }
791 wm = r
792 case wiremessage.OpMsg:
793 var reply wiremessage.Msg
794 err := reply.UnmarshalWireMessage(messageToDecode)
795 if err != nil {
796 c.Close()
797 return nil, Error{
798 ConnectionID: c.id,
799 Wrapped: err,
800 message: "unable to decode OP_MSG",
801 }
802 }
803 wm = reply
804 default:
805 c.Close()
806 return nil, Error{
807 ConnectionID: c.id,
808 message: fmt.Sprintf("opcode %s not implemented", hdr.OpCode),
809 }
810 }
811
812 c.bumpIdleDeadline()
813 err = c.commandFinishedEvent(ctx, wm)
814 if err != nil {
815 return nil, err // TODO: do we care if monitoring fails?
816 }
817
818 return wm, nil
819}
820
821func (c *connection) bumpIdleDeadline() {
822 if c.idleTimeout > 0 {
823 c.idleDeadline = time.Now().Add(c.idleTimeout)
824 }
825}
826
827func (c *connection) Close() error {
828 c.dead = true
829 err := c.conn.Close()
830 if err != nil {
831 return Error{
832 ConnectionID: c.id,
833 Wrapped: err,
834 message: "failed to close net.Conn",
835 }
836 }
837
838 return nil
839}
840
841func (c *connection) ID() string {
842 return c.id
843}
844
845func (c *connection) initialize(ctx context.Context, appName string) error {
846 return nil
847}
848
849func readInt32(b []byte, pos int32) int32 {
850 return (int32(b[pos+0])) | (int32(b[pos+1]) << 8) | (int32(b[pos+2]) << 16) | (int32(b[pos+3]) << 24)
851}