blob: bb7faeadedea445e91b60a5dd4ca685892bb3fdf [file] [log] [blame]
/*
* Copyright 2019-present Open Networking Foundation
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package afrouter
import (
"context"
"encoding/hex"
"errors"
"github.com/opencord/voltha-go/common/log"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"sort"
"sync"
)
type streams struct {
mutex sync.Mutex
activeStream *stream
streams map[string]*stream
sortedStreams []*stream
}
type stream struct {
stream grpc.ClientStream
ctx context.Context
cancel context.CancelFunc
ok2Close chan struct{}
c2sReturn chan error
s2cReturn error
}
func (s *streams) clientCancel() {
for _, strm := range s.streams {
if strm != nil {
strm.cancel()
}
}
}
func (s *streams) closeSend() {
for _, strm := range s.streams {
if strm != nil {
<-strm.ok2Close
log.Debug("Closing southbound stream")
strm.stream.CloseSend()
}
}
}
func (s *streams) trailer() metadata.MD {
return s.activeStream.stream.Trailer()
}
func (s *streams) getActive() *stream {
s.mutex.Lock()
defer s.mutex.Unlock()
return s.activeStream
}
func (s *streams) setThenGetActive(strm *stream) *stream {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.activeStream == nil {
s.activeStream = strm
}
return s.activeStream
}
func (s *streams) forwardClientToServer(dst grpc.ServerStream, f *sbFrame) chan error {
fc2s := func(srcS *stream) {
for i := 0; ; i++ {
if err := srcS.stream.RecvMsg(f); err != nil {
if s.setThenGetActive(srcS) == srcS {
srcS.c2sReturn <- err // this can be io.EOF which is the success case
} else {
srcS.c2sReturn <- nil // Inactive responder
}
close(srcS.ok2Close)
break
}
if s.setThenGetActive(srcS) != srcS {
srcS.c2sReturn <- nil
continue
}
if i == 0 {
// This is a bit of a hack, but client to server headers are only readable after first client msg is
// received but must be written to server stream before the first msg is flushed.
// This is the only place to do it nicely.
md, err := srcS.stream.Header()
if err != nil {
srcS.c2sReturn <- err
break
}
// Update the metadata for the response.
if f.metaKey != NoMeta {
if f.metaVal == "" {
// We could also alsways just do this
md.Set(f.metaKey, f.backend.name)
} else {
md.Set(f.metaKey, f.metaVal)
}
}
if err := dst.SendHeader(md); err != nil {
srcS.c2sReturn <- err
break
}
}
log.Debugf("Northbound frame %s", hex.EncodeToString(f.payload))
if err := dst.SendMsg(f); err != nil {
srcS.c2sReturn <- err
break
}
}
}
// There should be AT LEAST one open stream at this point
// if there isn't its a grave error in the code and it will
// cause this thread to block here so check for it and
// don't let the lock up happen but report the error
ret := make(chan error, 1)
agg := make(chan *stream)
atLeastOne := false
for _, strm := range s.streams {
if strm != nil {
go fc2s(strm)
go func(s *stream) { // Wait on result and aggregate
r := <-s.c2sReturn // got the return code
if r == nil {
return // We're the redundant stream, just die
}
s.c2sReturn <- r // put it back to pass it along
agg <- s // send the stream to the aggregator
}(strm)
atLeastOne = true
}
}
if atLeastOne == true {
go func() { // Wait on aggregated result
s := <-agg
ret <- <-s.c2sReturn
}()
} else {
err := errors.New("There are no open streams. Unable to forward message.")
log.Error(err)
ret <- err
}
return ret
}
func (s *streams) sendAll(f *nbFrame) error {
var rtrn error
atLeastOne := false
for _, strm := range s.sortedStreams {
if strm != nil {
if err := strm.stream.SendMsg(f); err != nil {
log.Debugf("Error on SendMsg: %s", err.Error())
strm.s2cReturn = err
}
atLeastOne = true
} else {
log.Debugf("Nil stream")
}
}
// If one of the streams succeeded, declare success
// if none did pick an error and return it.
if atLeastOne == true {
for _, strm := range s.sortedStreams {
if strm != nil {
rtrn = strm.s2cReturn
if rtrn == nil {
return rtrn
}
}
}
return rtrn
} else {
rtrn = errors.New("There are no open streams, this should never happen")
log.Error(rtrn)
}
return rtrn
}
func (s *streams) forwardServerToClient(src grpc.ServerStream, f *nbFrame) chan error {
ret := make(chan error, 1)
go func() {
// The frame buffer already has the results of a first
// RecvMsg in it so the first thing to do is to
// send it to the list of client streams and only
// then read some more.
for i := 0; ; i++ {
// Send the message to each of the backend streams
if err := s.sendAll(f); err != nil {
ret <- err
log.Debugf("SendAll failed %s", err.Error())
break
}
log.Debugf("Southbound frame %s", hex.EncodeToString(f.payload))
if err := src.RecvMsg(f); err != nil {
ret <- err // this can be io.EOF which is happy case
break
}
}
}()
return ret
}
func (s *streams) sortStreams() {
var tmpKeys []string
for k := range s.streams {
tmpKeys = append(tmpKeys, k)
}
sort.Strings(tmpKeys)
for _, v := range tmpKeys {
s.sortedStreams = append(s.sortedStreams, s.streams[v])
}
}