Added a helper type to more safely handle async request completion.
Fixes VOL-2286
Change-Id: Ifcbbfdf64c3614838adbbaa11ca69d3d49c44861
diff --git a/rw_core/utils/core_utils.go b/rw_core/utils/core_utils.go
index aad1348..3a71623 100644
--- a/rw_core/utils/core_utils.go
+++ b/rw_core/utils/core_utils.go
@@ -19,7 +19,6 @@
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"os"
- "reflect"
"time"
)
@@ -35,62 +34,70 @@
return os.Getenv("HOSTNAME")
}
+type Response struct {
+ *response
+}
+type response struct {
+ err error
+ ch chan struct{}
+ done bool
+}
+
+func NewResponse() Response {
+ return Response{
+ &response{
+ ch: make(chan struct{}),
+ },
+ }
+}
+
+// Error sends a response with the given error. It may only be called once.
+func (r Response) Error(err error) {
+ // if this is called twice, it will panic; this is intentional
+ r.err = err
+ r.done = true
+ close(r.ch)
+}
+
+// Done sends a non-error response unless Error has already been called, in which case this is a no-op.
+func (r Response) Done() {
+ if !r.done {
+ close(r.ch)
+ }
+}
+
//WaitForNilOrErrorResponses waits on a variadic number of channels for either a nil response or an error
//response. If an error is received from a given channel then the returned error array will contain that error.
//The error will be at the index corresponding to the order in which the channel appear in the parameter list.
//If no errors is found then nil is returned. This method also takes in a timeout in milliseconds. If a
//timeout is obtained then this function will stop waiting for the remaining responses and abort.
-func WaitForNilOrErrorResponses(timeout int64, chnls ...chan interface{}) []error {
- if len(chnls) == 0 {
- return nil
- }
- // Create a timeout channel
- tChnl := make(chan *interface{})
- go func() {
- time.Sleep(time.Duration(timeout) * time.Millisecond)
- tChnl <- nil
- }()
+func WaitForNilOrErrorResponses(timeout int64, responses ...Response) []error {
+ timedOut := make(chan struct{})
+ timer := time.AfterFunc(time.Duration(timeout)*time.Millisecond, func() { close(timedOut) })
+ defer timer.Stop()
- errorsReceived := false
- errors := make([]error, len(chnls))
- cases := make([]reflect.SelectCase, len(chnls)+1)
- for i, ch := range chnls {
- cases[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch)}
- }
- // Add the timeout channel
- cases[len(chnls)] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(tChnl)}
-
- resultsReceived := make([]bool, len(errors)+1)
- remaining := len(cases) - 1
- for remaining > 0 {
- index, value, ok := reflect.Select(cases)
- if !ok { // closed channel
- //Set the channel at that index to nil to disable this case, hence preventing it from interfering with other cases.
- cases[index].Chan = reflect.ValueOf(nil)
- errors[index] = status.Error(codes.Internal, "channel closed")
- errorsReceived = true
- } else if index == len(chnls) { // Timeout has occurred
- for k := range errors {
- if !resultsReceived[k] {
- errors[k] = status.Error(codes.Aborted, "timeout")
- }
+ gotError := false
+ errors := make([]error, 0, len(responses))
+ for _, response := range responses {
+ var err error
+ select {
+ case <-response.ch:
+ // if a response is already available, use it
+ err = response.err
+ default:
+ // otherwise, wait for either a response or a timeout
+ select {
+ case <-response.ch:
+ err = response.err
+ case <-timedOut:
+ err = status.Error(codes.Aborted, "timeout")
}
- errorsReceived = true
- break
- } else if value.IsNil() { // Nil means a good response
- //do nothing
- } else if err, ok := value.Interface().(error); ok { // error returned
- errors[index] = err
- errorsReceived = true
- } else { // unknown value
- errors[index] = status.Errorf(codes.Internal, "%s", value)
- errorsReceived = true
}
- resultsReceived[index] = true
- remaining -= 1
+ gotError = gotError || err != nil
+ errors = append(errors, err)
}
- if errorsReceived {
+ if gotError {
return errors
}
return nil