blob: 5357f48ab12728cc499883fd9dde2176fa3b94ae [file] [log] [blame]
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"github.com/mongodb/mongo-go-driver/bson"
"github.com/mongodb/mongo-go-driver/x/bsonx"
"github.com/mongodb/mongo-go-driver/x/network/command"
"github.com/mongodb/mongo-go-driver/x/network/description"
"github.com/mongodb/mongo-go-driver/x/network/wiremessage"
)
// SaslClient is the client piece of a sasl conversation.
type SaslClient interface {
Start() (string, []byte, error)
Next(challenge []byte) ([]byte, error)
Completed() bool
}
// SaslClientCloser is a SaslClient that has resources to clean up.
type SaslClientCloser interface {
SaslClient
Close()
}
// ConductSaslConversation handles running a sasl conversation with MongoDB.
func ConductSaslConversation(ctx context.Context, desc description.Server, rw wiremessage.ReadWriter, db string, client SaslClient) error {
// Arbiters cannot be authenticated
if desc.Kind == description.RSArbiter {
return nil
}
if db == "" {
db = defaultAuthDB
}
if closer, ok := client.(SaslClientCloser); ok {
defer closer.Close()
}
mech, payload, err := client.Start()
if err != nil {
return newError(err, mech)
}
saslStartCmd := command.Read{
DB: db,
Command: bsonx.Doc{
{"saslStart", bsonx.Int32(1)},
{"mechanism", bsonx.String(mech)},
{"payload", bsonx.Binary(0x00, payload)},
},
}
type saslResponse struct {
ConversationID int `bson:"conversationId"`
Code int `bson:"code"`
Done bool `bson:"done"`
Payload []byte `bson:"payload"`
}
var saslResp saslResponse
ssdesc := description.SelectedServer{Server: desc}
rdr, err := saslStartCmd.RoundTrip(ctx, ssdesc, rw)
if err != nil {
return newError(err, mech)
}
err = bson.Unmarshal(rdr, &saslResp)
if err != nil {
return newAuthError("unmarshall error", err)
}
cid := saslResp.ConversationID
for {
if saslResp.Code != 0 {
return newError(err, mech)
}
if saslResp.Done && client.Completed() {
return nil
}
payload, err = client.Next(saslResp.Payload)
if err != nil {
return newError(err, mech)
}
if saslResp.Done && client.Completed() {
return nil
}
saslContinueCmd := command.Read{
DB: db,
Command: bsonx.Doc{
{"saslContinue", bsonx.Int32(1)},
{"conversationId", bsonx.Int32(int32(cid))},
{"payload", bsonx.Binary(0x00, payload)},
},
}
rdr, err = saslContinueCmd.RoundTrip(ctx, ssdesc, rw)
if err != nil {
return newError(err, mech)
}
err = bson.Unmarshal(rdr, &saslResp)
if err != nil {
return newAuthError("unmarshal error", err)
}
}
}