blob: 5357f48ab12728cc499883fd9dde2176fa3b94ae [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package auth
8
9import (
10 "context"
11
12 "github.com/mongodb/mongo-go-driver/bson"
13 "github.com/mongodb/mongo-go-driver/x/bsonx"
14 "github.com/mongodb/mongo-go-driver/x/network/command"
15 "github.com/mongodb/mongo-go-driver/x/network/description"
16 "github.com/mongodb/mongo-go-driver/x/network/wiremessage"
17)
18
19// SaslClient is the client piece of a sasl conversation.
20type SaslClient interface {
21 Start() (string, []byte, error)
22 Next(challenge []byte) ([]byte, error)
23 Completed() bool
24}
25
26// SaslClientCloser is a SaslClient that has resources to clean up.
27type SaslClientCloser interface {
28 SaslClient
29 Close()
30}
31
32// ConductSaslConversation handles running a sasl conversation with MongoDB.
33func ConductSaslConversation(ctx context.Context, desc description.Server, rw wiremessage.ReadWriter, db string, client SaslClient) error {
34 // Arbiters cannot be authenticated
35 if desc.Kind == description.RSArbiter {
36 return nil
37 }
38
39 if db == "" {
40 db = defaultAuthDB
41 }
42
43 if closer, ok := client.(SaslClientCloser); ok {
44 defer closer.Close()
45 }
46
47 mech, payload, err := client.Start()
48 if err != nil {
49 return newError(err, mech)
50 }
51
52 saslStartCmd := command.Read{
53 DB: db,
54 Command: bsonx.Doc{
55 {"saslStart", bsonx.Int32(1)},
56 {"mechanism", bsonx.String(mech)},
57 {"payload", bsonx.Binary(0x00, payload)},
58 },
59 }
60
61 type saslResponse struct {
62 ConversationID int `bson:"conversationId"`
63 Code int `bson:"code"`
64 Done bool `bson:"done"`
65 Payload []byte `bson:"payload"`
66 }
67
68 var saslResp saslResponse
69
70 ssdesc := description.SelectedServer{Server: desc}
71 rdr, err := saslStartCmd.RoundTrip(ctx, ssdesc, rw)
72 if err != nil {
73 return newError(err, mech)
74 }
75
76 err = bson.Unmarshal(rdr, &saslResp)
77 if err != nil {
78 return newAuthError("unmarshall error", err)
79 }
80
81 cid := saslResp.ConversationID
82
83 for {
84 if saslResp.Code != 0 {
85 return newError(err, mech)
86 }
87
88 if saslResp.Done && client.Completed() {
89 return nil
90 }
91
92 payload, err = client.Next(saslResp.Payload)
93 if err != nil {
94 return newError(err, mech)
95 }
96
97 if saslResp.Done && client.Completed() {
98 return nil
99 }
100
101 saslContinueCmd := command.Read{
102 DB: db,
103 Command: bsonx.Doc{
104 {"saslContinue", bsonx.Int32(1)},
105 {"conversationId", bsonx.Int32(int32(cid))},
106 {"payload", bsonx.Binary(0x00, payload)},
107 },
108 }
109
110 rdr, err = saslContinueCmd.RoundTrip(ctx, ssdesc, rw)
111 if err != nil {
112 return newError(err, mech)
113 }
114
115 err = bson.Unmarshal(rdr, &saslResp)
116 if err != nil {
117 return newAuthError("unmarshal error", err)
118 }
119 }
120}