blob: c3560026d4817c26ff1e11a95dba81f7d17a6b03 [file] [log] [blame]
khenaidoo59ce9dd2019-11-11 13:05:32 -05001// Copyright 2015 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package srv looks up DNS SRV records.
16package srv
17
18import (
19 "fmt"
20 "net"
21 "net/url"
22 "strings"
23
24 "go.etcd.io/etcd/pkg/types"
25)
26
27var (
28 // indirection for testing
29 lookupSRV = net.LookupSRV // net.DefaultResolver.LookupSRV when ctxs don't conflict
30 resolveTCPAddr = net.ResolveTCPAddr
31)
32
33// GetCluster gets the cluster information via DNS discovery.
34// Also sees each entry as a separate instance.
35func GetCluster(serviceScheme, service, name, dns string, apurls types.URLs) ([]string, error) {
36 tempName := int(0)
37 tcp2ap := make(map[string]url.URL)
38
39 // First, resolve the apurls
40 for _, url := range apurls {
41 tcpAddr, err := resolveTCPAddr("tcp", url.Host)
42 if err != nil {
43 return nil, err
44 }
45 tcp2ap[tcpAddr.String()] = url
46 }
47
48 stringParts := []string{}
49 updateNodeMap := func(service, scheme string) error {
50 _, addrs, err := lookupSRV(service, "tcp", dns)
51 if err != nil {
52 return err
53 }
54 for _, srv := range addrs {
55 port := fmt.Sprintf("%d", srv.Port)
56 host := net.JoinHostPort(srv.Target, port)
57 tcpAddr, terr := resolveTCPAddr("tcp", host)
58 if terr != nil {
59 err = terr
60 continue
61 }
62 n := ""
63 url, ok := tcp2ap[tcpAddr.String()]
64 if ok {
65 n = name
66 }
67 if n == "" {
68 n = fmt.Sprintf("%d", tempName)
69 tempName++
70 }
71 // SRV records have a trailing dot but URL shouldn't.
72 shortHost := strings.TrimSuffix(srv.Target, ".")
73 urlHost := net.JoinHostPort(shortHost, port)
74 if ok && url.Scheme != scheme {
75 err = fmt.Errorf("bootstrap at %s from DNS for %s has scheme mismatch with expected peer %s", scheme+"://"+urlHost, service, url.String())
76 } else {
77 stringParts = append(stringParts, fmt.Sprintf("%s=%s://%s", n, scheme, urlHost))
78 }
79 }
80 if len(stringParts) == 0 {
81 return err
82 }
83 return nil
84 }
85
86 err := updateNodeMap(service, serviceScheme)
87 if err != nil {
88 return nil, fmt.Errorf("error querying DNS SRV records for _%s %s", service, err)
89 }
90 return stringParts, nil
91}
92
93type SRVClients struct {
94 Endpoints []string
95 SRVs []*net.SRV
96}
97
98// GetClient looks up the client endpoints for a service and domain.
99func GetClient(service, domain string, serviceName string) (*SRVClients, error) {
100 var urls []*url.URL
101 var srvs []*net.SRV
102
103 updateURLs := func(service, scheme string) error {
104 _, addrs, err := lookupSRV(service, "tcp", domain)
105 if err != nil {
106 return err
107 }
108 for _, srv := range addrs {
109 urls = append(urls, &url.URL{
110 Scheme: scheme,
111 Host: net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port)),
112 })
113 }
114 srvs = append(srvs, addrs...)
115 return nil
116 }
117
118 errHTTPS := updateURLs(GetSRVService(service, serviceName, "https"), "https")
119 errHTTP := updateURLs(GetSRVService(service, serviceName, "http"), "http")
120
121 if errHTTPS != nil && errHTTP != nil {
122 return nil, fmt.Errorf("dns lookup errors: %s and %s", errHTTPS, errHTTP)
123 }
124
125 endpoints := make([]string, len(urls))
126 for i := range urls {
127 endpoints[i] = urls[i].String()
128 }
129 return &SRVClients{Endpoints: endpoints, SRVs: srvs}, nil
130}
131
132// GetSRVService generates a SRV service including an optional suffix.
133func GetSRVService(service, serviceName string, scheme string) (SRVService string) {
134 if scheme == "https" {
135 service = fmt.Sprintf("%s-ssl", service)
136 }
137
138 if serviceName != "" {
139 return fmt.Sprintf("%s-%s", service, serviceName)
140 }
141 return service
142}