blob: 9129089acbf0bb337e9f6ee1474aa1181eb3aa3d [file] [log] [blame]
William Kurkianea869482019-04-09 15:16:11 -04001package sarama
2
3import (
4 "crypto/tls"
5 "encoding/binary"
6 "fmt"
7 "io"
8 "net"
9 "sort"
10 "strconv"
11 "strings"
12 "sync"
13 "sync/atomic"
14 "time"
15
16 "github.com/rcrowley/go-metrics"
17)
18
19// Broker represents a single Kafka broker connection. All operations on this object are entirely concurrency-safe.
20type Broker struct {
21 id int32
22 addr string
23 rack *string
24
25 conf *Config
26 correlationID int32
27 conn net.Conn
28 connErr error
29 lock sync.Mutex
30 opened int32
31
32 responses chan responsePromise
33 done chan bool
34
35 incomingByteRate metrics.Meter
36 requestRate metrics.Meter
37 requestSize metrics.Histogram
38 requestLatency metrics.Histogram
39 outgoingByteRate metrics.Meter
40 responseRate metrics.Meter
41 responseSize metrics.Histogram
42 brokerIncomingByteRate metrics.Meter
43 brokerRequestRate metrics.Meter
44 brokerRequestSize metrics.Histogram
45 brokerRequestLatency metrics.Histogram
46 brokerOutgoingByteRate metrics.Meter
47 brokerResponseRate metrics.Meter
48 brokerResponseSize metrics.Histogram
49}
50
51// SASLMechanism specifies the SASL mechanism the client uses to authenticate with the broker
52type SASLMechanism string
53
54const (
55 // SASLTypeOAuth represents the SASL/OAUTHBEARER mechanism (Kafka 2.0.0+)
56 SASLTypeOAuth = "OAUTHBEARER"
57 // SASLTypePlaintext represents the SASL/PLAIN mechanism
58 SASLTypePlaintext = "PLAIN"
59 // SASLHandshakeV0 is v0 of the Kafka SASL handshake protocol. Client and
60 // server negotiate SASL auth using opaque packets.
61 SASLHandshakeV0 = int16(0)
62 // SASLHandshakeV1 is v1 of the Kafka SASL handshake protocol. Client and
63 // server negotiate SASL by wrapping tokens with Kafka protocol headers.
64 SASLHandshakeV1 = int16(1)
65 // SASLExtKeyAuth is the reserved extension key name sent as part of the
66 // SASL/OAUTHBEARER intial client response
67 SASLExtKeyAuth = "auth"
68)
69
70// AccessToken contains an access token used to authenticate a
71// SASL/OAUTHBEARER client along with associated metadata.
72type AccessToken struct {
73 // Token is the access token payload.
74 Token string
75 // Extensions is a optional map of arbitrary key-value pairs that can be
76 // sent with the SASL/OAUTHBEARER initial client response. These values are
77 // ignored by the SASL server if they are unexpected. This feature is only
78 // supported by Kafka >= 2.1.0.
79 Extensions map[string]string
80}
81
82// AccessTokenProvider is the interface that encapsulates how implementors
83// can generate access tokens for Kafka broker authentication.
84type AccessTokenProvider interface {
85 // Token returns an access token. The implementation should ensure token
86 // reuse so that multiple calls at connect time do not create multiple
87 // tokens. The implementation should also periodically refresh the token in
88 // order to guarantee that each call returns an unexpired token. This
89 // method should not block indefinitely--a timeout error should be returned
90 // after a short period of inactivity so that the broker connection logic
91 // can log debugging information and retry.
92 Token() (*AccessToken, error)
93}
94
95type responsePromise struct {
96 requestTime time.Time
97 correlationID int32
98 packets chan []byte
99 errors chan error
100}
101
102// NewBroker creates and returns a Broker targeting the given host:port address.
103// This does not attempt to actually connect, you have to call Open() for that.
104func NewBroker(addr string) *Broker {
105 return &Broker{id: -1, addr: addr}
106}
107
108// Open tries to connect to the Broker if it is not already connected or connecting, but does not block
109// waiting for the connection to complete. This means that any subsequent operations on the broker will
110// block waiting for the connection to succeed or fail. To get the effect of a fully synchronous Open call,
111// follow it by a call to Connected(). The only errors Open will return directly are ConfigurationError or
112// AlreadyConnected. If conf is nil, the result of NewConfig() is used.
113func (b *Broker) Open(conf *Config) error {
114 if !atomic.CompareAndSwapInt32(&b.opened, 0, 1) {
115 return ErrAlreadyConnected
116 }
117
118 if conf == nil {
119 conf = NewConfig()
120 }
121
122 err := conf.Validate()
123 if err != nil {
124 return err
125 }
126
127 b.lock.Lock()
128
129 go withRecover(func() {
130 defer b.lock.Unlock()
131
132 dialer := net.Dialer{
133 Timeout: conf.Net.DialTimeout,
134 KeepAlive: conf.Net.KeepAlive,
135 LocalAddr: conf.Net.LocalAddr,
136 }
137
138 if conf.Net.TLS.Enable {
139 b.conn, b.connErr = tls.DialWithDialer(&dialer, "tcp", b.addr, conf.Net.TLS.Config)
140 } else {
141 b.conn, b.connErr = dialer.Dial("tcp", b.addr)
142 }
143 if b.connErr != nil {
144 Logger.Printf("Failed to connect to broker %s: %s\n", b.addr, b.connErr)
145 b.conn = nil
146 atomic.StoreInt32(&b.opened, 0)
147 return
148 }
149 b.conn = newBufConn(b.conn)
150
151 b.conf = conf
152
153 // Create or reuse the global metrics shared between brokers
154 b.incomingByteRate = metrics.GetOrRegisterMeter("incoming-byte-rate", conf.MetricRegistry)
155 b.requestRate = metrics.GetOrRegisterMeter("request-rate", conf.MetricRegistry)
156 b.requestSize = getOrRegisterHistogram("request-size", conf.MetricRegistry)
157 b.requestLatency = getOrRegisterHistogram("request-latency-in-ms", conf.MetricRegistry)
158 b.outgoingByteRate = metrics.GetOrRegisterMeter("outgoing-byte-rate", conf.MetricRegistry)
159 b.responseRate = metrics.GetOrRegisterMeter("response-rate", conf.MetricRegistry)
160 b.responseSize = getOrRegisterHistogram("response-size", conf.MetricRegistry)
161 // Do not gather metrics for seeded broker (only used during bootstrap) because they share
162 // the same id (-1) and are already exposed through the global metrics above
163 if b.id >= 0 {
164 b.brokerIncomingByteRate = getOrRegisterBrokerMeter("incoming-byte-rate", b, conf.MetricRegistry)
165 b.brokerRequestRate = getOrRegisterBrokerMeter("request-rate", b, conf.MetricRegistry)
166 b.brokerRequestSize = getOrRegisterBrokerHistogram("request-size", b, conf.MetricRegistry)
167 b.brokerRequestLatency = getOrRegisterBrokerHistogram("request-latency-in-ms", b, conf.MetricRegistry)
168 b.brokerOutgoingByteRate = getOrRegisterBrokerMeter("outgoing-byte-rate", b, conf.MetricRegistry)
169 b.brokerResponseRate = getOrRegisterBrokerMeter("response-rate", b, conf.MetricRegistry)
170 b.brokerResponseSize = getOrRegisterBrokerHistogram("response-size", b, conf.MetricRegistry)
171 }
172
173 if conf.Net.SASL.Enable {
174
175 b.connErr = b.authenticateViaSASL()
176
177 if b.connErr != nil {
178 err = b.conn.Close()
179 if err == nil {
180 Logger.Printf("Closed connection to broker %s\n", b.addr)
181 } else {
182 Logger.Printf("Error while closing connection to broker %s: %s\n", b.addr, err)
183 }
184 b.conn = nil
185 atomic.StoreInt32(&b.opened, 0)
186 return
187 }
188 }
189
190 b.done = make(chan bool)
191 b.responses = make(chan responsePromise, b.conf.Net.MaxOpenRequests-1)
192
193 if b.id >= 0 {
194 Logger.Printf("Connected to broker at %s (registered as #%d)\n", b.addr, b.id)
195 } else {
196 Logger.Printf("Connected to broker at %s (unregistered)\n", b.addr)
197 }
198 go withRecover(b.responseReceiver)
199 })
200
201 return nil
202}
203
204// Connected returns true if the broker is connected and false otherwise. If the broker is not
205// connected but it had tried to connect, the error from that connection attempt is also returned.
206func (b *Broker) Connected() (bool, error) {
207 b.lock.Lock()
208 defer b.lock.Unlock()
209
210 return b.conn != nil, b.connErr
211}
212
213func (b *Broker) Close() error {
214 b.lock.Lock()
215 defer b.lock.Unlock()
216
217 if b.conn == nil {
218 return ErrNotConnected
219 }
220
221 close(b.responses)
222 <-b.done
223
224 err := b.conn.Close()
225
226 b.conn = nil
227 b.connErr = nil
228 b.done = nil
229 b.responses = nil
230
231 if b.id >= 0 {
232 b.conf.MetricRegistry.Unregister(getMetricNameForBroker("incoming-byte-rate", b))
233 b.conf.MetricRegistry.Unregister(getMetricNameForBroker("request-rate", b))
234 b.conf.MetricRegistry.Unregister(getMetricNameForBroker("outgoing-byte-rate", b))
235 b.conf.MetricRegistry.Unregister(getMetricNameForBroker("response-rate", b))
236 }
237
238 if err == nil {
239 Logger.Printf("Closed connection to broker %s\n", b.addr)
240 } else {
241 Logger.Printf("Error while closing connection to broker %s: %s\n", b.addr, err)
242 }
243
244 atomic.StoreInt32(&b.opened, 0)
245
246 return err
247}
248
249// ID returns the broker ID retrieved from Kafka's metadata, or -1 if that is not known.
250func (b *Broker) ID() int32 {
251 return b.id
252}
253
254// Addr returns the broker address as either retrieved from Kafka's metadata or passed to NewBroker.
255func (b *Broker) Addr() string {
256 return b.addr
257}
258
259// Rack returns the broker's rack as retrieved from Kafka's metadata or the
260// empty string if it is not known. The returned value corresponds to the
261// broker's broker.rack configuration setting. Requires protocol version to be
262// at least v0.10.0.0.
263func (b *Broker) Rack() string {
264 if b.rack == nil {
265 return ""
266 }
267 return *b.rack
268}
269
270func (b *Broker) GetMetadata(request *MetadataRequest) (*MetadataResponse, error) {
271 response := new(MetadataResponse)
272
273 err := b.sendAndReceive(request, response)
274
275 if err != nil {
276 return nil, err
277 }
278
279 return response, nil
280}
281
282func (b *Broker) GetConsumerMetadata(request *ConsumerMetadataRequest) (*ConsumerMetadataResponse, error) {
283 response := new(ConsumerMetadataResponse)
284
285 err := b.sendAndReceive(request, response)
286
287 if err != nil {
288 return nil, err
289 }
290
291 return response, nil
292}
293
294func (b *Broker) FindCoordinator(request *FindCoordinatorRequest) (*FindCoordinatorResponse, error) {
295 response := new(FindCoordinatorResponse)
296
297 err := b.sendAndReceive(request, response)
298
299 if err != nil {
300 return nil, err
301 }
302
303 return response, nil
304}
305
306func (b *Broker) GetAvailableOffsets(request *OffsetRequest) (*OffsetResponse, error) {
307 response := new(OffsetResponse)
308
309 err := b.sendAndReceive(request, response)
310
311 if err != nil {
312 return nil, err
313 }
314
315 return response, nil
316}
317
318func (b *Broker) Produce(request *ProduceRequest) (*ProduceResponse, error) {
319 var response *ProduceResponse
320 var err error
321
322 if request.RequiredAcks == NoResponse {
323 err = b.sendAndReceive(request, nil)
324 } else {
325 response = new(ProduceResponse)
326 err = b.sendAndReceive(request, response)
327 }
328
329 if err != nil {
330 return nil, err
331 }
332
333 return response, nil
334}
335
336func (b *Broker) Fetch(request *FetchRequest) (*FetchResponse, error) {
337 response := new(FetchResponse)
338
339 err := b.sendAndReceive(request, response)
340
341 if err != nil {
342 return nil, err
343 }
344
345 return response, nil
346}
347
348func (b *Broker) CommitOffset(request *OffsetCommitRequest) (*OffsetCommitResponse, error) {
349 response := new(OffsetCommitResponse)
350
351 err := b.sendAndReceive(request, response)
352
353 if err != nil {
354 return nil, err
355 }
356
357 return response, nil
358}
359
360func (b *Broker) FetchOffset(request *OffsetFetchRequest) (*OffsetFetchResponse, error) {
361 response := new(OffsetFetchResponse)
362
363 err := b.sendAndReceive(request, response)
364
365 if err != nil {
366 return nil, err
367 }
368
369 return response, nil
370}
371
372func (b *Broker) JoinGroup(request *JoinGroupRequest) (*JoinGroupResponse, error) {
373 response := new(JoinGroupResponse)
374
375 err := b.sendAndReceive(request, response)
376 if err != nil {
377 return nil, err
378 }
379
380 return response, nil
381}
382
383func (b *Broker) SyncGroup(request *SyncGroupRequest) (*SyncGroupResponse, error) {
384 response := new(SyncGroupResponse)
385
386 err := b.sendAndReceive(request, response)
387 if err != nil {
388 return nil, err
389 }
390
391 return response, nil
392}
393
394func (b *Broker) LeaveGroup(request *LeaveGroupRequest) (*LeaveGroupResponse, error) {
395 response := new(LeaveGroupResponse)
396
397 err := b.sendAndReceive(request, response)
398 if err != nil {
399 return nil, err
400 }
401
402 return response, nil
403}
404
405func (b *Broker) Heartbeat(request *HeartbeatRequest) (*HeartbeatResponse, error) {
406 response := new(HeartbeatResponse)
407
408 err := b.sendAndReceive(request, response)
409 if err != nil {
410 return nil, err
411 }
412
413 return response, nil
414}
415
416func (b *Broker) ListGroups(request *ListGroupsRequest) (*ListGroupsResponse, error) {
417 response := new(ListGroupsResponse)
418
419 err := b.sendAndReceive(request, response)
420 if err != nil {
421 return nil, err
422 }
423
424 return response, nil
425}
426
427func (b *Broker) DescribeGroups(request *DescribeGroupsRequest) (*DescribeGroupsResponse, error) {
428 response := new(DescribeGroupsResponse)
429
430 err := b.sendAndReceive(request, response)
431 if err != nil {
432 return nil, err
433 }
434
435 return response, nil
436}
437
438func (b *Broker) ApiVersions(request *ApiVersionsRequest) (*ApiVersionsResponse, error) {
439 response := new(ApiVersionsResponse)
440
441 err := b.sendAndReceive(request, response)
442 if err != nil {
443 return nil, err
444 }
445
446 return response, nil
447}
448
449func (b *Broker) CreateTopics(request *CreateTopicsRequest) (*CreateTopicsResponse, error) {
450 response := new(CreateTopicsResponse)
451
452 err := b.sendAndReceive(request, response)
453 if err != nil {
454 return nil, err
455 }
456
457 return response, nil
458}
459
460func (b *Broker) DeleteTopics(request *DeleteTopicsRequest) (*DeleteTopicsResponse, error) {
461 response := new(DeleteTopicsResponse)
462
463 err := b.sendAndReceive(request, response)
464 if err != nil {
465 return nil, err
466 }
467
468 return response, nil
469}
470
471func (b *Broker) CreatePartitions(request *CreatePartitionsRequest) (*CreatePartitionsResponse, error) {
472 response := new(CreatePartitionsResponse)
473
474 err := b.sendAndReceive(request, response)
475 if err != nil {
476 return nil, err
477 }
478
479 return response, nil
480}
481
482func (b *Broker) DeleteRecords(request *DeleteRecordsRequest) (*DeleteRecordsResponse, error) {
483 response := new(DeleteRecordsResponse)
484
485 err := b.sendAndReceive(request, response)
486 if err != nil {
487 return nil, err
488 }
489
490 return response, nil
491}
492
493func (b *Broker) DescribeAcls(request *DescribeAclsRequest) (*DescribeAclsResponse, error) {
494 response := new(DescribeAclsResponse)
495
496 err := b.sendAndReceive(request, response)
497 if err != nil {
498 return nil, err
499 }
500
501 return response, nil
502}
503
504func (b *Broker) CreateAcls(request *CreateAclsRequest) (*CreateAclsResponse, error) {
505 response := new(CreateAclsResponse)
506
507 err := b.sendAndReceive(request, response)
508 if err != nil {
509 return nil, err
510 }
511
512 return response, nil
513}
514
515func (b *Broker) DeleteAcls(request *DeleteAclsRequest) (*DeleteAclsResponse, error) {
516 response := new(DeleteAclsResponse)
517
518 err := b.sendAndReceive(request, response)
519 if err != nil {
520 return nil, err
521 }
522
523 return response, nil
524}
525
526func (b *Broker) InitProducerID(request *InitProducerIDRequest) (*InitProducerIDResponse, error) {
527 response := new(InitProducerIDResponse)
528
529 err := b.sendAndReceive(request, response)
530 if err != nil {
531 return nil, err
532 }
533
534 return response, nil
535}
536
537func (b *Broker) AddPartitionsToTxn(request *AddPartitionsToTxnRequest) (*AddPartitionsToTxnResponse, error) {
538 response := new(AddPartitionsToTxnResponse)
539
540 err := b.sendAndReceive(request, response)
541 if err != nil {
542 return nil, err
543 }
544
545 return response, nil
546}
547
548func (b *Broker) AddOffsetsToTxn(request *AddOffsetsToTxnRequest) (*AddOffsetsToTxnResponse, error) {
549 response := new(AddOffsetsToTxnResponse)
550
551 err := b.sendAndReceive(request, response)
552 if err != nil {
553 return nil, err
554 }
555
556 return response, nil
557}
558
559func (b *Broker) EndTxn(request *EndTxnRequest) (*EndTxnResponse, error) {
560 response := new(EndTxnResponse)
561
562 err := b.sendAndReceive(request, response)
563 if err != nil {
564 return nil, err
565 }
566
567 return response, nil
568}
569
570func (b *Broker) TxnOffsetCommit(request *TxnOffsetCommitRequest) (*TxnOffsetCommitResponse, error) {
571 response := new(TxnOffsetCommitResponse)
572
573 err := b.sendAndReceive(request, response)
574 if err != nil {
575 return nil, err
576 }
577
578 return response, nil
579}
580
581func (b *Broker) DescribeConfigs(request *DescribeConfigsRequest) (*DescribeConfigsResponse, error) {
582 response := new(DescribeConfigsResponse)
583
584 err := b.sendAndReceive(request, response)
585 if err != nil {
586 return nil, err
587 }
588
589 return response, nil
590}
591
592func (b *Broker) AlterConfigs(request *AlterConfigsRequest) (*AlterConfigsResponse, error) {
593 response := new(AlterConfigsResponse)
594
595 err := b.sendAndReceive(request, response)
596 if err != nil {
597 return nil, err
598 }
599
600 return response, nil
601}
602
603func (b *Broker) DeleteGroups(request *DeleteGroupsRequest) (*DeleteGroupsResponse, error) {
604 response := new(DeleteGroupsResponse)
605
606 if err := b.sendAndReceive(request, response); err != nil {
607 return nil, err
608 }
609
610 return response, nil
611}
612
613func (b *Broker) send(rb protocolBody, promiseResponse bool) (*responsePromise, error) {
614 b.lock.Lock()
615 defer b.lock.Unlock()
616
617 if b.conn == nil {
618 if b.connErr != nil {
619 return nil, b.connErr
620 }
621 return nil, ErrNotConnected
622 }
623
624 if !b.conf.Version.IsAtLeast(rb.requiredVersion()) {
625 return nil, ErrUnsupportedVersion
626 }
627
628 req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
629 buf, err := encode(req, b.conf.MetricRegistry)
630 if err != nil {
631 return nil, err
632 }
633
634 err = b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout))
635 if err != nil {
636 return nil, err
637 }
638
639 requestTime := time.Now()
640 bytes, err := b.conn.Write(buf)
641 b.updateOutgoingCommunicationMetrics(bytes)
642 if err != nil {
643 return nil, err
644 }
645 b.correlationID++
646
647 if !promiseResponse {
648 // Record request latency without the response
649 b.updateRequestLatencyMetrics(time.Since(requestTime))
650 return nil, nil
651 }
652
653 promise := responsePromise{requestTime, req.correlationID, make(chan []byte), make(chan error)}
654 b.responses <- promise
655
656 return &promise, nil
657}
658
659func (b *Broker) sendAndReceive(req protocolBody, res versionedDecoder) error {
660 promise, err := b.send(req, res != nil)
661
662 if err != nil {
663 return err
664 }
665
666 if promise == nil {
667 return nil
668 }
669
670 select {
671 case buf := <-promise.packets:
672 return versionedDecode(buf, res, req.version())
673 case err = <-promise.errors:
674 return err
675 }
676}
677
678func (b *Broker) decode(pd packetDecoder, version int16) (err error) {
679 b.id, err = pd.getInt32()
680 if err != nil {
681 return err
682 }
683
684 host, err := pd.getString()
685 if err != nil {
686 return err
687 }
688
689 port, err := pd.getInt32()
690 if err != nil {
691 return err
692 }
693
694 if version >= 1 {
695 b.rack, err = pd.getNullableString()
696 if err != nil {
697 return err
698 }
699 }
700
701 b.addr = net.JoinHostPort(host, fmt.Sprint(port))
702 if _, _, err := net.SplitHostPort(b.addr); err != nil {
703 return err
704 }
705
706 return nil
707}
708
709func (b *Broker) encode(pe packetEncoder, version int16) (err error) {
710
711 host, portstr, err := net.SplitHostPort(b.addr)
712 if err != nil {
713 return err
714 }
715 port, err := strconv.Atoi(portstr)
716 if err != nil {
717 return err
718 }
719
720 pe.putInt32(b.id)
721
722 err = pe.putString(host)
723 if err != nil {
724 return err
725 }
726
727 pe.putInt32(int32(port))
728
729 if version >= 1 {
730 err = pe.putNullableString(b.rack)
731 if err != nil {
732 return err
733 }
734 }
735
736 return nil
737}
738
739func (b *Broker) responseReceiver() {
740 var dead error
741 header := make([]byte, 8)
742 for response := range b.responses {
743 if dead != nil {
744 response.errors <- dead
745 continue
746 }
747
748 err := b.conn.SetReadDeadline(time.Now().Add(b.conf.Net.ReadTimeout))
749 if err != nil {
750 dead = err
751 response.errors <- err
752 continue
753 }
754
755 bytesReadHeader, err := io.ReadFull(b.conn, header)
756 requestLatency := time.Since(response.requestTime)
757 if err != nil {
758 b.updateIncomingCommunicationMetrics(bytesReadHeader, requestLatency)
759 dead = err
760 response.errors <- err
761 continue
762 }
763
764 decodedHeader := responseHeader{}
765 err = decode(header, &decodedHeader)
766 if err != nil {
767 b.updateIncomingCommunicationMetrics(bytesReadHeader, requestLatency)
768 dead = err
769 response.errors <- err
770 continue
771 }
772 if decodedHeader.correlationID != response.correlationID {
773 b.updateIncomingCommunicationMetrics(bytesReadHeader, requestLatency)
774 // TODO if decoded ID < cur ID, discard until we catch up
775 // TODO if decoded ID > cur ID, save it so when cur ID catches up we have a response
776 dead = PacketDecodingError{fmt.Sprintf("correlation ID didn't match, wanted %d, got %d", response.correlationID, decodedHeader.correlationID)}
777 response.errors <- dead
778 continue
779 }
780
781 buf := make([]byte, decodedHeader.length-4)
782 bytesReadBody, err := io.ReadFull(b.conn, buf)
783 b.updateIncomingCommunicationMetrics(bytesReadHeader+bytesReadBody, requestLatency)
784 if err != nil {
785 dead = err
786 response.errors <- err
787 continue
788 }
789
790 response.packets <- buf
791 }
792 close(b.done)
793}
794
795func (b *Broker) authenticateViaSASL() error {
796 if b.conf.Net.SASL.Mechanism == SASLTypeOAuth {
797 return b.sendAndReceiveSASLOAuth(b.conf.Net.SASL.TokenProvider)
798 }
799 return b.sendAndReceiveSASLPlainAuth()
800}
801
802func (b *Broker) sendAndReceiveSASLHandshake(saslType string, version int16) error {
803 rb := &SaslHandshakeRequest{Mechanism: saslType, Version: version}
804
805 req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
806 buf, err := encode(req, b.conf.MetricRegistry)
807 if err != nil {
808 return err
809 }
810
811 err = b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout))
812 if err != nil {
813 return err
814 }
815
816 requestTime := time.Now()
817 bytes, err := b.conn.Write(buf)
818 b.updateOutgoingCommunicationMetrics(bytes)
819 if err != nil {
820 Logger.Printf("Failed to send SASL handshake %s: %s\n", b.addr, err.Error())
821 return err
822 }
823 b.correlationID++
824 //wait for the response
825 header := make([]byte, 8) // response header
826 _, err = io.ReadFull(b.conn, header)
827 if err != nil {
828 Logger.Printf("Failed to read SASL handshake header : %s\n", err.Error())
829 return err
830 }
831 length := binary.BigEndian.Uint32(header[:4])
832 payload := make([]byte, length-4)
833 n, err := io.ReadFull(b.conn, payload)
834 if err != nil {
835 Logger.Printf("Failed to read SASL handshake payload : %s\n", err.Error())
836 return err
837 }
838 b.updateIncomingCommunicationMetrics(n+8, time.Since(requestTime))
839 res := &SaslHandshakeResponse{}
840 err = versionedDecode(payload, res, 0)
841 if err != nil {
842 Logger.Printf("Failed to parse SASL handshake : %s\n", err.Error())
843 return err
844 }
845 if res.Err != ErrNoError {
846 Logger.Printf("Invalid SASL Mechanism : %s\n", res.Err.Error())
847 return res.Err
848 }
849 Logger.Print("Successful SASL handshake")
850 return nil
851}
852
853// Kafka 0.10.0 plans to support SASL Plain and Kerberos as per PR #812 (KIP-43)/(JIRA KAFKA-3149)
854// Some hosted kafka services such as IBM Message Hub already offer SASL/PLAIN auth with Kafka 0.9
855//
856// In SASL Plain, Kafka expects the auth header to be in the following format
857// Message format (from https://tools.ietf.org/html/rfc4616):
858//
859// message = [authzid] UTF8NUL authcid UTF8NUL passwd
860// authcid = 1*SAFE ; MUST accept up to 255 octets
861// authzid = 1*SAFE ; MUST accept up to 255 octets
862// passwd = 1*SAFE ; MUST accept up to 255 octets
863// UTF8NUL = %x00 ; UTF-8 encoded NUL character
864//
865// SAFE = UTF1 / UTF2 / UTF3 / UTF4
866// ;; any UTF-8 encoded Unicode character except NUL
867//
868// When credentials are valid, Kafka returns a 4 byte array of null characters.
869// When credentials are invalid, Kafka closes the connection. This does not seem to be the ideal way
870// of responding to bad credentials but thats how its being done today.
871func (b *Broker) sendAndReceiveSASLPlainAuth() error {
872 if b.conf.Net.SASL.Handshake {
873 handshakeErr := b.sendAndReceiveSASLHandshake(SASLTypePlaintext, SASLHandshakeV0)
874 if handshakeErr != nil {
875 Logger.Printf("Error while performing SASL handshake %s\n", b.addr)
876 return handshakeErr
877 }
878 }
879 length := 1 + len(b.conf.Net.SASL.User) + 1 + len(b.conf.Net.SASL.Password)
880 authBytes := make([]byte, length+4) //4 byte length header + auth data
881 binary.BigEndian.PutUint32(authBytes, uint32(length))
882 copy(authBytes[4:], []byte("\x00"+b.conf.Net.SASL.User+"\x00"+b.conf.Net.SASL.Password))
883
884 err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout))
885 if err != nil {
886 Logger.Printf("Failed to set write deadline when doing SASL auth with broker %s: %s\n", b.addr, err.Error())
887 return err
888 }
889
890 requestTime := time.Now()
891 bytesWritten, err := b.conn.Write(authBytes)
892 b.updateOutgoingCommunicationMetrics(bytesWritten)
893 if err != nil {
894 Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error())
895 return err
896 }
897
898 header := make([]byte, 4)
899 n, err := io.ReadFull(b.conn, header)
900 b.updateIncomingCommunicationMetrics(n, time.Since(requestTime))
901 // If the credentials are valid, we would get a 4 byte response filled with null characters.
902 // Otherwise, the broker closes the connection and we get an EOF
903 if err != nil {
904 Logger.Printf("Failed to read response while authenticating with SASL to broker %s: %s\n", b.addr, err.Error())
905 return err
906 }
907
908 Logger.Printf("SASL authentication successful with broker %s:%v - %v\n", b.addr, n, header)
909 return nil
910}
911
912// sendAndReceiveSASLOAuth performs the authentication flow as described by KIP-255
913// https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=75968876
914func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error {
915
916 if err := b.sendAndReceiveSASLHandshake(SASLTypeOAuth, SASLHandshakeV1); err != nil {
917 return err
918 }
919
920 token, err := provider.Token()
921
922 if err != nil {
923 return err
924 }
925
926 requestTime := time.Now()
927
928 correlationID := b.correlationID
929
930 bytesWritten, err := b.sendSASLOAuthBearerClientResponse(token, correlationID)
931
932 if err != nil {
933 return err
934 }
935
936 b.updateOutgoingCommunicationMetrics(bytesWritten)
937
938 b.correlationID++
939
940 bytesRead, err := b.receiveSASLOAuthBearerServerResponse(correlationID)
941
942 if err != nil {
943 return err
944 }
945
946 requestLatency := time.Since(requestTime)
947 b.updateIncomingCommunicationMetrics(bytesRead, requestLatency)
948
949 return nil
950}
951
952// Build SASL/OAUTHBEARER initial client response as described by RFC-7628
953// https://tools.ietf.org/html/rfc7628
954func buildClientInitialResponse(token *AccessToken) ([]byte, error) {
955
956 var ext string
957
958 if token.Extensions != nil && len(token.Extensions) > 0 {
959 if _, ok := token.Extensions[SASLExtKeyAuth]; ok {
960 return []byte{}, fmt.Errorf("The extension `%s` is invalid", SASLExtKeyAuth)
961 }
962 ext = "\x01" + mapToString(token.Extensions, "=", "\x01")
963 }
964
965 resp := []byte(fmt.Sprintf("n,,\x01auth=Bearer %s%s\x01\x01", token.Token, ext))
966
967 return resp, nil
968}
969
970// mapToString returns a list of key-value pairs ordered by key.
971// keyValSep separates the key from the value. elemSep separates each pair.
972func mapToString(extensions map[string]string, keyValSep string, elemSep string) string {
973
974 buf := make([]string, 0, len(extensions))
975
976 for k, v := range extensions {
977 buf = append(buf, k+keyValSep+v)
978 }
979
980 sort.Strings(buf)
981
982 return strings.Join(buf, elemSep)
983}
984
985func (b *Broker) sendSASLOAuthBearerClientResponse(token *AccessToken, correlationID int32) (int, error) {
986
987 initialResp, err := buildClientInitialResponse(token)
988
989 if err != nil {
990 return 0, err
991 }
992
993 rb := &SaslAuthenticateRequest{initialResp}
994
995 req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
996
997 buf, err := encode(req, b.conf.MetricRegistry)
998
999 if err != nil {
1000 return 0, err
1001 }
1002
1003 if err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)); err != nil {
1004 return 0, err
1005 }
1006
1007 return b.conn.Write(buf)
1008}
1009
1010func (b *Broker) receiveSASLOAuthBearerServerResponse(correlationID int32) (int, error) {
1011
1012 buf := make([]byte, 8)
1013
1014 bytesRead, err := io.ReadFull(b.conn, buf)
1015
1016 if err != nil {
1017 return bytesRead, err
1018 }
1019
1020 header := responseHeader{}
1021
1022 err = decode(buf, &header)
1023
1024 if err != nil {
1025 return bytesRead, err
1026 }
1027
1028 if header.correlationID != correlationID {
1029 return bytesRead, fmt.Errorf("correlation ID didn't match, wanted %d, got %d", b.correlationID, header.correlationID)
1030 }
1031
1032 buf = make([]byte, header.length-4)
1033
1034 c, err := io.ReadFull(b.conn, buf)
1035
1036 bytesRead += c
1037
1038 if err != nil {
1039 return bytesRead, err
1040 }
1041
1042 res := &SaslAuthenticateResponse{}
1043
1044 if err := versionedDecode(buf, res, 0); err != nil {
1045 return bytesRead, err
1046 }
1047
1048 if err != nil {
1049 return bytesRead, err
1050 }
1051
1052 if res.Err != ErrNoError {
1053 return bytesRead, res.Err
1054 }
1055
1056 if len(res.SaslAuthBytes) > 0 {
1057 Logger.Printf("Received SASL auth response: %s", res.SaslAuthBytes)
1058 }
1059
1060 return bytesRead, nil
1061}
1062
1063func (b *Broker) updateIncomingCommunicationMetrics(bytes int, requestLatency time.Duration) {
1064 b.updateRequestLatencyMetrics(requestLatency)
1065 b.responseRate.Mark(1)
1066 if b.brokerResponseRate != nil {
1067 b.brokerResponseRate.Mark(1)
1068 }
1069 responseSize := int64(bytes)
1070 b.incomingByteRate.Mark(responseSize)
1071 if b.brokerIncomingByteRate != nil {
1072 b.brokerIncomingByteRate.Mark(responseSize)
1073 }
1074 b.responseSize.Update(responseSize)
1075 if b.brokerResponseSize != nil {
1076 b.brokerResponseSize.Update(responseSize)
1077 }
1078}
1079
1080func (b *Broker) updateRequestLatencyMetrics(requestLatency time.Duration) {
1081 requestLatencyInMs := int64(requestLatency / time.Millisecond)
1082 b.requestLatency.Update(requestLatencyInMs)
1083 if b.brokerRequestLatency != nil {
1084 b.brokerRequestLatency.Update(requestLatencyInMs)
1085 }
1086}
1087
1088func (b *Broker) updateOutgoingCommunicationMetrics(bytes int) {
1089 b.requestRate.Mark(1)
1090 if b.brokerRequestRate != nil {
1091 b.brokerRequestRate.Mark(1)
1092 }
1093 requestSize := int64(bytes)
1094 b.outgoingByteRate.Mark(requestSize)
1095 if b.brokerOutgoingByteRate != nil {
1096 b.brokerOutgoingByteRate.Mark(requestSize)
1097 }
1098 b.requestSize.Update(requestSize)
1099 if b.brokerRequestSize != nil {
1100 b.brokerRequestSize.Update(requestSize)
1101 }
1102}