David K. Bainbridge | 528b318 | 2017-01-23 08:51:59 -0800 | [diff] [blame^] | 1 | // Copyright 2012-2016 Canonical Ltd. |
| 2 | // Licensed under the LGPLv3, see LICENCE file for details. |
| 3 | |
| 4 | package gomaasapi |
| 5 | |
| 6 | import ( |
| 7 | "fmt" |
| 8 | "net/http" |
| 9 | "net/http/httptest" |
| 10 | "strings" |
| 11 | ) |
| 12 | |
| 13 | type singleServingServer struct { |
| 14 | *httptest.Server |
| 15 | requestContent *string |
| 16 | requestHeader *http.Header |
| 17 | } |
| 18 | |
| 19 | // newSingleServingServer creates a single-serving test http server which will |
| 20 | // return only one response as defined by the passed arguments. |
| 21 | func newSingleServingServer(uri string, response string, code int) *singleServingServer { |
| 22 | var requestContent string |
| 23 | var requestHeader http.Header |
| 24 | var requested bool |
| 25 | handler := func(writer http.ResponseWriter, request *http.Request) { |
| 26 | if requested { |
| 27 | http.Error(writer, "Already requested", http.StatusServiceUnavailable) |
| 28 | } |
| 29 | res, err := readAndClose(request.Body) |
| 30 | if err != nil { |
| 31 | panic(err) |
| 32 | } |
| 33 | requestContent = string(res) |
| 34 | requestHeader = request.Header |
| 35 | if request.URL.String() != uri { |
| 36 | errorMsg := fmt.Sprintf("Error 404: page not found (expected '%v', got '%v').", uri, request.URL.String()) |
| 37 | http.Error(writer, errorMsg, http.StatusNotFound) |
| 38 | } else { |
| 39 | writer.WriteHeader(code) |
| 40 | fmt.Fprint(writer, response) |
| 41 | } |
| 42 | requested = true |
| 43 | } |
| 44 | server := httptest.NewServer(http.HandlerFunc(handler)) |
| 45 | return &singleServingServer{server, &requestContent, &requestHeader} |
| 46 | } |
| 47 | |
| 48 | type flakyServer struct { |
| 49 | *httptest.Server |
| 50 | nbRequests *int |
| 51 | requests *[][]byte |
| 52 | } |
| 53 | |
| 54 | // newFlakyServer creates a "flaky" test http server which will |
| 55 | // return `nbFlakyResponses` responses with the given code and then a 200 response. |
| 56 | func newFlakyServer(uri string, code int, nbFlakyResponses int) *flakyServer { |
| 57 | nbRequests := 0 |
| 58 | requests := make([][]byte, nbFlakyResponses+1) |
| 59 | handler := func(writer http.ResponseWriter, request *http.Request) { |
| 60 | nbRequests += 1 |
| 61 | body, err := readAndClose(request.Body) |
| 62 | if err != nil { |
| 63 | panic(err) |
| 64 | } |
| 65 | requests[nbRequests-1] = body |
| 66 | if request.URL.String() != uri { |
| 67 | errorMsg := fmt.Sprintf("Error 404: page not found (expected '%v', got '%v').", uri, request.URL.String()) |
| 68 | http.Error(writer, errorMsg, http.StatusNotFound) |
| 69 | } else if nbRequests <= nbFlakyResponses { |
| 70 | if code == http.StatusServiceUnavailable { |
| 71 | writer.Header().Set("Retry-After", "0") |
| 72 | } |
| 73 | writer.WriteHeader(code) |
| 74 | fmt.Fprint(writer, "flaky") |
| 75 | } else { |
| 76 | writer.WriteHeader(http.StatusOK) |
| 77 | fmt.Fprint(writer, "ok") |
| 78 | } |
| 79 | |
| 80 | } |
| 81 | server := httptest.NewServer(http.HandlerFunc(handler)) |
| 82 | return &flakyServer{server, &nbRequests, &requests} |
| 83 | } |
| 84 | |
| 85 | type simpleResponse struct { |
| 86 | status int |
| 87 | body string |
| 88 | } |
| 89 | |
| 90 | type SimpleTestServer struct { |
| 91 | *httptest.Server |
| 92 | |
| 93 | getResponses map[string][]simpleResponse |
| 94 | getResponseIndex map[string]int |
| 95 | putResponses map[string][]simpleResponse |
| 96 | putResponseIndex map[string]int |
| 97 | postResponses map[string][]simpleResponse |
| 98 | postResponseIndex map[string]int |
| 99 | deleteResponses map[string][]simpleResponse |
| 100 | deleteResponseIndex map[string]int |
| 101 | |
| 102 | requests []*http.Request |
| 103 | } |
| 104 | |
| 105 | func NewSimpleServer() *SimpleTestServer { |
| 106 | server := &SimpleTestServer{ |
| 107 | getResponses: make(map[string][]simpleResponse), |
| 108 | getResponseIndex: make(map[string]int), |
| 109 | putResponses: make(map[string][]simpleResponse), |
| 110 | putResponseIndex: make(map[string]int), |
| 111 | postResponses: make(map[string][]simpleResponse), |
| 112 | postResponseIndex: make(map[string]int), |
| 113 | deleteResponses: make(map[string][]simpleResponse), |
| 114 | deleteResponseIndex: make(map[string]int), |
| 115 | } |
| 116 | server.Server = httptest.NewUnstartedServer(http.HandlerFunc(server.handler)) |
| 117 | return server |
| 118 | } |
| 119 | |
| 120 | func (s *SimpleTestServer) AddGetResponse(path string, status int, body string) { |
| 121 | logger.Debugf("add get response for: %s, %d", path, status) |
| 122 | s.getResponses[path] = append(s.getResponses[path], simpleResponse{status: status, body: body}) |
| 123 | } |
| 124 | |
| 125 | func (s *SimpleTestServer) AddPutResponse(path string, status int, body string) { |
| 126 | logger.Debugf("add put response for: %s, %d", path, status) |
| 127 | s.putResponses[path] = append(s.putResponses[path], simpleResponse{status: status, body: body}) |
| 128 | } |
| 129 | |
| 130 | func (s *SimpleTestServer) AddPostResponse(path string, status int, body string) { |
| 131 | logger.Debugf("add post response for: %s, %d", path, status) |
| 132 | s.postResponses[path] = append(s.postResponses[path], simpleResponse{status: status, body: body}) |
| 133 | } |
| 134 | |
| 135 | func (s *SimpleTestServer) AddDeleteResponse(path string, status int, body string) { |
| 136 | logger.Debugf("add delete response for: %s, %d", path, status) |
| 137 | s.deleteResponses[path] = append(s.deleteResponses[path], simpleResponse{status: status, body: body}) |
| 138 | } |
| 139 | |
| 140 | func (s *SimpleTestServer) LastRequest() *http.Request { |
| 141 | pos := len(s.requests) - 1 |
| 142 | if pos < 0 { |
| 143 | return nil |
| 144 | } |
| 145 | return s.requests[pos] |
| 146 | } |
| 147 | |
| 148 | func (s *SimpleTestServer) LastNRequests(n int) []*http.Request { |
| 149 | start := len(s.requests) - n |
| 150 | if start < 0 { |
| 151 | start = 0 |
| 152 | } |
| 153 | return s.requests[start:] |
| 154 | } |
| 155 | |
| 156 | func (s *SimpleTestServer) RequestCount() int { |
| 157 | return len(s.requests) |
| 158 | } |
| 159 | |
| 160 | func (s *SimpleTestServer) ResetRequests() { |
| 161 | s.requests = nil |
| 162 | } |
| 163 | |
| 164 | func (s *SimpleTestServer) handler(writer http.ResponseWriter, request *http.Request) { |
| 165 | method := request.Method |
| 166 | var ( |
| 167 | err error |
| 168 | responses map[string][]simpleResponse |
| 169 | responseIndex map[string]int |
| 170 | ) |
| 171 | switch method { |
| 172 | case "GET": |
| 173 | responses = s.getResponses |
| 174 | responseIndex = s.getResponseIndex |
| 175 | _, err = readAndClose(request.Body) |
| 176 | if err != nil { |
| 177 | panic(err) // it is a test, panic should be fine |
| 178 | } |
| 179 | case "PUT": |
| 180 | responses = s.putResponses |
| 181 | responseIndex = s.putResponseIndex |
| 182 | err = request.ParseForm() |
| 183 | if err != nil { |
| 184 | panic(err) |
| 185 | } |
| 186 | case "POST": |
| 187 | responses = s.postResponses |
| 188 | responseIndex = s.postResponseIndex |
| 189 | contentType := request.Header.Get("Content-Type") |
| 190 | if strings.HasPrefix(contentType, "multipart/form-data;") { |
| 191 | err = request.ParseMultipartForm(2 << 20) |
| 192 | } else { |
| 193 | err = request.ParseForm() |
| 194 | } |
| 195 | if err != nil { |
| 196 | panic(err) |
| 197 | } |
| 198 | case "DELETE": |
| 199 | responses = s.deleteResponses |
| 200 | responseIndex = s.deleteResponseIndex |
| 201 | _, err := readAndClose(request.Body) |
| 202 | if err != nil { |
| 203 | panic(err) |
| 204 | } |
| 205 | default: |
| 206 | panic("unsupported method " + method) |
| 207 | } |
| 208 | s.requests = append(s.requests, request) |
| 209 | uri := request.URL.String() |
| 210 | testResponses, found := responses[uri] |
| 211 | if !found { |
| 212 | errorMsg := fmt.Sprintf("Error 404: page not found ('%v').", uri) |
| 213 | http.Error(writer, errorMsg, http.StatusNotFound) |
| 214 | } else { |
| 215 | index := responseIndex[uri] |
| 216 | response := testResponses[index] |
| 217 | responseIndex[uri] = index + 1 |
| 218 | |
| 219 | writer.WriteHeader(response.status) |
| 220 | fmt.Fprint(writer, response.body) |
| 221 | } |
| 222 | } |