David K. Bainbridge | 528b318 | 2017-01-23 08:51:59 -0800 | [diff] [blame] | 1 | // Copyright 2012 The Gorilla Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
| 4 | |
| 5 | package mux |
| 6 | |
| 7 | import ( |
| 8 | "errors" |
| 9 | "fmt" |
| 10 | "net/http" |
| 11 | "path" |
| 12 | "regexp" |
| 13 | "strings" |
| 14 | ) |
| 15 | |
| 16 | // NewRouter returns a new router instance. |
| 17 | func NewRouter() *Router { |
| 18 | return &Router{namedRoutes: make(map[string]*Route), KeepContext: false} |
| 19 | } |
| 20 | |
| 21 | // Router registers routes to be matched and dispatches a handler. |
| 22 | // |
| 23 | // It implements the http.Handler interface, so it can be registered to serve |
| 24 | // requests: |
| 25 | // |
| 26 | // var router = mux.NewRouter() |
| 27 | // |
| 28 | // func main() { |
| 29 | // http.Handle("/", router) |
| 30 | // } |
| 31 | // |
| 32 | // Or, for Google App Engine, register it in a init() function: |
| 33 | // |
| 34 | // func init() { |
| 35 | // http.Handle("/", router) |
| 36 | // } |
| 37 | // |
| 38 | // This will send all incoming requests to the router. |
| 39 | type Router struct { |
| 40 | // Configurable Handler to be used when no route matches. |
| 41 | NotFoundHandler http.Handler |
| 42 | // Parent route, if this is a subrouter. |
| 43 | parent parentRoute |
| 44 | // Routes to be matched, in order. |
| 45 | routes []*Route |
| 46 | // Routes by name for URL building. |
| 47 | namedRoutes map[string]*Route |
| 48 | // See Router.StrictSlash(). This defines the flag for new routes. |
| 49 | strictSlash bool |
| 50 | // See Router.SkipClean(). This defines the flag for new routes. |
| 51 | skipClean bool |
| 52 | // If true, do not clear the request context after handling the request. |
| 53 | // This has no effect when go1.7+ is used, since the context is stored |
| 54 | // on the request itself. |
| 55 | KeepContext bool |
| 56 | // see Router.UseEncodedPath(). This defines a flag for all routes. |
| 57 | useEncodedPath bool |
| 58 | } |
| 59 | |
| 60 | // Match matches registered routes against the request. |
| 61 | func (r *Router) Match(req *http.Request, match *RouteMatch) bool { |
| 62 | for _, route := range r.routes { |
| 63 | if route.Match(req, match) { |
| 64 | return true |
| 65 | } |
| 66 | } |
| 67 | |
| 68 | // Closest match for a router (includes sub-routers) |
| 69 | if r.NotFoundHandler != nil { |
| 70 | match.Handler = r.NotFoundHandler |
| 71 | return true |
| 72 | } |
| 73 | return false |
| 74 | } |
| 75 | |
| 76 | // ServeHTTP dispatches the handler registered in the matched route. |
| 77 | // |
| 78 | // When there is a match, the route variables can be retrieved calling |
| 79 | // mux.Vars(request). |
| 80 | func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { |
| 81 | if !r.skipClean { |
| 82 | path := req.URL.Path |
| 83 | if r.useEncodedPath { |
| 84 | path = getPath(req) |
| 85 | } |
| 86 | // Clean path to canonical form and redirect. |
| 87 | if p := cleanPath(path); p != path { |
| 88 | |
| 89 | // Added 3 lines (Philip Schlump) - It was dropping the query string and #whatever from query. |
| 90 | // This matches with fix in go 1.2 r.c. 4 for same problem. Go Issue: |
| 91 | // http://code.google.com/p/go/issues/detail?id=5252 |
| 92 | url := *req.URL |
| 93 | url.Path = p |
| 94 | p = url.String() |
| 95 | |
| 96 | w.Header().Set("Location", p) |
| 97 | w.WriteHeader(http.StatusMovedPermanently) |
| 98 | return |
| 99 | } |
| 100 | } |
| 101 | var match RouteMatch |
| 102 | var handler http.Handler |
| 103 | if r.Match(req, &match) { |
| 104 | handler = match.Handler |
| 105 | req = setVars(req, match.Vars) |
| 106 | req = setCurrentRoute(req, match.Route) |
| 107 | } |
| 108 | if handler == nil { |
| 109 | handler = http.NotFoundHandler() |
| 110 | } |
| 111 | if !r.KeepContext { |
| 112 | defer contextClear(req) |
| 113 | } |
| 114 | handler.ServeHTTP(w, req) |
| 115 | } |
| 116 | |
| 117 | // Get returns a route registered with the given name. |
| 118 | func (r *Router) Get(name string) *Route { |
| 119 | return r.getNamedRoutes()[name] |
| 120 | } |
| 121 | |
| 122 | // GetRoute returns a route registered with the given name. This method |
| 123 | // was renamed to Get() and remains here for backwards compatibility. |
| 124 | func (r *Router) GetRoute(name string) *Route { |
| 125 | return r.getNamedRoutes()[name] |
| 126 | } |
| 127 | |
| 128 | // StrictSlash defines the trailing slash behavior for new routes. The initial |
| 129 | // value is false. |
| 130 | // |
| 131 | // When true, if the route path is "/path/", accessing "/path" will redirect |
| 132 | // to the former and vice versa. In other words, your application will always |
| 133 | // see the path as specified in the route. |
| 134 | // |
| 135 | // When false, if the route path is "/path", accessing "/path/" will not match |
| 136 | // this route and vice versa. |
| 137 | // |
| 138 | // Special case: when a route sets a path prefix using the PathPrefix() method, |
| 139 | // strict slash is ignored for that route because the redirect behavior can't |
| 140 | // be determined from a prefix alone. However, any subrouters created from that |
| 141 | // route inherit the original StrictSlash setting. |
| 142 | func (r *Router) StrictSlash(value bool) *Router { |
| 143 | r.strictSlash = value |
| 144 | return r |
| 145 | } |
| 146 | |
| 147 | // SkipClean defines the path cleaning behaviour for new routes. The initial |
| 148 | // value is false. Users should be careful about which routes are not cleaned |
| 149 | // |
| 150 | // When true, if the route path is "/path//to", it will remain with the double |
| 151 | // slash. This is helpful if you have a route like: /fetch/http://xkcd.com/534/ |
| 152 | // |
| 153 | // When false, the path will be cleaned, so /fetch/http://xkcd.com/534/ will |
| 154 | // become /fetch/http/xkcd.com/534 |
| 155 | func (r *Router) SkipClean(value bool) *Router { |
| 156 | r.skipClean = value |
| 157 | return r |
| 158 | } |
| 159 | |
| 160 | // UseEncodedPath tells the router to match the encoded original path |
| 161 | // to the routes. |
| 162 | // For eg. "/path/foo%2Fbar/to" will match the path "/path/{var}/to". |
| 163 | // This behavior has the drawback of needing to match routes against |
| 164 | // r.RequestURI instead of r.URL.Path. Any modifications (such as http.StripPrefix) |
| 165 | // to r.URL.Path will not affect routing when this flag is on and thus may |
| 166 | // induce unintended behavior. |
| 167 | // |
| 168 | // If not called, the router will match the unencoded path to the routes. |
| 169 | // For eg. "/path/foo%2Fbar/to" will match the path "/path/foo/bar/to" |
| 170 | func (r *Router) UseEncodedPath() *Router { |
| 171 | r.useEncodedPath = true |
| 172 | return r |
| 173 | } |
| 174 | |
| 175 | // ---------------------------------------------------------------------------- |
| 176 | // parentRoute |
| 177 | // ---------------------------------------------------------------------------- |
| 178 | |
| 179 | // getNamedRoutes returns the map where named routes are registered. |
| 180 | func (r *Router) getNamedRoutes() map[string]*Route { |
| 181 | if r.namedRoutes == nil { |
| 182 | if r.parent != nil { |
| 183 | r.namedRoutes = r.parent.getNamedRoutes() |
| 184 | } else { |
| 185 | r.namedRoutes = make(map[string]*Route) |
| 186 | } |
| 187 | } |
| 188 | return r.namedRoutes |
| 189 | } |
| 190 | |
| 191 | // getRegexpGroup returns regexp definitions from the parent route, if any. |
| 192 | func (r *Router) getRegexpGroup() *routeRegexpGroup { |
| 193 | if r.parent != nil { |
| 194 | return r.parent.getRegexpGroup() |
| 195 | } |
| 196 | return nil |
| 197 | } |
| 198 | |
| 199 | func (r *Router) buildVars(m map[string]string) map[string]string { |
| 200 | if r.parent != nil { |
| 201 | m = r.parent.buildVars(m) |
| 202 | } |
| 203 | return m |
| 204 | } |
| 205 | |
| 206 | // ---------------------------------------------------------------------------- |
| 207 | // Route factories |
| 208 | // ---------------------------------------------------------------------------- |
| 209 | |
| 210 | // NewRoute registers an empty route. |
| 211 | func (r *Router) NewRoute() *Route { |
| 212 | route := &Route{parent: r, strictSlash: r.strictSlash, skipClean: r.skipClean, useEncodedPath: r.useEncodedPath} |
| 213 | r.routes = append(r.routes, route) |
| 214 | return route |
| 215 | } |
| 216 | |
| 217 | // Handle registers a new route with a matcher for the URL path. |
| 218 | // See Route.Path() and Route.Handler(). |
| 219 | func (r *Router) Handle(path string, handler http.Handler) *Route { |
| 220 | return r.NewRoute().Path(path).Handler(handler) |
| 221 | } |
| 222 | |
| 223 | // HandleFunc registers a new route with a matcher for the URL path. |
| 224 | // See Route.Path() and Route.HandlerFunc(). |
| 225 | func (r *Router) HandleFunc(path string, f func(http.ResponseWriter, |
| 226 | *http.Request)) *Route { |
| 227 | return r.NewRoute().Path(path).HandlerFunc(f) |
| 228 | } |
| 229 | |
| 230 | // Headers registers a new route with a matcher for request header values. |
| 231 | // See Route.Headers(). |
| 232 | func (r *Router) Headers(pairs ...string) *Route { |
| 233 | return r.NewRoute().Headers(pairs...) |
| 234 | } |
| 235 | |
| 236 | // Host registers a new route with a matcher for the URL host. |
| 237 | // See Route.Host(). |
| 238 | func (r *Router) Host(tpl string) *Route { |
| 239 | return r.NewRoute().Host(tpl) |
| 240 | } |
| 241 | |
| 242 | // MatcherFunc registers a new route with a custom matcher function. |
| 243 | // See Route.MatcherFunc(). |
| 244 | func (r *Router) MatcherFunc(f MatcherFunc) *Route { |
| 245 | return r.NewRoute().MatcherFunc(f) |
| 246 | } |
| 247 | |
| 248 | // Methods registers a new route with a matcher for HTTP methods. |
| 249 | // See Route.Methods(). |
| 250 | func (r *Router) Methods(methods ...string) *Route { |
| 251 | return r.NewRoute().Methods(methods...) |
| 252 | } |
| 253 | |
| 254 | // Path registers a new route with a matcher for the URL path. |
| 255 | // See Route.Path(). |
| 256 | func (r *Router) Path(tpl string) *Route { |
| 257 | return r.NewRoute().Path(tpl) |
| 258 | } |
| 259 | |
| 260 | // PathPrefix registers a new route with a matcher for the URL path prefix. |
| 261 | // See Route.PathPrefix(). |
| 262 | func (r *Router) PathPrefix(tpl string) *Route { |
| 263 | return r.NewRoute().PathPrefix(tpl) |
| 264 | } |
| 265 | |
| 266 | // Queries registers a new route with a matcher for URL query values. |
| 267 | // See Route.Queries(). |
| 268 | func (r *Router) Queries(pairs ...string) *Route { |
| 269 | return r.NewRoute().Queries(pairs...) |
| 270 | } |
| 271 | |
| 272 | // Schemes registers a new route with a matcher for URL schemes. |
| 273 | // See Route.Schemes(). |
| 274 | func (r *Router) Schemes(schemes ...string) *Route { |
| 275 | return r.NewRoute().Schemes(schemes...) |
| 276 | } |
| 277 | |
| 278 | // BuildVarsFunc registers a new route with a custom function for modifying |
| 279 | // route variables before building a URL. |
| 280 | func (r *Router) BuildVarsFunc(f BuildVarsFunc) *Route { |
| 281 | return r.NewRoute().BuildVarsFunc(f) |
| 282 | } |
| 283 | |
| 284 | // Walk walks the router and all its sub-routers, calling walkFn for each route |
| 285 | // in the tree. The routes are walked in the order they were added. Sub-routers |
| 286 | // are explored depth-first. |
| 287 | func (r *Router) Walk(walkFn WalkFunc) error { |
| 288 | return r.walk(walkFn, []*Route{}) |
| 289 | } |
| 290 | |
| 291 | // SkipRouter is used as a return value from WalkFuncs to indicate that the |
| 292 | // router that walk is about to descend down to should be skipped. |
| 293 | var SkipRouter = errors.New("skip this router") |
| 294 | |
| 295 | // WalkFunc is the type of the function called for each route visited by Walk. |
| 296 | // At every invocation, it is given the current route, and the current router, |
| 297 | // and a list of ancestor routes that lead to the current route. |
| 298 | type WalkFunc func(route *Route, router *Router, ancestors []*Route) error |
| 299 | |
| 300 | func (r *Router) walk(walkFn WalkFunc, ancestors []*Route) error { |
| 301 | for _, t := range r.routes { |
| 302 | if t.regexp == nil || t.regexp.path == nil || t.regexp.path.template == "" { |
| 303 | continue |
| 304 | } |
| 305 | |
| 306 | err := walkFn(t, r, ancestors) |
| 307 | if err == SkipRouter { |
| 308 | continue |
| 309 | } |
| 310 | if err != nil { |
| 311 | return err |
| 312 | } |
| 313 | for _, sr := range t.matchers { |
| 314 | if h, ok := sr.(*Router); ok { |
| 315 | err := h.walk(walkFn, ancestors) |
| 316 | if err != nil { |
| 317 | return err |
| 318 | } |
| 319 | } |
| 320 | } |
| 321 | if h, ok := t.handler.(*Router); ok { |
| 322 | ancestors = append(ancestors, t) |
| 323 | err := h.walk(walkFn, ancestors) |
| 324 | if err != nil { |
| 325 | return err |
| 326 | } |
| 327 | ancestors = ancestors[:len(ancestors)-1] |
| 328 | } |
| 329 | } |
| 330 | return nil |
| 331 | } |
| 332 | |
| 333 | // ---------------------------------------------------------------------------- |
| 334 | // Context |
| 335 | // ---------------------------------------------------------------------------- |
| 336 | |
| 337 | // RouteMatch stores information about a matched route. |
| 338 | type RouteMatch struct { |
| 339 | Route *Route |
| 340 | Handler http.Handler |
| 341 | Vars map[string]string |
| 342 | } |
| 343 | |
| 344 | type contextKey int |
| 345 | |
| 346 | const ( |
| 347 | varsKey contextKey = iota |
| 348 | routeKey |
| 349 | ) |
| 350 | |
| 351 | // Vars returns the route variables for the current request, if any. |
| 352 | func Vars(r *http.Request) map[string]string { |
| 353 | if rv := contextGet(r, varsKey); rv != nil { |
| 354 | return rv.(map[string]string) |
| 355 | } |
| 356 | return nil |
| 357 | } |
| 358 | |
| 359 | // CurrentRoute returns the matched route for the current request, if any. |
| 360 | // This only works when called inside the handler of the matched route |
| 361 | // because the matched route is stored in the request context which is cleared |
| 362 | // after the handler returns, unless the KeepContext option is set on the |
| 363 | // Router. |
| 364 | func CurrentRoute(r *http.Request) *Route { |
| 365 | if rv := contextGet(r, routeKey); rv != nil { |
| 366 | return rv.(*Route) |
| 367 | } |
| 368 | return nil |
| 369 | } |
| 370 | |
| 371 | func setVars(r *http.Request, val interface{}) *http.Request { |
| 372 | return contextSet(r, varsKey, val) |
| 373 | } |
| 374 | |
| 375 | func setCurrentRoute(r *http.Request, val interface{}) *http.Request { |
| 376 | return contextSet(r, routeKey, val) |
| 377 | } |
| 378 | |
| 379 | // ---------------------------------------------------------------------------- |
| 380 | // Helpers |
| 381 | // ---------------------------------------------------------------------------- |
| 382 | |
| 383 | // getPath returns the escaped path if possible; doing what URL.EscapedPath() |
| 384 | // which was added in go1.5 does |
| 385 | func getPath(req *http.Request) string { |
| 386 | if req.RequestURI != "" { |
| 387 | // Extract the path from RequestURI (which is escaped unlike URL.Path) |
| 388 | // as detailed here as detailed in https://golang.org/pkg/net/url/#URL |
| 389 | // for < 1.5 server side workaround |
| 390 | // http://localhost/path/here?v=1 -> /path/here |
| 391 | path := req.RequestURI |
| 392 | path = strings.TrimPrefix(path, req.URL.Scheme+`://`) |
| 393 | path = strings.TrimPrefix(path, req.URL.Host) |
| 394 | if i := strings.LastIndex(path, "?"); i > -1 { |
| 395 | path = path[:i] |
| 396 | } |
| 397 | if i := strings.LastIndex(path, "#"); i > -1 { |
| 398 | path = path[:i] |
| 399 | } |
| 400 | return path |
| 401 | } |
| 402 | return req.URL.Path |
| 403 | } |
| 404 | |
| 405 | // cleanPath returns the canonical path for p, eliminating . and .. elements. |
| 406 | // Borrowed from the net/http package. |
| 407 | func cleanPath(p string) string { |
| 408 | if p == "" { |
| 409 | return "/" |
| 410 | } |
| 411 | if p[0] != '/' { |
| 412 | p = "/" + p |
| 413 | } |
| 414 | np := path.Clean(p) |
| 415 | // path.Clean removes trailing slash except for root; |
| 416 | // put the trailing slash back if necessary. |
| 417 | if p[len(p)-1] == '/' && np != "/" { |
| 418 | np += "/" |
| 419 | } |
| 420 | |
| 421 | return np |
| 422 | } |
| 423 | |
| 424 | // uniqueVars returns an error if two slices contain duplicated strings. |
| 425 | func uniqueVars(s1, s2 []string) error { |
| 426 | for _, v1 := range s1 { |
| 427 | for _, v2 := range s2 { |
| 428 | if v1 == v2 { |
| 429 | return fmt.Errorf("mux: duplicated route variable %q", v2) |
| 430 | } |
| 431 | } |
| 432 | } |
| 433 | return nil |
| 434 | } |
| 435 | |
| 436 | // checkPairs returns the count of strings passed in, and an error if |
| 437 | // the count is not an even number. |
| 438 | func checkPairs(pairs ...string) (int, error) { |
| 439 | length := len(pairs) |
| 440 | if length%2 != 0 { |
| 441 | return length, fmt.Errorf( |
| 442 | "mux: number of parameters must be multiple of 2, got %v", pairs) |
| 443 | } |
| 444 | return length, nil |
| 445 | } |
| 446 | |
| 447 | // mapFromPairsToString converts variadic string parameters to a |
| 448 | // string to string map. |
| 449 | func mapFromPairsToString(pairs ...string) (map[string]string, error) { |
| 450 | length, err := checkPairs(pairs...) |
| 451 | if err != nil { |
| 452 | return nil, err |
| 453 | } |
| 454 | m := make(map[string]string, length/2) |
| 455 | for i := 0; i < length; i += 2 { |
| 456 | m[pairs[i]] = pairs[i+1] |
| 457 | } |
| 458 | return m, nil |
| 459 | } |
| 460 | |
| 461 | // mapFromPairsToRegex converts variadic string paramers to a |
| 462 | // string to regex map. |
| 463 | func mapFromPairsToRegex(pairs ...string) (map[string]*regexp.Regexp, error) { |
| 464 | length, err := checkPairs(pairs...) |
| 465 | if err != nil { |
| 466 | return nil, err |
| 467 | } |
| 468 | m := make(map[string]*regexp.Regexp, length/2) |
| 469 | for i := 0; i < length; i += 2 { |
| 470 | regex, err := regexp.Compile(pairs[i+1]) |
| 471 | if err != nil { |
| 472 | return nil, err |
| 473 | } |
| 474 | m[pairs[i]] = regex |
| 475 | } |
| 476 | return m, nil |
| 477 | } |
| 478 | |
| 479 | // matchInArray returns true if the given string value is in the array. |
| 480 | func matchInArray(arr []string, value string) bool { |
| 481 | for _, v := range arr { |
| 482 | if v == value { |
| 483 | return true |
| 484 | } |
| 485 | } |
| 486 | return false |
| 487 | } |
| 488 | |
| 489 | // matchMapWithString returns true if the given key/value pairs exist in a given map. |
| 490 | func matchMapWithString(toCheck map[string]string, toMatch map[string][]string, canonicalKey bool) bool { |
| 491 | for k, v := range toCheck { |
| 492 | // Check if key exists. |
| 493 | if canonicalKey { |
| 494 | k = http.CanonicalHeaderKey(k) |
| 495 | } |
| 496 | if values := toMatch[k]; values == nil { |
| 497 | return false |
| 498 | } else if v != "" { |
| 499 | // If value was defined as an empty string we only check that the |
| 500 | // key exists. Otherwise we also check for equality. |
| 501 | valueExists := false |
| 502 | for _, value := range values { |
| 503 | if v == value { |
| 504 | valueExists = true |
| 505 | break |
| 506 | } |
| 507 | } |
| 508 | if !valueExists { |
| 509 | return false |
| 510 | } |
| 511 | } |
| 512 | } |
| 513 | return true |
| 514 | } |
| 515 | |
| 516 | // matchMapWithRegex returns true if the given key/value pairs exist in a given map compiled against |
| 517 | // the given regex |
| 518 | func matchMapWithRegex(toCheck map[string]*regexp.Regexp, toMatch map[string][]string, canonicalKey bool) bool { |
| 519 | for k, v := range toCheck { |
| 520 | // Check if key exists. |
| 521 | if canonicalKey { |
| 522 | k = http.CanonicalHeaderKey(k) |
| 523 | } |
| 524 | if values := toMatch[k]; values == nil { |
| 525 | return false |
| 526 | } else if v != nil { |
| 527 | // If value was defined as an empty string we only check that the |
| 528 | // key exists. Otherwise we also check for equality. |
| 529 | valueExists := false |
| 530 | for _, value := range values { |
| 531 | if v.MatchString(value) { |
| 532 | valueExists = true |
| 533 | break |
| 534 | } |
| 535 | } |
| 536 | if !valueExists { |
| 537 | return false |
| 538 | } |
| 539 | } |
| 540 | } |
| 541 | return true |
| 542 | } |