| // Copyright 2015 The etcd Authors |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| // Package srv looks up DNS SRV records. |
| package srv |
| |
| import ( |
| "fmt" |
| "net" |
| "net/url" |
| "strings" |
| |
| "go.etcd.io/etcd/pkg/types" |
| ) |
| |
| var ( |
| // indirection for testing |
| lookupSRV = net.LookupSRV // net.DefaultResolver.LookupSRV when ctxs don't conflict |
| resolveTCPAddr = net.ResolveTCPAddr |
| ) |
| |
| // GetCluster gets the cluster information via DNS discovery. |
| // Also sees each entry as a separate instance. |
| func GetCluster(serviceScheme, service, name, dns string, apurls types.URLs) ([]string, error) { |
| tempName := int(0) |
| tcp2ap := make(map[string]url.URL) |
| |
| // First, resolve the apurls |
| for _, url := range apurls { |
| tcpAddr, err := resolveTCPAddr("tcp", url.Host) |
| if err != nil { |
| return nil, err |
| } |
| tcp2ap[tcpAddr.String()] = url |
| } |
| |
| stringParts := []string{} |
| updateNodeMap := func(service, scheme string) error { |
| _, addrs, err := lookupSRV(service, "tcp", dns) |
| if err != nil { |
| return err |
| } |
| for _, srv := range addrs { |
| port := fmt.Sprintf("%d", srv.Port) |
| host := net.JoinHostPort(srv.Target, port) |
| tcpAddr, terr := resolveTCPAddr("tcp", host) |
| if terr != nil { |
| err = terr |
| continue |
| } |
| n := "" |
| url, ok := tcp2ap[tcpAddr.String()] |
| if ok { |
| n = name |
| } |
| if n == "" { |
| n = fmt.Sprintf("%d", tempName) |
| tempName++ |
| } |
| // SRV records have a trailing dot but URL shouldn't. |
| shortHost := strings.TrimSuffix(srv.Target, ".") |
| urlHost := net.JoinHostPort(shortHost, port) |
| if ok && url.Scheme != scheme { |
| err = fmt.Errorf("bootstrap at %s from DNS for %s has scheme mismatch with expected peer %s", scheme+"://"+urlHost, service, url.String()) |
| } else { |
| stringParts = append(stringParts, fmt.Sprintf("%s=%s://%s", n, scheme, urlHost)) |
| } |
| } |
| if len(stringParts) == 0 { |
| return err |
| } |
| return nil |
| } |
| |
| err := updateNodeMap(service, serviceScheme) |
| if err != nil { |
| return nil, fmt.Errorf("error querying DNS SRV records for _%s %s", service, err) |
| } |
| return stringParts, nil |
| } |
| |
| type SRVClients struct { |
| Endpoints []string |
| SRVs []*net.SRV |
| } |
| |
| // GetClient looks up the client endpoints for a service and domain. |
| func GetClient(service, domain string, serviceName string) (*SRVClients, error) { |
| var urls []*url.URL |
| var srvs []*net.SRV |
| |
| updateURLs := func(service, scheme string) error { |
| _, addrs, err := lookupSRV(service, "tcp", domain) |
| if err != nil { |
| return err |
| } |
| for _, srv := range addrs { |
| urls = append(urls, &url.URL{ |
| Scheme: scheme, |
| Host: net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port)), |
| }) |
| } |
| srvs = append(srvs, addrs...) |
| return nil |
| } |
| |
| errHTTPS := updateURLs(GetSRVService(service, serviceName, "https"), "https") |
| errHTTP := updateURLs(GetSRVService(service, serviceName, "http"), "http") |
| |
| if errHTTPS != nil && errHTTP != nil { |
| return nil, fmt.Errorf("dns lookup errors: %s and %s", errHTTPS, errHTTP) |
| } |
| |
| endpoints := make([]string, len(urls)) |
| for i := range urls { |
| endpoints[i] = urls[i].String() |
| } |
| return &SRVClients{Endpoints: endpoints, SRVs: srvs}, nil |
| } |
| |
| // GetSRVService generates a SRV service including an optional suffix. |
| func GetSRVService(service, serviceName string, scheme string) (SRVService string) { |
| if scheme == "https" { |
| service = fmt.Sprintf("%s-ssl", service) |
| } |
| |
| if serviceName != "" { |
| return fmt.Sprintf("%s-%s", service, serviceName) |
| } |
| return service |
| } |