Don Newton | 379ae25 | 2019-04-01 12:17:06 -0400 | [diff] [blame^] | 1 | // Copyright (C) MongoDB, Inc. 2017-present. |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may |
| 4 | // not use this file except in compliance with the License. You may obtain |
| 5 | // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | |
| 7 | package connstring |
| 8 | |
| 9 | import ( |
| 10 | "errors" |
| 11 | "fmt" |
| 12 | "net" |
| 13 | "net/url" |
| 14 | "runtime" |
| 15 | "strconv" |
| 16 | "strings" |
| 17 | "time" |
| 18 | |
| 19 | "github.com/mongodb/mongo-go-driver/internal" |
| 20 | "github.com/mongodb/mongo-go-driver/mongo/writeconcern" |
| 21 | "github.com/mongodb/mongo-go-driver/x/network/wiremessage" |
| 22 | ) |
| 23 | |
| 24 | // Parse parses the provided uri and returns a URI object. |
| 25 | func Parse(s string) (ConnString, error) { |
| 26 | var p parser |
| 27 | err := p.parse(s) |
| 28 | if err != nil { |
| 29 | err = internal.WrapErrorf(err, "error parsing uri (%s)", s) |
| 30 | } |
| 31 | return p.ConnString, err |
| 32 | } |
| 33 | |
| 34 | // ConnString represents a connection string to mongodb. |
| 35 | type ConnString struct { |
| 36 | Original string |
| 37 | AppName string |
| 38 | AuthMechanism string |
| 39 | AuthMechanismProperties map[string]string |
| 40 | AuthSource string |
| 41 | Compressors []string |
| 42 | Connect ConnectMode |
| 43 | ConnectSet bool |
| 44 | ConnectTimeout time.Duration |
| 45 | ConnectTimeoutSet bool |
| 46 | Database string |
| 47 | HeartbeatInterval time.Duration |
| 48 | HeartbeatIntervalSet bool |
| 49 | Hosts []string |
| 50 | J bool |
| 51 | JSet bool |
| 52 | LocalThreshold time.Duration |
| 53 | LocalThresholdSet bool |
| 54 | MaxConnIdleTime time.Duration |
| 55 | MaxConnIdleTimeSet bool |
| 56 | MaxPoolSize uint16 |
| 57 | MaxPoolSizeSet bool |
| 58 | Password string |
| 59 | PasswordSet bool |
| 60 | ReadConcernLevel string |
| 61 | ReadPreference string |
| 62 | ReadPreferenceTagSets []map[string]string |
| 63 | RetryWrites bool |
| 64 | RetryWritesSet bool |
| 65 | MaxStaleness time.Duration |
| 66 | MaxStalenessSet bool |
| 67 | ReplicaSet string |
| 68 | ServerSelectionTimeout time.Duration |
| 69 | ServerSelectionTimeoutSet bool |
| 70 | SocketTimeout time.Duration |
| 71 | SocketTimeoutSet bool |
| 72 | SSL bool |
| 73 | SSLSet bool |
| 74 | SSLClientCertificateKeyFile string |
| 75 | SSLClientCertificateKeyFileSet bool |
| 76 | SSLClientCertificateKeyPassword func() string |
| 77 | SSLClientCertificateKeyPasswordSet bool |
| 78 | SSLInsecure bool |
| 79 | SSLInsecureSet bool |
| 80 | SSLCaFile string |
| 81 | SSLCaFileSet bool |
| 82 | WString string |
| 83 | WNumber int |
| 84 | WNumberSet bool |
| 85 | Username string |
| 86 | ZlibLevel int |
| 87 | |
| 88 | WTimeout time.Duration |
| 89 | WTimeoutSet bool |
| 90 | WTimeoutSetFromOption bool |
| 91 | |
| 92 | Options map[string][]string |
| 93 | UnknownOptions map[string][]string |
| 94 | } |
| 95 | |
| 96 | func (u *ConnString) String() string { |
| 97 | return u.Original |
| 98 | } |
| 99 | |
| 100 | // ConnectMode informs the driver on how to connect |
| 101 | // to the server. |
| 102 | type ConnectMode uint8 |
| 103 | |
| 104 | // ConnectMode constants. |
| 105 | const ( |
| 106 | AutoConnect ConnectMode = iota |
| 107 | SingleConnect |
| 108 | ) |
| 109 | |
| 110 | type parser struct { |
| 111 | ConnString |
| 112 | } |
| 113 | |
| 114 | func (p *parser) parse(original string) error { |
| 115 | p.Original = original |
| 116 | uri := original |
| 117 | |
| 118 | var err error |
| 119 | var isSRV bool |
| 120 | if strings.HasPrefix(uri, "mongodb+srv://") { |
| 121 | isSRV = true |
| 122 | // remove the scheme |
| 123 | uri = uri[14:] |
| 124 | } else if strings.HasPrefix(uri, "mongodb://") { |
| 125 | // remove the scheme |
| 126 | uri = uri[10:] |
| 127 | } else { |
| 128 | return fmt.Errorf("scheme must be \"mongodb\" or \"mongodb+srv\"") |
| 129 | } |
| 130 | |
| 131 | if idx := strings.Index(uri, "@"); idx != -1 { |
| 132 | userInfo := uri[:idx] |
| 133 | uri = uri[idx+1:] |
| 134 | |
| 135 | username := userInfo |
| 136 | var password string |
| 137 | |
| 138 | if idx := strings.Index(userInfo, ":"); idx != -1 { |
| 139 | username = userInfo[:idx] |
| 140 | password = userInfo[idx+1:] |
| 141 | p.PasswordSet = true |
| 142 | } |
| 143 | |
| 144 | if len(username) > 1 { |
| 145 | if strings.Contains(username, "/") { |
| 146 | return fmt.Errorf("unescaped slash in username") |
| 147 | } |
| 148 | } |
| 149 | |
| 150 | p.Username, err = url.QueryUnescape(username) |
| 151 | if err != nil { |
| 152 | return internal.WrapErrorf(err, "invalid username") |
| 153 | } |
| 154 | if len(password) > 1 { |
| 155 | if strings.Contains(password, ":") { |
| 156 | return fmt.Errorf("unescaped colon in password") |
| 157 | } |
| 158 | if strings.Contains(password, "/") { |
| 159 | return fmt.Errorf("unescaped slash in password") |
| 160 | } |
| 161 | p.Password, err = url.QueryUnescape(password) |
| 162 | if err != nil { |
| 163 | return internal.WrapErrorf(err, "invalid password") |
| 164 | } |
| 165 | } |
| 166 | } |
| 167 | |
| 168 | // fetch the hosts field |
| 169 | hosts := uri |
| 170 | if idx := strings.IndexAny(uri, "/?@"); idx != -1 { |
| 171 | if uri[idx] == '@' { |
| 172 | return fmt.Errorf("unescaped @ sign in user info") |
| 173 | } |
| 174 | if uri[idx] == '?' { |
| 175 | return fmt.Errorf("must have a / before the query ?") |
| 176 | } |
| 177 | hosts = uri[:idx] |
| 178 | } |
| 179 | |
| 180 | var connectionArgsFromTXT []string |
| 181 | parsedHosts := strings.Split(hosts, ",") |
| 182 | |
| 183 | if isSRV { |
| 184 | parsedHosts = strings.Split(hosts, ",") |
| 185 | if len(parsedHosts) != 1 { |
| 186 | return fmt.Errorf("URI with SRV must include one and only one hostname") |
| 187 | } |
| 188 | parsedHosts, err = fetchSeedlistFromSRV(parsedHosts[0]) |
| 189 | if err != nil { |
| 190 | return err |
| 191 | } |
| 192 | |
| 193 | // error ignored because finding a TXT record should not be |
| 194 | // considered an error. |
| 195 | recordsFromTXT, _ := net.LookupTXT(hosts) |
| 196 | |
| 197 | // This is a temporary fix to get around bug https://github.com/golang/go/issues/21472. |
| 198 | // It will currently incorrectly concatenate multiple TXT records to one |
| 199 | // on windows. |
| 200 | if runtime.GOOS == "windows" { |
| 201 | recordsFromTXT = []string{strings.Join(recordsFromTXT, "")} |
| 202 | } |
| 203 | |
| 204 | if len(recordsFromTXT) > 1 { |
| 205 | return errors.New("multiple records from TXT not supported") |
| 206 | } |
| 207 | if len(recordsFromTXT) > 0 { |
| 208 | connectionArgsFromTXT = strings.FieldsFunc(recordsFromTXT[0], func(r rune) bool { return r == ';' || r == '&' }) |
| 209 | |
| 210 | err := validateTXTResult(connectionArgsFromTXT) |
| 211 | if err != nil { |
| 212 | return err |
| 213 | } |
| 214 | |
| 215 | } |
| 216 | |
| 217 | // SSL is enabled by default for SRV, but can be manually disabled with "ssl=false". |
| 218 | p.SSL = true |
| 219 | p.SSLSet = true |
| 220 | } |
| 221 | |
| 222 | for _, host := range parsedHosts { |
| 223 | err = p.addHost(host) |
| 224 | if err != nil { |
| 225 | return internal.WrapErrorf(err, "invalid host \"%s\"", host) |
| 226 | } |
| 227 | } |
| 228 | if len(p.Hosts) == 0 { |
| 229 | return fmt.Errorf("must have at least 1 host") |
| 230 | } |
| 231 | |
| 232 | uri = uri[len(hosts):] |
| 233 | |
| 234 | extractedDatabase, err := extractDatabaseFromURI(uri) |
| 235 | if err != nil { |
| 236 | return err |
| 237 | } |
| 238 | |
| 239 | uri = extractedDatabase.uri |
| 240 | p.Database = extractedDatabase.db |
| 241 | |
| 242 | connectionArgsFromQueryString, err := extractQueryArgsFromURI(uri) |
| 243 | connectionArgPairs := append(connectionArgsFromTXT, connectionArgsFromQueryString...) |
| 244 | |
| 245 | for _, pair := range connectionArgPairs { |
| 246 | err = p.addOption(pair) |
| 247 | if err != nil { |
| 248 | return err |
| 249 | } |
| 250 | } |
| 251 | |
| 252 | err = p.setDefaultAuthParams(extractedDatabase.db) |
| 253 | if err != nil { |
| 254 | return err |
| 255 | } |
| 256 | |
| 257 | err = p.validateAuth() |
| 258 | if err != nil { |
| 259 | return err |
| 260 | } |
| 261 | |
| 262 | // Check for invalid write concern (i.e. w=0 and j=true) |
| 263 | if p.WNumberSet && p.WNumber == 0 && p.JSet && p.J { |
| 264 | return writeconcern.ErrInconsistent |
| 265 | } |
| 266 | |
| 267 | // If WTimeout was set from manual options passed in, set WTImeoutSet to true. |
| 268 | if p.WTimeoutSetFromOption { |
| 269 | p.WTimeoutSet = true |
| 270 | } |
| 271 | |
| 272 | return nil |
| 273 | } |
| 274 | |
| 275 | func (p *parser) setDefaultAuthParams(dbName string) error { |
| 276 | switch strings.ToLower(p.AuthMechanism) { |
| 277 | case "plain": |
| 278 | if p.AuthSource == "" { |
| 279 | p.AuthSource = dbName |
| 280 | if p.AuthSource == "" { |
| 281 | p.AuthSource = "$external" |
| 282 | } |
| 283 | } |
| 284 | case "gssapi": |
| 285 | if p.AuthMechanismProperties == nil { |
| 286 | p.AuthMechanismProperties = map[string]string{ |
| 287 | "SERVICE_NAME": "mongodb", |
| 288 | } |
| 289 | } else if v, ok := p.AuthMechanismProperties["SERVICE_NAME"]; !ok || v == "" { |
| 290 | p.AuthMechanismProperties["SERVICE_NAME"] = "mongodb" |
| 291 | } |
| 292 | fallthrough |
| 293 | case "mongodb-x509": |
| 294 | if p.AuthSource == "" { |
| 295 | p.AuthSource = "$external" |
| 296 | } else if p.AuthSource != "$external" { |
| 297 | return fmt.Errorf("auth source must be $external") |
| 298 | } |
| 299 | case "mongodb-cr": |
| 300 | fallthrough |
| 301 | case "scram-sha-1": |
| 302 | fallthrough |
| 303 | case "scram-sha-256": |
| 304 | if p.AuthSource == "" { |
| 305 | p.AuthSource = dbName |
| 306 | if p.AuthSource == "" { |
| 307 | p.AuthSource = "admin" |
| 308 | } |
| 309 | } |
| 310 | case "": |
| 311 | if p.AuthSource == "" { |
| 312 | p.AuthSource = dbName |
| 313 | if p.AuthSource == "" { |
| 314 | p.AuthSource = "admin" |
| 315 | } |
| 316 | } |
| 317 | default: |
| 318 | return fmt.Errorf("invalid auth mechanism") |
| 319 | } |
| 320 | return nil |
| 321 | } |
| 322 | |
| 323 | func (p *parser) validateAuth() error { |
| 324 | switch strings.ToLower(p.AuthMechanism) { |
| 325 | case "mongodb-cr": |
| 326 | if p.Username == "" { |
| 327 | return fmt.Errorf("username required for MONGO-CR") |
| 328 | } |
| 329 | if p.Password == "" { |
| 330 | return fmt.Errorf("password required for MONGO-CR") |
| 331 | } |
| 332 | if p.AuthMechanismProperties != nil { |
| 333 | return fmt.Errorf("MONGO-CR cannot have mechanism properties") |
| 334 | } |
| 335 | case "mongodb-x509": |
| 336 | if p.Password != "" { |
| 337 | return fmt.Errorf("password cannot be specified for MONGO-X509") |
| 338 | } |
| 339 | if p.AuthMechanismProperties != nil { |
| 340 | return fmt.Errorf("MONGO-X509 cannot have mechanism properties") |
| 341 | } |
| 342 | case "gssapi": |
| 343 | if p.Username == "" { |
| 344 | return fmt.Errorf("username required for GSSAPI") |
| 345 | } |
| 346 | for k := range p.AuthMechanismProperties { |
| 347 | if k != "SERVICE_NAME" && k != "CANONICALIZE_HOST_NAME" && k != "SERVICE_REALM" { |
| 348 | return fmt.Errorf("invalid auth property for GSSAPI") |
| 349 | } |
| 350 | } |
| 351 | case "plain": |
| 352 | if p.Username == "" { |
| 353 | return fmt.Errorf("username required for PLAIN") |
| 354 | } |
| 355 | if p.Password == "" { |
| 356 | return fmt.Errorf("password required for PLAIN") |
| 357 | } |
| 358 | if p.AuthMechanismProperties != nil { |
| 359 | return fmt.Errorf("PLAIN cannot have mechanism properties") |
| 360 | } |
| 361 | case "scram-sha-1": |
| 362 | if p.Username == "" { |
| 363 | return fmt.Errorf("username required for SCRAM-SHA-1") |
| 364 | } |
| 365 | if p.Password == "" { |
| 366 | return fmt.Errorf("password required for SCRAM-SHA-1") |
| 367 | } |
| 368 | if p.AuthMechanismProperties != nil { |
| 369 | return fmt.Errorf("SCRAM-SHA-1 cannot have mechanism properties") |
| 370 | } |
| 371 | case "scram-sha-256": |
| 372 | if p.Username == "" { |
| 373 | return fmt.Errorf("username required for SCRAM-SHA-256") |
| 374 | } |
| 375 | if p.Password == "" { |
| 376 | return fmt.Errorf("password required for SCRAM-SHA-256") |
| 377 | } |
| 378 | if p.AuthMechanismProperties != nil { |
| 379 | return fmt.Errorf("SCRAM-SHA-256 cannot have mechanism properties") |
| 380 | } |
| 381 | case "": |
| 382 | default: |
| 383 | return fmt.Errorf("invalid auth mechanism") |
| 384 | } |
| 385 | return nil |
| 386 | } |
| 387 | |
| 388 | func fetchSeedlistFromSRV(host string) ([]string, error) { |
| 389 | var err error |
| 390 | |
| 391 | _, _, err = net.SplitHostPort(host) |
| 392 | |
| 393 | if err == nil { |
| 394 | // we were able to successfully extract a port from the host, |
| 395 | // but should not be able to when using SRV |
| 396 | return nil, fmt.Errorf("URI with srv must not include a port number") |
| 397 | } |
| 398 | |
| 399 | _, addresses, err := net.LookupSRV("mongodb", "tcp", host) |
| 400 | if err != nil { |
| 401 | return nil, err |
| 402 | } |
| 403 | parsedHosts := make([]string, len(addresses)) |
| 404 | for i, address := range addresses { |
| 405 | trimmedAddressTarget := strings.TrimSuffix(address.Target, ".") |
| 406 | err := validateSRVResult(trimmedAddressTarget, host) |
| 407 | if err != nil { |
| 408 | return nil, err |
| 409 | } |
| 410 | parsedHosts[i] = fmt.Sprintf("%s:%d", trimmedAddressTarget, address.Port) |
| 411 | } |
| 412 | |
| 413 | return parsedHosts, nil |
| 414 | } |
| 415 | |
| 416 | func (p *parser) addHost(host string) error { |
| 417 | if host == "" { |
| 418 | return nil |
| 419 | } |
| 420 | host, err := url.QueryUnescape(host) |
| 421 | if err != nil { |
| 422 | return internal.WrapErrorf(err, "invalid host \"%s\"", host) |
| 423 | } |
| 424 | |
| 425 | _, port, err := net.SplitHostPort(host) |
| 426 | // this is unfortunate that SplitHostPort actually requires |
| 427 | // a port to exist. |
| 428 | if err != nil { |
| 429 | if addrError, ok := err.(*net.AddrError); !ok || addrError.Err != "missing port in address" { |
| 430 | return err |
| 431 | } |
| 432 | } |
| 433 | |
| 434 | if port != "" { |
| 435 | d, err := strconv.Atoi(port) |
| 436 | if err != nil { |
| 437 | return internal.WrapErrorf(err, "port must be an integer") |
| 438 | } |
| 439 | if d <= 0 || d >= 65536 { |
| 440 | return fmt.Errorf("port must be in the range [1, 65535]") |
| 441 | } |
| 442 | } |
| 443 | p.Hosts = append(p.Hosts, host) |
| 444 | return nil |
| 445 | } |
| 446 | |
| 447 | func (p *parser) addOption(pair string) error { |
| 448 | kv := strings.SplitN(pair, "=", 2) |
| 449 | if len(kv) != 2 || kv[0] == "" { |
| 450 | return fmt.Errorf("invalid option") |
| 451 | } |
| 452 | |
| 453 | key, err := url.QueryUnescape(kv[0]) |
| 454 | if err != nil { |
| 455 | return internal.WrapErrorf(err, "invalid option key \"%s\"", kv[0]) |
| 456 | } |
| 457 | |
| 458 | value, err := url.QueryUnescape(kv[1]) |
| 459 | if err != nil { |
| 460 | return internal.WrapErrorf(err, "invalid option value \"%s\"", kv[1]) |
| 461 | } |
| 462 | |
| 463 | lowerKey := strings.ToLower(key) |
| 464 | switch lowerKey { |
| 465 | case "appname": |
| 466 | p.AppName = value |
| 467 | case "authmechanism": |
| 468 | p.AuthMechanism = value |
| 469 | case "authmechanismproperties": |
| 470 | p.AuthMechanismProperties = make(map[string]string) |
| 471 | pairs := strings.Split(value, ",") |
| 472 | for _, pair := range pairs { |
| 473 | kv := strings.SplitN(pair, ":", 2) |
| 474 | if len(kv) != 2 || kv[0] == "" { |
| 475 | return fmt.Errorf("invalid authMechanism property") |
| 476 | } |
| 477 | p.AuthMechanismProperties[kv[0]] = kv[1] |
| 478 | } |
| 479 | case "authsource": |
| 480 | p.AuthSource = value |
| 481 | case "compressors": |
| 482 | compressors := strings.Split(value, ",") |
| 483 | if len(compressors) < 1 { |
| 484 | return fmt.Errorf("must have at least 1 compressor") |
| 485 | } |
| 486 | p.Compressors = compressors |
| 487 | case "connect": |
| 488 | switch strings.ToLower(value) { |
| 489 | case "automatic": |
| 490 | case "direct": |
| 491 | p.Connect = SingleConnect |
| 492 | default: |
| 493 | return fmt.Errorf("invalid 'connect' value: %s", value) |
| 494 | } |
| 495 | |
| 496 | p.ConnectSet = true |
| 497 | case "connecttimeoutms": |
| 498 | n, err := strconv.Atoi(value) |
| 499 | if err != nil || n < 0 { |
| 500 | return fmt.Errorf("invalid value for %s: %s", key, value) |
| 501 | } |
| 502 | p.ConnectTimeout = time.Duration(n) * time.Millisecond |
| 503 | p.ConnectTimeoutSet = true |
| 504 | case "heartbeatintervalms", "heartbeatfrequencyms": |
| 505 | n, err := strconv.Atoi(value) |
| 506 | if err != nil || n < 0 { |
| 507 | return fmt.Errorf("invalid value for %s: %s", key, value) |
| 508 | } |
| 509 | p.HeartbeatInterval = time.Duration(n) * time.Millisecond |
| 510 | p.HeartbeatIntervalSet = true |
| 511 | case "journal": |
| 512 | switch value { |
| 513 | case "true": |
| 514 | p.J = true |
| 515 | case "false": |
| 516 | p.J = false |
| 517 | default: |
| 518 | return fmt.Errorf("invalid value for %s: %s", key, value) |
| 519 | } |
| 520 | |
| 521 | p.JSet = true |
| 522 | case "localthresholdms": |
| 523 | n, err := strconv.Atoi(value) |
| 524 | if err != nil || n < 0 { |
| 525 | return fmt.Errorf("invalid value for %s: %s", key, value) |
| 526 | } |
| 527 | p.LocalThreshold = time.Duration(n) * time.Millisecond |
| 528 | p.LocalThresholdSet = true |
| 529 | case "maxidletimems": |
| 530 | n, err := strconv.Atoi(value) |
| 531 | if err != nil || n < 0 { |
| 532 | return fmt.Errorf("invalid value for %s: %s", key, value) |
| 533 | } |
| 534 | p.MaxConnIdleTime = time.Duration(n) * time.Millisecond |
| 535 | p.MaxConnIdleTimeSet = true |
| 536 | case "maxpoolsize": |
| 537 | n, err := strconv.Atoi(value) |
| 538 | if err != nil || n < 0 { |
| 539 | return fmt.Errorf("invalid value for %s: %s", key, value) |
| 540 | } |
| 541 | p.MaxPoolSize = uint16(n) |
| 542 | p.MaxPoolSizeSet = true |
| 543 | case "readconcernlevel": |
| 544 | p.ReadConcernLevel = value |
| 545 | case "readpreference": |
| 546 | p.ReadPreference = value |
| 547 | case "readpreferencetags": |
| 548 | tags := make(map[string]string) |
| 549 | items := strings.Split(value, ",") |
| 550 | for _, item := range items { |
| 551 | parts := strings.Split(item, ":") |
| 552 | if len(parts) != 2 { |
| 553 | return fmt.Errorf("invalid value for %s: %s", key, value) |
| 554 | } |
| 555 | tags[parts[0]] = parts[1] |
| 556 | } |
| 557 | p.ReadPreferenceTagSets = append(p.ReadPreferenceTagSets, tags) |
| 558 | case "maxstaleness": |
| 559 | n, err := strconv.Atoi(value) |
| 560 | if err != nil || n < 0 { |
| 561 | return fmt.Errorf("invalid value for %s: %s", key, value) |
| 562 | } |
| 563 | p.MaxStaleness = time.Duration(n) * time.Second |
| 564 | p.MaxStalenessSet = true |
| 565 | case "replicaset": |
| 566 | p.ReplicaSet = value |
| 567 | case "retrywrites": |
| 568 | p.RetryWrites = value == "true" |
| 569 | p.RetryWritesSet = true |
| 570 | case "serverselectiontimeoutms": |
| 571 | n, err := strconv.Atoi(value) |
| 572 | if err != nil || n < 0 { |
| 573 | return fmt.Errorf("invalid value for %s: %s", key, value) |
| 574 | } |
| 575 | p.ServerSelectionTimeout = time.Duration(n) * time.Millisecond |
| 576 | p.ServerSelectionTimeoutSet = true |
| 577 | case "sockettimeoutms": |
| 578 | n, err := strconv.Atoi(value) |
| 579 | if err != nil || n < 0 { |
| 580 | return fmt.Errorf("invalid value for %s: %s", key, value) |
| 581 | } |
| 582 | p.SocketTimeout = time.Duration(n) * time.Millisecond |
| 583 | p.SocketTimeoutSet = true |
| 584 | case "ssl": |
| 585 | switch value { |
| 586 | case "true": |
| 587 | p.SSL = true |
| 588 | case "false": |
| 589 | p.SSL = false |
| 590 | default: |
| 591 | return fmt.Errorf("invalid value for %s: %s", key, value) |
| 592 | } |
| 593 | |
| 594 | p.SSLSet = true |
| 595 | case "sslclientcertificatekeyfile": |
| 596 | p.SSL = true |
| 597 | p.SSLSet = true |
| 598 | p.SSLClientCertificateKeyFile = value |
| 599 | p.SSLClientCertificateKeyFileSet = true |
| 600 | case "sslclientcertificatekeypassword": |
| 601 | p.SSLClientCertificateKeyPassword = func() string { return value } |
| 602 | p.SSLClientCertificateKeyPasswordSet = true |
| 603 | case "sslinsecure": |
| 604 | switch value { |
| 605 | case "true": |
| 606 | p.SSLInsecure = true |
| 607 | case "false": |
| 608 | p.SSLInsecure = false |
| 609 | default: |
| 610 | return fmt.Errorf("invalid value for %s: %s", key, value) |
| 611 | } |
| 612 | |
| 613 | p.SSLInsecureSet = true |
| 614 | case "sslcertificateauthorityfile": |
| 615 | p.SSL = true |
| 616 | p.SSLSet = true |
| 617 | p.SSLCaFile = value |
| 618 | p.SSLCaFileSet = true |
| 619 | case "w": |
| 620 | if w, err := strconv.Atoi(value); err == nil { |
| 621 | if w < 0 { |
| 622 | return fmt.Errorf("invalid value for %s: %s", key, value) |
| 623 | } |
| 624 | |
| 625 | p.WNumber = w |
| 626 | p.WNumberSet = true |
| 627 | p.WString = "" |
| 628 | break |
| 629 | } |
| 630 | |
| 631 | p.WString = value |
| 632 | p.WNumberSet = false |
| 633 | |
| 634 | case "wtimeoutms": |
| 635 | n, err := strconv.Atoi(value) |
| 636 | if err != nil || n < 0 { |
| 637 | return fmt.Errorf("invalid value for %s: %s", key, value) |
| 638 | } |
| 639 | p.WTimeout = time.Duration(n) * time.Millisecond |
| 640 | p.WTimeoutSet = true |
| 641 | case "wtimeout": |
| 642 | // Defer to wtimeoutms, but not to a manually-set option. |
| 643 | if p.WTimeoutSet { |
| 644 | break |
| 645 | } |
| 646 | n, err := strconv.Atoi(value) |
| 647 | if err != nil || n < 0 { |
| 648 | return fmt.Errorf("invalid value for %s: %s", key, value) |
| 649 | } |
| 650 | p.WTimeout = time.Duration(n) * time.Millisecond |
| 651 | case "zlibcompressionlevel": |
| 652 | level, err := strconv.Atoi(value) |
| 653 | if err != nil || (level < -1 || level > 9) { |
| 654 | return fmt.Errorf("invalid value for %s: %s", key, value) |
| 655 | } |
| 656 | |
| 657 | if level == -1 { |
| 658 | level = wiremessage.DefaultZlibLevel |
| 659 | } |
| 660 | p.ZlibLevel = level |
| 661 | default: |
| 662 | if p.UnknownOptions == nil { |
| 663 | p.UnknownOptions = make(map[string][]string) |
| 664 | } |
| 665 | p.UnknownOptions[lowerKey] = append(p.UnknownOptions[lowerKey], value) |
| 666 | } |
| 667 | |
| 668 | if p.Options == nil { |
| 669 | p.Options = make(map[string][]string) |
| 670 | } |
| 671 | p.Options[lowerKey] = append(p.Options[lowerKey], value) |
| 672 | |
| 673 | return nil |
| 674 | } |
| 675 | |
| 676 | func validateSRVResult(recordFromSRV, inputHostName string) error { |
| 677 | separatedInputDomain := strings.Split(inputHostName, ".") |
| 678 | separatedRecord := strings.Split(recordFromSRV, ".") |
| 679 | if len(separatedRecord) < 2 { |
| 680 | return errors.New("DNS name must contain at least 2 labels") |
| 681 | } |
| 682 | if len(separatedRecord) < len(separatedInputDomain) { |
| 683 | return errors.New("Domain suffix from SRV record not matched input domain") |
| 684 | } |
| 685 | |
| 686 | inputDomainSuffix := separatedInputDomain[1:] |
| 687 | domainSuffixOffset := len(separatedRecord) - (len(separatedInputDomain) - 1) |
| 688 | |
| 689 | recordDomainSuffix := separatedRecord[domainSuffixOffset:] |
| 690 | for ix, label := range inputDomainSuffix { |
| 691 | if label != recordDomainSuffix[ix] { |
| 692 | return errors.New("Domain suffix from SRV record not matched input domain") |
| 693 | } |
| 694 | } |
| 695 | return nil |
| 696 | } |
| 697 | |
| 698 | var allowedTXTOptions = map[string]struct{}{ |
| 699 | "authsource": {}, |
| 700 | "replicaset": {}, |
| 701 | } |
| 702 | |
| 703 | func validateTXTResult(paramsFromTXT []string) error { |
| 704 | for _, param := range paramsFromTXT { |
| 705 | kv := strings.SplitN(param, "=", 2) |
| 706 | if len(kv) != 2 { |
| 707 | return errors.New("Invalid TXT record") |
| 708 | } |
| 709 | key := strings.ToLower(kv[0]) |
| 710 | if _, ok := allowedTXTOptions[key]; !ok { |
| 711 | return fmt.Errorf("Cannot specify option '%s' in TXT record", kv[0]) |
| 712 | } |
| 713 | } |
| 714 | return nil |
| 715 | } |
| 716 | |
| 717 | func extractQueryArgsFromURI(uri string) ([]string, error) { |
| 718 | if len(uri) == 0 { |
| 719 | return nil, nil |
| 720 | } |
| 721 | |
| 722 | if uri[0] != '?' { |
| 723 | return nil, errors.New("must have a ? separator between path and query") |
| 724 | } |
| 725 | |
| 726 | uri = uri[1:] |
| 727 | if len(uri) == 0 { |
| 728 | return nil, nil |
| 729 | } |
| 730 | return strings.FieldsFunc(uri, func(r rune) bool { return r == ';' || r == '&' }), nil |
| 731 | |
| 732 | } |
| 733 | |
| 734 | type extractedDatabase struct { |
| 735 | uri string |
| 736 | db string |
| 737 | } |
| 738 | |
| 739 | // extractDatabaseFromURI is a helper function to retrieve information about |
| 740 | // the database from the passed in URI. It accepts as an argument the currently |
| 741 | // parsed URI and returns the remainder of the uri, the database it found, |
| 742 | // and any error it encounters while parsing. |
| 743 | func extractDatabaseFromURI(uri string) (extractedDatabase, error) { |
| 744 | if len(uri) == 0 { |
| 745 | return extractedDatabase{}, nil |
| 746 | } |
| 747 | |
| 748 | if uri[0] != '/' { |
| 749 | return extractedDatabase{}, errors.New("must have a / separator between hosts and path") |
| 750 | } |
| 751 | |
| 752 | uri = uri[1:] |
| 753 | if len(uri) == 0 { |
| 754 | return extractedDatabase{}, nil |
| 755 | } |
| 756 | |
| 757 | database := uri |
| 758 | if idx := strings.IndexRune(uri, '?'); idx != -1 { |
| 759 | database = uri[:idx] |
| 760 | } |
| 761 | |
| 762 | escapedDatabase, err := url.QueryUnescape(database) |
| 763 | if err != nil { |
| 764 | return extractedDatabase{}, internal.WrapErrorf(err, "invalid database \"%s\"", database) |
| 765 | } |
| 766 | |
| 767 | uri = uri[len(database):] |
| 768 | |
| 769 | return extractedDatabase{ |
| 770 | uri: uri, |
| 771 | db: escapedDatabase, |
| 772 | }, nil |
| 773 | } |