VOL-2112 move to voltha-lib-go

Change-Id: I3435b8acb982deeab6b6ac28e798d7722ad01d0a
diff --git a/vendor/github.com/Shopify/sarama/broker.go b/vendor/github.com/Shopify/sarama/broker.go
index 9129089..9c3e5a0 100644
--- a/vendor/github.com/Shopify/sarama/broker.go
+++ b/vendor/github.com/Shopify/sarama/broker.go
@@ -13,24 +13,25 @@
 	"sync/atomic"
 	"time"
 
-	"github.com/rcrowley/go-metrics"
+	metrics "github.com/rcrowley/go-metrics"
 )
 
 // Broker represents a single Kafka broker connection. All operations on this object are entirely concurrency-safe.
 type Broker struct {
-	id   int32
-	addr string
+	conf *Config
 	rack *string
 
-	conf          *Config
+	id            int32
+	addr          string
 	correlationID int32
 	conn          net.Conn
 	connErr       error
 	lock          sync.Mutex
 	opened        int32
+	responses     chan responsePromise
+	done          chan bool
 
-	responses chan responsePromise
-	done      chan bool
+	registeredMetrics []string
 
 	incomingByteRate       metrics.Meter
 	requestRate            metrics.Meter
@@ -46,6 +47,8 @@
 	brokerOutgoingByteRate metrics.Meter
 	brokerResponseRate     metrics.Meter
 	brokerResponseSize     metrics.Histogram
+
+	kerberosAuthenticator GSSAPIKerberosAuth
 }
 
 // SASLMechanism specifies the SASL mechanism the client uses to authenticate with the broker
@@ -56,6 +59,11 @@
 	SASLTypeOAuth = "OAUTHBEARER"
 	// SASLTypePlaintext represents the SASL/PLAIN mechanism
 	SASLTypePlaintext = "PLAIN"
+	// SASLTypeSCRAMSHA256 represents the SCRAM-SHA-256 mechanism.
+	SASLTypeSCRAMSHA256 = "SCRAM-SHA-256"
+	// SASLTypeSCRAMSHA512 represents the SCRAM-SHA-512 mechanism.
+	SASLTypeSCRAMSHA512 = "SCRAM-SHA-512"
+	SASLTypeGSSAPI      = "GSSAPI"
 	// SASLHandshakeV0 is v0 of the Kafka SASL handshake protocol. Client and
 	// server negotiate SASL auth using opaque packets.
 	SASLHandshakeV0 = int16(0)
@@ -92,6 +100,20 @@
 	Token() (*AccessToken, error)
 }
 
