blob: 0c64f16a3905e6a57c13bb3a64ec669179c1e2d8 [file] [log] [blame]
khenaidooffe076b2019-01-15 16:08:08 -05001// Copyright 2015 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package cors handles cross-origin HTTP requests (CORS).
16package cors
17
18import (
19 "fmt"
20 "net/http"
21 "net/url"
22 "sort"
23 "strings"
24)
25
26type CORSInfo map[string]bool
27
28// Set implements the flag.Value interface to allow users to define a list of CORS origins
29func (ci *CORSInfo) Set(s string) error {
30 m := make(map[string]bool)
31 for _, v := range strings.Split(s, ",") {
32 v = strings.TrimSpace(v)
33 if v == "" {
34 continue
35 }
36 if v != "*" {
37 if _, err := url.Parse(v); err != nil {
38 return fmt.Errorf("Invalid CORS origin: %s", err)
39 }
40 }
41 m[v] = true
42
43 }
44 *ci = CORSInfo(m)
45 return nil
46}
47
48func (ci *CORSInfo) String() string {
49 o := make([]string, 0)
50 for k := range *ci {
51 o = append(o, k)
52 }
53 sort.StringSlice(o).Sort()
54 return strings.Join(o, ",")
55}
56
57// OriginAllowed determines whether the server will allow a given CORS origin.
58func (c CORSInfo) OriginAllowed(origin string) bool {
59 return c["*"] || c[origin]
60}
61
62type CORSHandler struct {
63 Handler http.Handler
64 Info *CORSInfo
65}
66
67// addHeader adds the correct cors headers given an origin
68func (h *CORSHandler) addHeader(w http.ResponseWriter, origin string) {
69 w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
70 w.Header().Add("Access-Control-Allow-Origin", origin)
71 w.Header().Add("Access-Control-Allow-Headers", "accept, content-type, authorization")
72}
73
74// ServeHTTP adds the correct CORS headers based on the origin and returns immediately
75// with a 200 OK if the method is OPTIONS.
76func (h *CORSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
77 // Write CORS header.
78 if h.Info.OriginAllowed("*") {
79 h.addHeader(w, "*")
80 } else if origin := req.Header.Get("Origin"); h.Info.OriginAllowed(origin) {
81 h.addHeader(w, origin)
82 }
83
84 if req.Method == "OPTIONS" {
85 w.WriteHeader(http.StatusOK)
86 return
87 }
88
89 h.Handler.ServeHTTP(w, req)
90}