Added a helper type to more safely handle async request completion.

Fixes VOL-2286

Change-Id: Ifcbbfdf64c3614838adbbaa11ca69d3d49c44861
diff --git a/rw_core/utils/core_utils.go b/rw_core/utils/core_utils.go
index aad1348..3a71623 100644
--- a/rw_core/utils/core_utils.go
+++ b/rw_core/utils/core_utils.go
@@ -19,7 +19,6 @@
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/status"
 	"os"
-	"reflect"
 	"time"
 )
 
@@ -35,62 +34,70 @@
 	return os.Getenv("HOSTNAME")
 }
 
+type Response struct {
+	*response
+}
+type response struct {
+	err  error
+	ch   chan struct{}
+	done bool
+}
+
+func NewResponse() Response {
+	return Response{
+		&response{
+			ch: make(chan struct{}),
+		},
+	}
+}
+
+// Error sends a response with the given error.  It may only be called once.
+func (r Response) Error(err error) {
+	// if this is called twice, it will panic; this is intentional
+	r.err = err
+	r.done = true
+	close(r.ch)
+}
+
+// Done sends a non-error response unless Error has already been called, in which case this is a no-op.
+func (r Response) Done() {
+	if !r.done {
+		close(r.ch)
+	}
+}
+
 //WaitForNilOrErrorResponses waits on a variadic number of channels for either a nil response or an error
 //response. If an error is received from a given channel then the returned error array will contain that error.
 //The error will be at the index corresponding to the order in which the channel appear in the parameter list.
 //If no errors is found then nil is returned.  This method also takes in a timeout in milliseconds. If a
 //timeout is obtained then this function will stop waiting for the remaining responses and abort.
-func WaitForNilOrErrorResponses(timeout int64, chnls ...chan interface{}) []error {
-	if len(chnls) == 0 {
-		return nil
-	}
-	// Create a timeout channel
-	tChnl := make(chan *interface{})
-	go func() {
-		time.Sleep(time.Duration(timeout) * time.Millisecond)
-		tChnl <- nil
-	}()
+func WaitForNilOrErrorResponses(timeout int64, responses ...Response) []error {
+	timedOut := make(chan struct{})
+	timer := time.AfterFunc(time.Duration(timeout)*time.Millisecond, func() { close(timedOut) })
+	defer timer.Stop()
 
-	errorsReceived := false
-	errors := make([]error, len(chnls))
-	cases := make([]reflect.SelectCase, len(chnls)+1)
-	for i, ch := range chnls {
-		cases[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch)}
-	}
-	// Add the timeout channel
-	cases[len(chnls)] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(tChnl)}
-
-	resultsReceived := make([]bool, len(errors)+1)
-	remaining := len(cases) - 1
-	for remaining > 0 {
-		index, value, ok := reflect.Select(cases)
-		if !ok { // closed channel
-			//Set the channel at that index to nil to disable this case, hence preventing it from interfering with other cases.
-			cases[index].Chan = reflect.ValueOf(nil)
-			errors[index] = status.Error(codes.Internal, "channel closed")
-			errorsReceived = true
-		} else if index == len(chnls) { // Timeout has occurred
-			for k := range errors {
-				if !resultsReceived[k] {
-					errors[k] = status.Error(codes.Aborted, "timeout")
-				}
+	gotError := false
+	errors := make([]error, 0, len(responses))
+	for _, response := range responses {
+		var err error
+		select {
+		case <-response.ch:
+			// if a response is already available, use it
+			err = response.err
+		default:
+			// otherwise, wait for either a response or a timeout
+			select {
+			case <-response.ch:
+				err = response.err
+			case <-timedOut:
+				err = status.Error(codes.Aborted, "timeout")
 			}
-			errorsReceived = true
-			break
-		} else if value.IsNil() { // Nil means a good response
-			//do nothing
-		} else if err, ok := value.Interface().(error); ok { // error returned
-			errors[index] = err
-			errorsReceived = true
-		} else { // unknown value
-			errors[index] = status.Errorf(codes.Internal, "%s", value)
-			errorsReceived = true
 		}
-		resultsReceived[index] = true
-		remaining -= 1
+		gotError = gotError || err != nil
+		errors = append(errors, err)
 	}
 
-	if errorsReceived {
+	if gotError {
 		return errors
 	}
 	return nil
diff --git a/rw_core/utils/core_utils_test.go b/rw_core/utils/core_utils_test.go
index e0d0e75..9f8dd87 100644
--- a/rw_core/utils/core_utils_test.go
+++ b/rw_core/utils/core_utils_test.go
@@ -36,14 +36,14 @@
 	taskFailureError = status.Error(codes.Internal, "test failure task")
 }
 
-func runSuccessfulTask(ch chan interface{}, durationRange int) {
+func runSuccessfulTask(response Response, durationRange int) {
 	time.Sleep(time.Duration(rand.Intn(durationRange)) * time.Millisecond)
-	ch <- nil
+	response.Done()
 }
 
-func runFailureTask(ch chan interface{}, durationRange int) {
+func runFailureTask(response Response, durationRange int) {
 	time.Sleep(time.Duration(rand.Intn(durationRange)) * time.Millisecond)
-	ch <- taskFailureError
+	response.Error(taskFailureError)
 }
 
 func runMultipleTasks(timeout, numTasks, taskDurationRange, numSuccessfulTask, numFailuretask int) []error {
@@ -51,17 +51,17 @@
 		return []error{status.Error(codes.FailedPrecondition, "invalid-num-tasks")}
 	}
 	numSuccessfulTaskCreated := 0
-	chnls := make([]chan interface{}, numTasks)
+	responses := make([]Response, numTasks)
 	for i := 0; i < numTasks; i++ {
-		chnls[i] = make(chan interface{})
+		responses[i] = NewResponse()
 		if numSuccessfulTaskCreated < numSuccessfulTask {
-			go runSuccessfulTask(chnls[i], taskDurationRange)
+			go runSuccessfulTask(responses[i], taskDurationRange)
 			numSuccessfulTaskCreated += 1
 			continue
 		}
-		go runFailureTask(chnls[i], taskDurationRange)
+		go runFailureTask(responses[i], taskDurationRange)
 	}
-	return WaitForNilOrErrorResponses(int64(timeout), chnls...)
+	return WaitForNilOrErrorResponses(int64(timeout), responses...)
 }
 
 func getNumSuccessFailure(inputs []error) (numSuccess, numFailure, numTimeout int) {