blob: 54d67aa67e7098343e4e4d1bd0ac08ee5915bc44 [file] [log] [blame]
David K. Bainbridge528b3182017-01-23 08:51:59 -08001// Copyright 2012-2016 Canonical Ltd.
2// Licensed under the LGPLv3, see LICENCE file for details.
3
4package gomaasapi
5
6import (
7 "fmt"
8 "net/http"
9 "net/http/httptest"
10 "strings"
11)
12
13type singleServingServer struct {
14 *httptest.Server
15 requestContent *string
16 requestHeader *http.Header
17}
18
19// newSingleServingServer creates a single-serving test http server which will
20// return only one response as defined by the passed arguments.
21func newSingleServingServer(uri string, response string, code int) *singleServingServer {
22 var requestContent string
23 var requestHeader http.Header
24 var requested bool
25 handler := func(writer http.ResponseWriter, request *http.Request) {
26 if requested {
27 http.Error(writer, "Already requested", http.StatusServiceUnavailable)
28 }
29 res, err := readAndClose(request.Body)
30 if err != nil {
31 panic(err)
32 }
33 requestContent = string(res)
34 requestHeader = request.Header
35 if request.URL.String() != uri {
36 errorMsg := fmt.Sprintf("Error 404: page not found (expected '%v', got '%v').", uri, request.URL.String())
37 http.Error(writer, errorMsg, http.StatusNotFound)
38 } else {
39 writer.WriteHeader(code)
40 fmt.Fprint(writer, response)
41 }
42 requested = true
43 }
44 server := httptest.NewServer(http.HandlerFunc(handler))
45 return &singleServingServer{server, &requestContent, &requestHeader}
46}
47
48type flakyServer struct {
49 *httptest.Server
50 nbRequests *int
51 requests *[][]byte
52}
53
54// newFlakyServer creates a "flaky" test http server which will
55// return `nbFlakyResponses` responses with the given code and then a 200 response.
56func newFlakyServer(uri string, code int, nbFlakyResponses int) *flakyServer {
57 nbRequests := 0
58 requests := make([][]byte, nbFlakyResponses+1)
59 handler := func(writer http.ResponseWriter, request *http.Request) {
60 nbRequests += 1
61 body, err := readAndClose(request.Body)
62 if err != nil {
63 panic(err)
64 }
65 requests[nbRequests-1] = body
66 if request.URL.String() != uri {
67 errorMsg := fmt.Sprintf("Error 404: page not found (expected '%v', got '%v').", uri, request.URL.String())
68 http.Error(writer, errorMsg, http.StatusNotFound)
69 } else if nbRequests <= nbFlakyResponses {
70 if code == http.StatusServiceUnavailable {
71 writer.Header().Set("Retry-After", "0")
72 }
73 writer.WriteHeader(code)
74 fmt.Fprint(writer, "flaky")
75 } else {
76 writer.WriteHeader(http.StatusOK)
77 fmt.Fprint(writer, "ok")
78 }
79
80 }
81 server := httptest.NewServer(http.HandlerFunc(handler))
82 return &flakyServer{server, &nbRequests, &requests}
83}
84
85type simpleResponse struct {
86 status int
87 body string
88}
89
90type SimpleTestServer struct {
91 *httptest.Server
92
93 getResponses map[string][]simpleResponse
94 getResponseIndex map[string]int
95 putResponses map[string][]simpleResponse
96 putResponseIndex map[string]int
97 postResponses map[string][]simpleResponse
98 postResponseIndex map[string]int
99 deleteResponses map[string][]simpleResponse
100 deleteResponseIndex map[string]int
101
102 requests []*http.Request
103}
104
105func NewSimpleServer() *SimpleTestServer {
106 server := &SimpleTestServer{
107 getResponses: make(map[string][]simpleResponse),
108 getResponseIndex: make(map[string]int),
109 putResponses: make(map[string][]simpleResponse),
110 putResponseIndex: make(map[string]int),
111 postResponses: make(map[string][]simpleResponse),
112 postResponseIndex: make(map[string]int),
113 deleteResponses: make(map[string][]simpleResponse),
114 deleteResponseIndex: make(map[string]int),
115 }
116 server.Server = httptest.NewUnstartedServer(http.HandlerFunc(server.handler))
117 return server
118}
119
120func (s *SimpleTestServer) AddGetResponse(path string, status int, body string) {
121 logger.Debugf("add get response for: %s, %d", path, status)
122 s.getResponses[path] = append(s.getResponses[path], simpleResponse{status: status, body: body})
123}
124
125func (s *SimpleTestServer) AddPutResponse(path string, status int, body string) {
126 logger.Debugf("add put response for: %s, %d", path, status)
127 s.putResponses[path] = append(s.putResponses[path], simpleResponse{status: status, body: body})
128}
129
130func (s *SimpleTestServer) AddPostResponse(path string, status int, body string) {
131 logger.Debugf("add post response for: %s, %d", path, status)
132 s.postResponses[path] = append(s.postResponses[path], simpleResponse{status: status, body: body})
133}
134
135func (s *SimpleTestServer) AddDeleteResponse(path string, status int, body string) {
136 logger.Debugf("add delete response for: %s, %d", path, status)
137 s.deleteResponses[path] = append(s.deleteResponses[path], simpleResponse{status: status, body: body})
138}
139
140func (s *SimpleTestServer) LastRequest() *http.Request {
141 pos := len(s.requests) - 1
142 if pos < 0 {
143 return nil
144 }
145 return s.requests[pos]
146}
147
148func (s *SimpleTestServer) LastNRequests(n int) []*http.Request {
149 start := len(s.requests) - n
150 if start < 0 {
151 start = 0
152 }
153 return s.requests[start:]
154}
155
156func (s *SimpleTestServer) RequestCount() int {
157 return len(s.requests)
158}
159
160func (s *SimpleTestServer) ResetRequests() {
161 s.requests = nil
162}
163
164func (s *SimpleTestServer) handler(writer http.ResponseWriter, request *http.Request) {
165 method := request.Method
166 var (
167 err error
168 responses map[string][]simpleResponse
169 responseIndex map[string]int
170 )
171 switch method {
172 case "GET":
173 responses = s.getResponses
174 responseIndex = s.getResponseIndex
175 _, err = readAndClose(request.Body)
176 if err != nil {
177 panic(err) // it is a test, panic should be fine
178 }
179 case "PUT":
180 responses = s.putResponses
181 responseIndex = s.putResponseIndex
182 err = request.ParseForm()
183 if err != nil {
184 panic(err)
185 }
186 case "POST":
187 responses = s.postResponses
188 responseIndex = s.postResponseIndex
189 contentType := request.Header.Get("Content-Type")
190 if strings.HasPrefix(contentType, "multipart/form-data;") {
191 err = request.ParseMultipartForm(2 << 20)
192 } else {
193 err = request.ParseForm()
194 }
195 if err != nil {
196 panic(err)
197 }
198 case "DELETE":
199 responses = s.deleteResponses
200 responseIndex = s.deleteResponseIndex
201 _, err := readAndClose(request.Body)
202 if err != nil {
203 panic(err)
204 }
205 default:
206 panic("unsupported method " + method)
207 }
208 s.requests = append(s.requests, request)
209 uri := request.URL.String()
210 testResponses, found := responses[uri]
211 if !found {
212 errorMsg := fmt.Sprintf("Error 404: page not found ('%v').", uri)
213 http.Error(writer, errorMsg, http.StatusNotFound)
214 } else {
215 index := responseIndex[uri]
216 response := testResponses[index]
217 responseIndex[uri] = index + 1
218
219 writer.WriteHeader(response.status)
220 fmt.Fprint(writer, response.body)
221 }
222}