+// SCRAMClient is a an interface to a SCRAM
+// client implementation.
+type SCRAMClient interface {
+	// Begin prepares the client for the SCRAM exchange
+	// with the server with a user name and a password
+	Begin(userName, password, authzID string) error
+	// Step steps client through the SCRAM exchange. It is
+	// called repeatedly until it errors or `Done` returns true.
+	Step(challenge string) (response string, err error)
+	// Done should return true when the SCRAM conversation
+	// is over.
+	Done() bool
+}
+
 type responsePromise struct {
 	requestTime   time.Time
 	correlationID int32
@@ -137,6 +159,8 @@
 
 		if conf.Net.TLS.Enable {
 			b.conn, b.connErr = tls.DialWithDialer(&dialer, "tcp", b.addr, conf.Net.TLS.Config)
+		} else if conf.Net.Proxy.Enable {
+			b.conn, b.connErr = conf.Net.Proxy.Dialer.Dial("tcp", b.addr)
 		} else {
 			b.conn, b.connErr = dialer.Dial("tcp", b.addr)
 		}
@@ -161,13 +185,7 @@
 		// Do not gather metrics for seeded broker (only used during bootstrap) because they share
 		// the same id (-1) and are already exposed through the global metrics above
 		if b.id >= 0 {
-			b.brokerIncomingByteRate = getOrRegisterBrokerMeter("incoming-byte-rate", b, conf.MetricRegistry)
-			b.brokerRequestRate = getOrRegisterBrokerMeter("request-rate", b, conf.MetricRegistry)
-			b.brokerRequestSize = getOrRegisterBrokerHistogram("request-size", b, conf.MetricRegistry)
-			b.brokerRequestLatency = getOrRegisterBrokerHistogram("request-latency-in-ms", b, conf.MetricRegistry)
-			b.brokerOutgoingByteRate = getOrRegisterBrokerMeter("outgoing-byte-rate", b, conf.MetricRegistry)
-			b.brokerResponseRate = getOrRegisterBrokerMeter("response-rate", b, conf.MetricRegistry)
-			b.brokerResponseSize = getOrRegisterBrokerHistogram("response-size", b, conf.MetricRegistry)
+			b.registerMetrics()
 		}
 
 		if conf.Net.SASL.Enable {
@@ -210,6 +228,7 @@
 	return b.conn != nil, b.connErr
 }
 
+//Close closes the broker resources
 func (b *Broker) Close() error {
 	b.lock.Lock()
 	defer b.lock.Unlock()
@@ -228,12 +247,7 @@
 	b.done = nil
 	b.responses = nil
 
-	if b.id >= 0 {
-		b.conf.MetricRegistry.Unregister(getMetricNameForBroker("incoming-byte-rate", b))
-		b.conf.MetricRegistry.Unregister(getMetricNameForBroker("request-rate", b))
-		b.conf.MetricRegistry.Unregister(getMetricNameForBroker("outgoing-byte-rate", b))
-		b.conf.MetricRegistry.Unregister(getMetricNameForBroker("response-rate", b))
-	}
+	b.unregisterMetrics()
 
 	if err == nil {
 		Logger.Printf("Closed connection to broker %s\n", b.addr)
@@ -267,6 +281,7 @@
 	return *b.rack
 }
 
+//GetMetadata send a metadata request and returns a metadata response or error
 func (b *Broker) GetMetadata(request *MetadataRequest) (*MetadataResponse, error) {
 	response := new(MetadataResponse)
 
@@ -279,6 +294,7 @@
 	return response, nil
 }
 
+//GetConsumerMetadata send a consumer metadata request and returns a consumer metadata response or error
 func (b *Broker) GetConsumerMetadata(request *ConsumerMetadataRequest) (*ConsumerMetadataResponse, error) {
 	response := new(ConsumerMetadataResponse)
 
@@ -291,6 +307,7 @@
 	return response, nil
 }
 
+//FindCoordinator sends a find coordinate request and returns a response or error
 func (b *Broker) FindCoordinator(request *FindCoordinatorRequest) (*FindCoordinatorResponse, error) {
 	response := new(FindCoordinatorResponse)
 
@@ -303,6 +320,7 @@
 	return response, nil
 }
 
+//GetAvailableOffsets return an offset response or error
 func (b *Broker) GetAvailableOffsets(request *OffsetRequest) (*OffsetResponse, error) {
 	response := new(OffsetResponse)
 
@@ -315,9 +333,12 @@
 	return response, nil
 }
 
+//Produce returns a produce response or error
 func (b *Broker) Produce(request *ProduceRequest) (*ProduceResponse, error) {
-	var response *ProduceResponse
-	var err error
+	var (
+		response *ProduceResponse
+		err      error
+	)
 
 	if request.RequiredAcks == NoResponse {
 		err = b.sendAndReceive(request, nil)
@@ -333,11 +354,11 @@
 	return response, nil
 }
 
+//Fetch returns a FetchResponse or error
 func (b *Broker) Fetch(request *FetchRequest) (*FetchResponse, error) {
 	response := new(FetchResponse)
 
 	err := b.sendAndReceive(request, response)
-
 	if err != nil {
 		return nil, err
 	}
@@ -345,11 +366,11 @@
 	return response, nil
 }
 
+//CommitOffset return an Offset commit reponse or error
 func (b *Broker) CommitOffset(request *OffsetCommitRequest) (*OffsetCommitResponse, error) {
 	response := new(OffsetCommitResponse)
 
 	err := b.sendAndReceive(request, response)
-
 	if err != nil {
 		return nil, err
 	}
@@ -357,11 +378,11 @@
 	return response, nil
 }
 
+//FetchOffset returns an offset fetch response or error
 func (b *Broker) FetchOffset(request *OffsetFetchRequest) (*OffsetFetchResponse, error) {
 	response := new(OffsetFetchResponse)
 
 	err := b.sendAndReceive(request, response)
-
 	if err != nil {
 		return nil, err
 	}
@@ -369,6 +390,7 @@
 	return response, nil
 }
 
+//JoinGroup returns a join group response or error
 func (b *Broker) JoinGroup(request *JoinGroupRequest) (*JoinGroupResponse, error) {
 	response := new(JoinGroupResponse)
 
@@ -380,6 +402,7 @@
 	return response, nil
 }
 
+//SyncGroup returns a sync group response or error
 func (b *Broker) SyncGroup(request *SyncGroupRequest) (*SyncGroupResponse, error) {
 	response := new(SyncGroupResponse)
 
@@ -391,6 +414,7 @@
 	return response, nil
 }
 
+//LeaveGroup return a leave group response or error
 func (b *Broker) LeaveGroup(request *LeaveGroupRequest) (*LeaveGroupResponse, error) {
 	response := new(LeaveGroupResponse)
 
@@ -402,6 +426,7 @@
 	return response, nil
 }
 
+//Heartbeat returns a heartbeat response or error
 func (b *Broker) Heartbeat(request *HeartbeatRequest) (*HeartbeatResponse, error) {
 	response := new(HeartbeatResponse)
 
@@ -413,6 +438,7 @@
 	return response, nil
 }
 
+//ListGroups return a list group response or error
 func (b *Broker) ListGroups(request *ListGroupsRequest) (*ListGroupsResponse, error) {
 	response := new(ListGroupsResponse)
 
@@ -424,6 +450,7 @@
 	return response, nil
 }
 
+//DescribeGroups return describe group response or error
 func (b *Broker) DescribeGroups(request *DescribeGroupsRequest) (*DescribeGroupsResponse, error) {
 	response := new(DescribeGroupsResponse)
 
@@ -435,6 +462,7 @@
 	return response, nil
 }
 
+//ApiVersions return api version response or error
 func (b *Broker) ApiVersions(request *ApiVersionsRequest) (*ApiVersionsResponse, error) {
 	response := new(ApiVersionsResponse)
 
@@ -446,6 +474,7 @@
 	return response, nil
 }
 
+//CreateTopics send a create topic request and returns create topic response
 func (b *Broker) CreateTopics(request *CreateTopicsRequest) (*CreateTopicsResponse, error) {
 	response := new(CreateTopicsResponse)
 
@@ -457,6 +486,7 @@
 	return response, nil
 }
 
+//DeleteTopics sends a delete topic request and returns delete topic response
 func (b *Broker) DeleteTopics(request *DeleteTopicsRequest) (*DeleteTopicsResponse, error) {
 	response := new(DeleteTopicsResponse)
 
@@ -468,6 +498,8 @@
 	return response, nil
 }
 
+//CreatePartitions sends a create partition request and returns create
+//partitions response or error
 func (b *Broker) CreatePartitions(request *CreatePartitionsRequest) (*CreatePartitionsResponse, error) {
 	response := new(CreatePartitionsResponse)
 
@@ -479,6 +511,8 @@
 	return response, nil
 }
 
+//DeleteRecords send a request to delete records and return delete record
+//response or error
 func (b *Broker) DeleteRecords(request *DeleteRecordsRequest) (*DeleteRecordsResponse, error) {
 	response := new(DeleteRecordsResponse)
 
@@ -490,6 +524,7 @@
 	return response, nil
 }
 
+//DescribeAcls sends a describe acl request and returns a response or error
 func (b *Broker) DescribeAcls(request *DescribeAclsRequest) (*DescribeAclsResponse, error) {
 	response := new(DescribeAclsResponse)
 
@@ -501,6 +536,7 @@
 	return response, nil
 }
 
+//CreateAcls sends a create acl request and returns a response or error
 func (b *Broker) CreateAcls(request *CreateAclsRequest) (*CreateAclsResponse, error) {
 	response := new(CreateAclsResponse)
 
@@ -512,6 +548,7 @@
 	return response, nil
 }
 
+//DeleteAcls sends a delete acl request and returns a response or error
 func (b *Broker) DeleteAcls(request *DeleteAclsRequest) (*DeleteAclsResponse, error) {
 	response := new(DeleteAclsResponse)
 
@@ -523,6 +560,7 @@
 	return response, nil
 }
 
+//InitProducerID sends an init producer request and returns a response or error
 func (b *Broker) InitProducerID(request *InitProducerIDRequest) (*InitProducerIDResponse, error) {
 	response := new(InitProducerIDResponse)
 
@@ -534,6 +572,8 @@
 	return response, nil
 }
 
+//AddPartitionsToTxn send a request to add partition to txn and returns
+//a response or error
 func (b *Broker) AddPartitionsToTxn(request *AddPartitionsToTxnRequest) (*AddPartitionsToTxnResponse, error) {
 	response := new(AddPartitionsToTxnResponse)
 
@@ -545,6 +585,8 @@
 	return response, nil
 }
 
+//AddOffsetsToTxn sends a request to add offsets to txn and returns a response
+//or error
 func (b *Broker) AddOffsetsToTxn(request *AddOffsetsToTxnRequest) (*AddOffsetsToTxnResponse, error) {
 	response := new(AddOffsetsToTxnResponse)
 
@@ -556,6 +598,7 @@
 	return response, nil
 }
 
+//EndTxn sends a request to end txn and returns a response or error
 func (b *Broker) EndTxn(request *EndTxnRequest) (*EndTxnResponse, error) {
 	response := new(EndTxnResponse)
 
@@ -567,6 +610,8 @@
 	return response, nil
 }
 
+//TxnOffsetCommit sends a request to commit transaction offsets and returns
+//a response or error
 func (b *Broker) TxnOffsetCommit(request *TxnOffsetCommitRequest) (*TxnOffsetCommitResponse, error) {
 	response := new(TxnOffsetCommitResponse)
 
@@ -578,6 +623,8 @@
 	return response, nil
 }
 
+//DescribeConfigs sends a request to describe config and returns a response or
+//error
 func (b *Broker) DescribeConfigs(request *DescribeConfigsRequest) (*DescribeConfigsResponse, error) {
 	response := new(DescribeConfigsResponse)
 
@@ -589,6 +636,7 @@
 	return response, nil
 }
 
+//AlterConfigs sends a request to alter config and return a response or error
 func (b *Broker) AlterConfigs(request *AlterConfigsRequest) (*AlterConfigsResponse, error) {
 	response := new(AlterConfigsResponse)
 
@@ -600,6 +648,7 @@
 	return response, nil
 }
 
+//DeleteGroups sends a request to delete groups and returns a response or error
 func (b *Broker) DeleteGroups(request *DeleteGroupsRequest) (*DeleteGroupsResponse, error) {
 	response := new(DeleteGroupsResponse)
 
@@ -638,7 +687,7 @@
 
 	requestTime := time.Now()
 	bytes, err := b.conn.Write(buf)
-	b.updateOutgoingCommunicationMetrics(bytes)
+	b.updateOutgoingCommunicationMetrics(bytes) //TODO: should it be after error check
 	if err != nil {
 		return nil, err
 	}
@@ -658,7 +707,6 @@
 
 func (b *Broker) sendAndReceive(req protocolBody, res versionedDecoder) error {
 	promise, err := b.send(req, res != nil)
-
 	if err != nil {
 		return err
 	}
@@ -707,11 +755,11 @@
 }
 
 func (b *Broker) encode(pe packetEncoder, version int16) (err error) {
-
 	host, portstr, err := net.SplitHostPort(b.addr)
 	if err != nil {
 		return err
 	}
+
 	port, err := strconv.Atoi(portstr)
 	if err != nil {
 		return err
@@ -739,6 +787,7 @@
 func (b *Broker) responseReceiver() {
 	var dead error
 	header := make([]byte, 8)
+
 	for response := range b.responses {
 		if dead != nil {
 			response.errors <- dead
@@ -793,14 +842,28 @@
 }
 
 func (b *Broker) authenticateViaSASL() error {
-	if b.conf.Net.SASL.Mechanism == SASLTypeOAuth {
+	switch b.conf.Net.SASL.Mechanism {
+	case SASLTypeOAuth:
 		return b.sendAndReceiveSASLOAuth(b.conf.Net.SASL.TokenProvider)
+	case SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512:
+		return b.sendAndReceiveSASLSCRAMv1()
+	case SASLTypeGSSAPI:
+		return b.sendAndReceiveKerberos()
+	default:
+		return b.sendAndReceiveSASLPlainAuth()
 	}
-	return b.sendAndReceiveSASLPlainAuth()
 }
 
-func (b *Broker) sendAndReceiveSASLHandshake(saslType string, version int16) error {
-	rb := &SaslHandshakeRequest{Mechanism: saslType, Version: version}
+func (b *Broker) sendAndReceiveKerberos() error {
+	b.kerberosAuthenticator.Config = &b.conf.Net.SASL.GSSAPI
+	if b.kerberosAuthenticator.NewKerberosClientFunc == nil {
+		b.kerberosAuthenticator.NewKerberosClientFunc = NewKerberosClient
+	}
+	return b.kerberosAuthenticator.Authorize(b)
+}
+
+func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int16) error {
+	rb := &SaslHandshakeRequest{Mechanism: string(saslType), Version: version}
 
 	req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
 	buf, err := encode(req, b.conf.MetricRegistry)
@@ -828,6 +891,7 @@
 		Logger.Printf("Failed to read SASL handshake header : %s\n", err.Error())
 		return err
 	}
+
 	length := binary.BigEndian.Uint32(header[:4])
 	payload := make([]byte, length-4)
 	n, err := io.ReadFull(b.conn, payload)
@@ -835,23 +899,29 @@
 		Logger.Printf("Failed to read SASL handshake payload : %s\n", err.Error())
 		return err
 	}
+
 	b.updateIncomingCommunicationMetrics(n+8, time.Since(requestTime))
 	res := &SaslHandshakeResponse{}
+
 	err = versionedDecode(payload, res, 0)
 	if err != nil {
 		Logger.Printf("Failed to parse SASL handshake : %s\n", err.Error())
 		return err
 	}
+
 	if res.Err != ErrNoError {
 		Logger.Printf("Invalid SASL Mechanism : %s\n", res.Err.Error())
 		return res.Err
 	}
-	Logger.Print("Successful SASL handshake")
+
+	Logger.Print("Successful SASL handshake. Available mechanisms: ", res.EnabledMechanisms)
 	return nil
 }
 
-// Kafka 0.10.0 plans to support SASL Plain and Kerberos as per PR #812 (KIP-43)/(JIRA KAFKA-3149)
-// Some hosted kafka services such as IBM Message Hub already offer SASL/PLAIN auth with Kafka 0.9
+// Kafka 0.10.x supported SASL PLAIN/Kerberos via KAFKA-3149 (KIP-43).
+// Kafka 1.x.x onward added a SaslAuthenticate request/response message which
+// wraps the SASL flow in the Kafka protocol, which allows for returning
+// meaningful errors on authentication failure.
 //
 // In SASL Plain, Kafka expects the auth header to be in the following format
 // Message format (from https://tools.ietf.org/html/rfc4616):
@@ -865,17 +935,34 @@
 //   SAFE      = UTF1 / UTF2 / UTF3 / UTF4
 //                  ;; any UTF-8 encoded Unicode character except NUL
 //
+// With SASL v0 handshake and auth then:
 // When credentials are valid, Kafka returns a 4 byte array of null characters.
-// When credentials are invalid, Kafka closes the connection. This does not seem to be the ideal way
-// of responding to bad credentials but thats how its being done today.
+// When credentials are invalid, Kafka closes the connection.
+//
+// With SASL v1 handshake and auth then:
+// When credentials are invalid, Kafka replies with a SaslAuthenticate response
+// containing an error code and message detailing the authentication failure.
 func (b *Broker) sendAndReceiveSASLPlainAuth() error {
+	// default to V0 to allow for backward compatability when SASL is enabled
+	// but not the handshake
 	if b.conf.Net.SASL.Handshake {
-		handshakeErr := b.sendAndReceiveSASLHandshake(SASLTypePlaintext, SASLHandshakeV0)
+
+		handshakeErr := b.sendAndReceiveSASLHandshake(SASLTypePlaintext, b.conf.Net.SASL.Version)
 		if handshakeErr != nil {
 			Logger.Printf("Error while performing SASL handshake %s\n", b.addr)
 			return handshakeErr
 		}
 	}
+
+	if b.conf.Net.SASL.Version == SASLHandshakeV1 {
+		return b.sendAndReceiveV1SASLPlainAuth()
+	}
+	return b.sendAndReceiveV0SASLPlainAuth()
+}
+
+// sendAndReceiveV0SASLPlainAuth flows the v0 sasl auth NOT wrapped in the kafka protocol
+func (b *Broker) sendAndReceiveV0SASLPlainAuth() error {
+
 	length := 1 + len(b.conf.Net.SASL.User) + 1 + len(b.conf.Net.SASL.Password)
 	authBytes := make([]byte, length+4) //4 byte length header + auth data
 	binary.BigEndian.PutUint32(authBytes, uint32(length))
@@ -909,55 +996,197 @@
 	return nil
 }
 
+// sendAndReceiveV1SASLPlainAuth flows the v1 sasl authentication using the kafka protocol
+func (b *Broker) sendAndReceiveV1SASLPlainAuth() error {
+	correlationID := b.correlationID
+
+	requestTime := time.Now()
+
+	bytesWritten, err := b.sendSASLPlainAuthClientResponse(correlationID)
+
+	b.updateOutgoingCommunicationMetrics(bytesWritten)
+
+	if err != nil {
+		Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error())
+		return err
+	}
+
+	b.correlationID++
+
+	bytesRead, err := b.receiveSASLServerResponse(&SaslAuthenticateResponse{}, correlationID)
+	b.updateIncomingCommunicationMetrics(bytesRead, time.Since(requestTime))
+
+	// With v1 sasl we get an error message set in the response we can return
+	if err != nil {
+		Logger.Printf("Error returned from broker during SASL flow %s: %s\n", b.addr, err.Error())
+		return err
+	}
+
+	return nil
+}
+
 // sendAndReceiveSASLOAuth performs the authentication flow as described by KIP-255
 // https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=75968876
 func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error {
-
 	if err := b.sendAndReceiveSASLHandshake(SASLTypeOAuth, SASLHandshakeV1); err != nil {
 		return err
 	}
 
 	token, err := provider.Token()
-
 	if err != nil {
 		return err
 	}
 
-	requestTime := time.Now()
-
-	correlationID := b.correlationID
-
-	bytesWritten, err := b.sendSASLOAuthBearerClientResponse(token, correlationID)
-
+	message, err := buildClientFirstMessage(token)
 	if err != nil {
 		return err
 	}
 
+	challenged, err := b.sendClientMessage(message)
+	if err != nil {
+		return err
+	}
+
+	if challenged {
+		// Abort the token exchange. The broker returns the failure code.
+		_, err = b.sendClientMessage([]byte(`\x01`))
+	}
+
+	return err
+}
+
+// sendClientMessage sends a SASL/OAUTHBEARER client message and returns true
+// if the broker responds with a challenge, in which case the token is
+// rejected.
+func (b *Broker) sendClientMessage(message []byte) (bool, error) {
+
+	requestTime := time.Now()
+	correlationID := b.correlationID
+
+	bytesWritten, err := b.sendSASLOAuthBearerClientMessage(message, correlationID)
+	if err != nil {
+		return false, err
+	}
+
 	b.updateOutgoingCommunicationMetrics(bytesWritten)
-
 	b.correlationID++
 
-	bytesRead, err := b.receiveSASLOAuthBearerServerResponse(correlationID)
-
-	if err != nil {
-		return err
-	}
+	res := &SaslAuthenticateResponse{}
+	bytesRead, err := b.receiveSASLServerResponse(res, correlationID)
 
 	requestLatency := time.Since(requestTime)
 	b.updateIncomingCommunicationMetrics(bytesRead, requestLatency)
 
+	isChallenge := len(res.SaslAuthBytes) > 0
+
+	if isChallenge && err != nil {
+		Logger.Printf("Broker rejected authentication token: %s", res.SaslAuthBytes)
+	}
+
+	return isChallenge, err
+}
+
+func (b *Broker) sendAndReceiveSASLSCRAMv1() error {
+	if err := b.sendAndReceiveSASLHandshake(b.conf.Net.SASL.Mechanism, SASLHandshakeV1); err != nil {
+		return err
+	}
+
+	scramClient := b.conf.Net.SASL.SCRAMClientGeneratorFunc()
+	if err := scramClient.Begin(b.conf.Net.SASL.User, b.conf.Net.SASL.Password, b.conf.Net.SASL.SCRAMAuthzID); err != nil {
+		return fmt.Errorf("failed to start SCRAM exchange with the server: %s", err.Error())
+	}
+
+	msg, err := scramClient.Step("")
+	if err != nil {
+		return fmt.Errorf("failed to advance the SCRAM exchange: %s", err.Error())
+
+	}
+
+	for !scramClient.Done() {
+		requestTime := time.Now()
+		correlationID := b.correlationID
+		bytesWritten, err := b.sendSaslAuthenticateRequest(correlationID, []byte(msg))
+		if err != nil {
+			Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error())
+			return err
+		}
+
+		b.updateOutgoingCommunicationMetrics(bytesWritten)
+		b.correlationID++
+		challenge, err := b.receiveSaslAuthenticateResponse(correlationID)
+		if err != nil {
+			Logger.Printf("Failed to read response while authenticating with SASL to broker %s: %s\n", b.addr, err.Error())
+			return err
+		}
+
+		b.updateIncomingCommunicationMetrics(len(challenge), time.Since(requestTime))
+		msg, err = scramClient.Step(string(challenge))
+		if err != nil {
+			Logger.Println("SASL authentication failed", err)
+			return err
+		}
+	}
+
+	Logger.Println("SASL authentication succeeded")
 	return nil
 }
 
+func (b *Broker) sendSaslAuthenticateRequest(correlationID int32, msg []byte) (int, error) {
+	rb := &SaslAuthenticateRequest{msg}
+	req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
+	buf, err := encode(req, b.conf.MetricRegistry)
+	if err != nil {
+		return 0, err
+	}
+
+	if err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)); err != nil {
+		return 0, err
+	}
+
+	return b.conn.Write(buf)
+}
+
+func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, error) {
+	buf := make([]byte, responseLengthSize+correlationIDSize)
+	_, err := io.ReadFull(b.conn, buf)
+	if err != nil {
+		return nil, err
+	}
+
+	header := responseHeader{}
+	err = decode(buf, &header)
+	if err != nil {
+		return nil, err
+	}
+
+	if header.correlationID != correlationID {
+		return nil, fmt.Errorf("correlation ID didn't match, wanted %d, got %d", b.correlationID, header.correlationID)
+	}
+
+	buf = make([]byte, header.length-correlationIDSize)
+	_, err = io.ReadFull(b.conn, buf)
+	if err != nil {
+		return nil, err
+	}
+
+	res := &SaslAuthenticateResponse{}
+	if err := versionedDecode(buf, res, 0); err != nil {
+		return nil, err
+	}
+	if res.Err != ErrNoError {
+		return nil, res.Err
+	}
+	return res.SaslAuthBytes, nil
+}
+
 // Build SASL/OAUTHBEARER initial client response as described by RFC-7628
 // https://tools.ietf.org/html/rfc7628
-func buildClientInitialResponse(token *AccessToken) ([]byte, error) {
-
+func buildClientFirstMessage(token *AccessToken) ([]byte, error) {
 	var ext string
 
 	if token.Extensions != nil && len(token.Extensions) > 0 {
 		if _, ok := token.Extensions[SASLExtKeyAuth]; ok {
-			return []byte{}, fmt.Errorf("The extension `%s` is invalid", SASLExtKeyAuth)
+			return []byte{}, fmt.Errorf("the extension `%s` is invalid", SASLExtKeyAuth)
 		}
 		ext = "\x01" + mapToString(token.Extensions, "=", "\x01")
 	}
@@ -970,7 +1199,6 @@
 // mapToString returns a list of key-value pairs ordered by key.
 // keyValSep separates the key from the value. elemSep separates each pair.
 func mapToString(extensions map[string]string, keyValSep string, elemSep string) string {
-
 	buf := make([]string, 0, len(extensions))
 
 	for k, v := range extensions {
@@ -982,20 +1210,30 @@
 	return strings.Join(buf, elemSep)
 }
 
-func (b *Broker) sendSASLOAuthBearerClientResponse(token *AccessToken, correlationID int32) (int, error) {
-
-	initialResp, err := buildClientInitialResponse(token)
-
+func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, error) {
+	authBytes := []byte("\x00" + b.conf.Net.SASL.User + "\x00" + b.conf.Net.SASL.Password)
+	rb := &SaslAuthenticateRequest{authBytes}
+	req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
+	buf, err := encode(req, b.conf.MetricRegistry)
 	if err != nil {
 		return 0, err
 	}
 
+	err = b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout))
+	if err != nil {
+		Logger.Printf("Failed to set write deadline when doing SASL auth with broker %s: %s\n", b.addr, err.Error())
+		return 0, err
+	}
+	return b.conn.Write(buf)
+}
+
+func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlationID int32) (int, error) {
+
 	rb := &SaslAuthenticateRequest{initialResp}
 
 	req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
 
 	buf, err := encode(req, b.conf.MetricRegistry)
-
 	if err != nil {
 		return 0, err
 	}
@@ -1007,12 +1245,11 @@
 	return b.conn.Write(buf)
 }
 
-func (b *Broker) receiveSASLOAuthBearerServerResponse(correlationID int32) (int, error) {
+func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correlationID int32) (int, error) {
 
-	buf := make([]byte, 8)
+	buf := make([]byte, responseLengthSize+correlationIDSize)
 
 	bytesRead, err := io.ReadFull(b.conn, buf)
-
 	if err != nil {
 		return bytesRead, err
 	}
@@ -1020,7 +1257,6 @@
 	header := responseHeader{}
 
 	err = decode(buf, &header)
-
 	if err != nil {
 		return bytesRead, err
 	}
@@ -1029,48 +1265,39 @@
 		return bytesRead, fmt.Errorf("correlation ID didn't match, wanted %d, got %d", b.correlationID, header.correlationID)
 	}
 
-	buf = make([]byte, header.length-4)
+	buf = make([]byte, header.length-correlationIDSize)
 
 	c, err := io.ReadFull(b.conn, buf)
-
 	bytesRead += c
-
 	if err != nil {
 		return bytesRead, err
 	}
 
-	res := &SaslAuthenticateResponse{}
-
 	if err := versionedDecode(buf, res, 0); err != nil {
 		return bytesRead, err
 	}
 
-	if err != nil {
-		return bytesRead, err
-	}
-
 	if res.Err != ErrNoError {
 		return bytesRead, res.Err
 	}
 
-	if len(res.SaslAuthBytes) > 0 {
-		Logger.Printf("Received SASL auth response: %s", res.SaslAuthBytes)
-	}
-
 	return bytesRead, nil
 }
 
 func (b *Broker) updateIncomingCommunicationMetrics(bytes int, requestLatency time.Duration) {
 	b.updateRequestLatencyMetrics(requestLatency)
 	b.responseRate.Mark(1)
+
 	if b.brokerResponseRate != nil {
 		b.brokerResponseRate.Mark(1)
 	}
+
 	responseSize := int64(bytes)
 	b.incomingByteRate.Mark(responseSize)
 	if b.brokerIncomingByteRate != nil {
 		b.brokerIncomingByteRate.Mark(responseSize)
 	}
+
 	b.responseSize.Update(responseSize)
 	if b.brokerResponseSize != nil {
 		b.brokerResponseSize.Update(responseSize)
@@ -1080,9 +1307,11 @@
 func (b *Broker) updateRequestLatencyMetrics(requestLatency time.Duration) {
 	requestLatencyInMs := int64(requestLatency / time.Millisecond)
 	b.requestLatency.Update(requestLatencyInMs)
+
 	if b.brokerRequestLatency != nil {
 		b.brokerRequestLatency.Update(requestLatencyInMs)
 	}
+
 }
 
 func (b *Broker) updateOutgoingCommunicationMetrics(bytes int) {
@@ -1090,13 +1319,44 @@
 	if b.brokerRequestRate != nil {
 		b.brokerRequestRate.Mark(1)
 	}
+
 	requestSize := int64(bytes)
 	b.outgoingByteRate.Mark(requestSize)
 	if b.brokerOutgoingByteRate != nil {
 		b.brokerOutgoingByteRate.Mark(requestSize)
 	}
+
 	b.requestSize.Update(requestSize)
 	if b.brokerRequestSize != nil {
 		b.brokerRequestSize.Update(requestSize)
 	}
+
+}
+
+func (b *Broker) registerMetrics() {
+	b.brokerIncomingByteRate = b.registerMeter("incoming-byte-rate")
+	b.brokerRequestRate = b.registerMeter("request-rate")
+	b.brokerRequestSize = b.registerHistogram("request-size")
+	b.brokerRequestLatency = b.registerHistogram("request-latency-in-ms")
+	b.brokerOutgoingByteRate = b.registerMeter("outgoing-byte-rate")
+	b.brokerResponseRate = b.registerMeter("response-rate")
+	b.brokerResponseSize = b.registerHistogram("response-size")
+}
+
+func (b *Broker) unregisterMetrics() {
+	for _, name := range b.registeredMetrics {
+		b.conf.MetricRegistry.Unregister(name)
+	}
+}
+
+func (b *Broker) registerMeter(name string) metrics.Meter {
+	nameForBroker := getMetricNameForBroker(name, b)
+	b.registeredMetrics = append(b.registeredMetrics, nameForBroker)
+	return metrics.GetOrRegisterMeter(nameForBroker, b.conf.MetricRegistry)
+}
+
+func (b *Broker) registerHistogram(name string) metrics.Histogram {
+	nameForBroker := getMetricNameForBroker(name, b)
+	b.registeredMetrics = append(b.registeredMetrics, nameForBroker)
+	return getOrRegisterHistogram(nameForBroker, b.conf.MetricRegistry)
 }