blob: 54d67aa67e7098343e4e4d1bd0ac08ee5915bc44 [file] [log] [blame]
// Copyright 2012-2016 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package gomaasapi
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
)
type singleServingServer struct {
*httptest.Server
requestContent *string
requestHeader *http.Header
}
// newSingleServingServer creates a single-serving test http server which will
// return only one response as defined by the passed arguments.
func newSingleServingServer(uri string, response string, code int) *singleServingServer {
var requestContent string
var requestHeader http.Header
var requested bool
handler := func(writer http.ResponseWriter, request *http.Request) {
if requested {
http.Error(writer, "Already requested", http.StatusServiceUnavailable)
}
res, err := readAndClose(request.Body)
if err != nil {
panic(err)
}
requestContent = string(res)
requestHeader = request.Header
if request.URL.String() != uri {
errorMsg := fmt.Sprintf("Error 404: page not found (expected '%v', got '%v').", uri, request.URL.String())
http.Error(writer, errorMsg, http.StatusNotFound)
} else {
writer.WriteHeader(code)
fmt.Fprint(writer, response)
}
requested = true
}
server := httptest.NewServer(http.HandlerFunc(handler))
return &singleServingServer{server, &requestContent, &requestHeader}
}
type flakyServer struct {
*httptest.Server
nbRequests *int
requests *[][]byte
}
// newFlakyServer creates a "flaky" test http server which will
// return `nbFlakyResponses` responses with the given code and then a 200 response.
func newFlakyServer(uri string, code int, nbFlakyResponses int) *flakyServer {
nbRequests := 0
requests := make([][]byte, nbFlakyResponses+1)
handler := func(writer http.ResponseWriter, request *http.Request) {
nbRequests += 1
body, err := readAndClose(request.Body)
if err != nil {
panic(err)
}
requests[nbRequests-1] = body
if request.URL.String() != uri {
errorMsg := fmt.Sprintf("Error 404: page not found (expected '%v', got '%v').", uri, request.URL.String())
http.Error(writer, errorMsg, http.StatusNotFound)
} else if nbRequests <= nbFlakyResponses {
if code == http.StatusServiceUnavailable {
writer.Header().Set("Retry-After", "0")
}
writer.WriteHeader(code)
fmt.Fprint(writer, "flaky")
} else {
writer.WriteHeader(http.StatusOK)
fmt.Fprint(writer, "ok")
}
}
server := httptest.NewServer(http.HandlerFunc(handler))
return &flakyServer{server, &nbRequests, &requests}
}
type simpleResponse struct {
status int
body string
}
type SimpleTestServer struct {
*httptest.Server
getResponses map[string][]simpleResponse
getResponseIndex map[string]int
putResponses map[string][]simpleResponse
putResponseIndex map[string]int
postResponses map[string][]simpleResponse
postResponseIndex map[string]int
deleteResponses map[string][]simpleResponse
deleteResponseIndex map[string]int
requests []*http.Request
}
func NewSimpleServer() *SimpleTestServer {
server := &SimpleTestServer{
getResponses: make(map[string][]simpleResponse),
getResponseIndex: make(map[string]int),
putResponses: make(map[string][]simpleResponse),
putResponseIndex: make(map[string]int),
postResponses: make(map[string][]simpleResponse),
postResponseIndex: make(map[string]int),
deleteResponses: make(map[string][]simpleResponse),
deleteResponseIndex: make(map[string]int),
}
server.Server = httptest.NewUnstartedServer(http.HandlerFunc(server.handler))
return server
}
func (s *SimpleTestServer) AddGetResponse(path string, status int, body string) {
logger.Debugf("add get response for: %s, %d", path, status)
s.getResponses[path] = append(s.getResponses[path], simpleResponse{status: status, body: body})
}
func (s *SimpleTestServer) AddPutResponse(path string, status int, body string) {
logger.Debugf("add put response for: %s, %d", path, status)
s.putResponses[path] = append(s.putResponses[path], simpleResponse{status: status, body: body})
}
func (s *SimpleTestServer) AddPostResponse(path string, status int, body string) {
logger.Debugf("add post response for: %s, %d", path, status)
s.postResponses[path] = append(s.postResponses[path], simpleResponse{status: status, body: body})
}
func (s *SimpleTestServer) AddDeleteResponse(path string, status int, body string) {
logger.Debugf("add delete response for: %s, %d", path, status)
s.deleteResponses[path] = append(s.deleteResponses[path], simpleResponse{status: status, body: body})
}
func (s *SimpleTestServer) LastRequest() *http.Request {
pos := len(s.requests) - 1
if pos < 0 {
return nil
}
return s.requests[pos]
}
func (s *SimpleTestServer) LastNRequests(n int) []*http.Request {
start := len(s.requests) - n
if start < 0 {
start = 0
}
return s.requests[start:]
}
func (s *SimpleTestServer) RequestCount() int {
return len(s.requests)
}
func (s *SimpleTestServer) ResetRequests() {
s.requests = nil
}
func (s *SimpleTestServer) handler(writer http.ResponseWriter, request *http.Request) {
method := request.Method
var (
err error
responses map[string][]simpleResponse
responseIndex map[string]int
)
switch method {
case "GET":
responses = s.getResponses
responseIndex = s.getResponseIndex
_, err = readAndClose(request.Body)
if err != nil {
panic(err) // it is a test, panic should be fine
}
case "PUT":
responses = s.putResponses
responseIndex = s.putResponseIndex
err = request.ParseForm()
if err != nil {
panic(err)
}
case "POST":
responses = s.postResponses
responseIndex = s.postResponseIndex
contentType := request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "multipart/form-data;") {
err = request.ParseMultipartForm(2 << 20)
} else {
err = request.ParseForm()
}
if err != nil {
panic(err)
}
case "DELETE":
responses = s.deleteResponses
responseIndex = s.deleteResponseIndex
_, err := readAndClose(request.Body)
if err != nil {
panic(err)
}
default:
panic("unsupported method " + method)
}
s.requests = append(s.requests, request)
uri := request.URL.String()
testResponses, found := responses[uri]
if !found {
errorMsg := fmt.Sprintf("Error 404: page not found ('%v').", uri)
http.Error(writer, errorMsg, http.StatusNotFound)
} else {
index := responseIndex[uri]
response := testResponses[index]
responseIndex[uri] = index + 1
writer.WriteHeader(response.status)
fmt.Fprint(writer, response.body)
}
}