[VOL-3999] Correctly handling context during OLT Reconcile

Change-Id: I22b8cca74eba3574adee4ed9dae48808ca9af889
diff --git a/internal/bbr/devices/olt.go b/internal/bbr/devices/olt.go
index b65a174..b40bf58 100644
--- a/internal/bbr/devices/olt.go
+++ b/internal/bbr/devices/olt.go
@@ -48,6 +48,18 @@
 	CompletedOnus int // Number of ONUs that have received a DHCPAck
 }
 
+type MockStream struct {
+	grpc.ServerStream
+}
+
+func (*MockStream) Send(ind *openolt.Indication) error {
+	return nil
+}
+
+func (*MockStream) Context() context.Context {
+	return context.Background()
+}
+
 // trigger an enable call and start the same listeners on the gRPC stream that VOLTHA would create
 // this method is blocking
 func (o *OltMock) Start() {
@@ -211,7 +223,11 @@
 	}
 
 	ctx, cancel := context.WithCancel(context.TODO())
-	go onu.ProcessOnuMessages(ctx, nil, client)
+	// NOTE we need to create a fake stream for ProcessOnuMessages
+	// as it listen on the context to cancel the loop
+	// In the BBR case it's not used for anything else
+	mockStream := MockStream{}
+	go onu.ProcessOnuMessages(ctx, &mockStream, client)
 
 	go func() {
 
diff --git a/internal/bbsim/devices/olt.go b/internal/bbsim/devices/olt.go
index 6ad068d..1c4e099 100644
--- a/internal/bbsim/devices/olt.go
+++ b/internal/bbsim/devices/olt.go
@@ -413,19 +413,20 @@
 	// new ones
 	o.Lock()
 	if o.enableContext != nil && o.enableContextCancel != nil {
-		oltLogger.Info("This is an OLT reboot")
+		oltLogger.Info("This is an OLT reboot or a reconcile")
 		o.enableContextCancel()
 		rebootFlag = true
+		time.Sleep(1 * time.Second)
 	}
 	o.enableContext, o.enableContextCancel = context.WithCancel(context.TODO())
 	o.Unlock()
 
 	wg := sync.WaitGroup{}
-	wg.Add(3)
 
 	o.OpenoltStream = stream
 
 	// create Go routine to process all OLT events
+	wg.Add(1)
 	go o.processOltMessages(o.enableContext, stream, &wg)
 
 	// enable the OLT
@@ -461,6 +462,15 @@
 				}
 				o.channel <- msg
 			}
+			// when the enableContext was canceled the ONUs stopped listening on the channel
+			for _, onu := range pon.Onus {
+				go onu.ProcessOnuMessages(o.enableContext, stream, nil)
+
+				// update the stream on all the services
+				for _, service := range onu.Services {
+					service.UpdateStream(stream)
+				}
+			}
 		}
 	} else {
 
@@ -482,18 +492,22 @@
 		}
 	}
 
-	oltLogger.Debug("Enable OLT Done")
-
 	if !o.enablePerf {
 		// Start a go routine to send periodic port stats to openolt adapter
-		go o.periodicPortStats(o.enableContext)
+		wg.Add(1)
+		go o.periodicPortStats(o.enableContext, &wg, stream)
 	}
 
 	wg.Wait()
+	oltLogger.WithFields(log.Fields{
+		"stream": stream,
+	}).Debug("OpenOLT Stream closed")
 }
 
-func (o *OltDevice) periodicPortStats(ctx context.Context) {
+func (o *OltDevice) periodicPortStats(ctx context.Context, wg *sync.WaitGroup, stream openolt.Openolt_EnableIndicationServer) {
 	var portStats *openolt.PortStatistics
+
+loop:
 	for {
 		select {
 		case <-time.After(time.Duration(o.PortStatsInterval) * time.Second):
@@ -504,7 +518,7 @@
 					incrementStat = false
 				}
 				portStats, port.PacketCount = getPortStats(port.PacketCount, incrementStat)
-				o.sendPortStatsIndication(portStats, port.ID, port.Type)
+				o.sendPortStatsIndication(portStats, port.ID, port.Type, stream)
 			}
 
 			// send PON port stats
@@ -515,14 +529,14 @@
 					incrementStat = false
 				}
 				portStats, port.PacketCount = getPortStats(port.PacketCount, incrementStat)
-				o.sendPortStatsIndication(portStats, port.ID, port.Type)
+				o.sendPortStatsIndication(portStats, port.ID, port.Type, stream)
 			}
 		case <-ctx.Done():
-			log.Debug("Stop sending port stats")
-			return
+			oltLogger.Debug("Stop sending port stats")
+			break loop
 		}
-
 	}
+	wg.Done()
 }
 
 // Helpers method
@@ -695,7 +709,7 @@
 	}).Debug("Sent Indication_IntfOperInd for PON")
 }
 
-func (o *OltDevice) sendPortStatsIndication(stats *openolt.PortStatistics, portID uint32, portType string) {
+func (o *OltDevice) sendPortStatsIndication(stats *openolt.PortStatistics, portID uint32, portType string, stream openolt.Openolt_EnableIndicationServer) {
 	if o.InternalState.Current() == "enabled" {
 		oltLogger.WithFields(log.Fields{
 			"Type":   portType,
@@ -705,7 +719,7 @@
 		data := &openolt.Indication_PortStats{
 			PortStats: stats,
 		}
-		stream := o.OpenoltStream
+
 		if err := stream.Send(&openolt.Indication{Data: data}); err != nil {
 			oltLogger.Errorf("Failed to send PortStats: %v", err)
 			return
@@ -714,8 +728,10 @@
 }
 
 // processOltMessages handles messages received over the OpenOLT interface
-func (o *OltDevice) processOltMessages(ctx context.Context, stream openolt.Openolt_EnableIndicationServer, wg *sync.WaitGroup) {
-	oltLogger.Debug("Starting OLT Indication Channel")
+func (o *OltDevice) processOltMessages(ctx context.Context, stream types.Stream, wg *sync.WaitGroup) {
+	oltLogger.WithFields(log.Fields{
+		"stream": stream,
+	}).Debug("Starting OLT Indication Channel")
 	ch := o.channel
 
 loop:
@@ -724,9 +740,15 @@
 		case <-ctx.Done():
 			oltLogger.Debug("OLT Indication processing canceled via context")
 			break loop
+		case <-stream.Context().Done():
+			oltLogger.Debug("OLT Indication processing canceled via stream context")
+			break loop
 		case message, ok := <-ch:
-			if !ok || ctx.Err() != nil {
-				oltLogger.Debug("OLT Indication processing canceled via closed channel")
+			if !ok {
+				if ctx.Err() != nil {
+					oltLogger.WithField("err", ctx.Err()).Error("OLT EnableContext error")
+				}
+				oltLogger.Warn("OLT Indication processing canceled via closed channel")
 				break loop
 			}
 
@@ -788,7 +810,9 @@
 		}
 	}
 	wg.Done()
-	oltLogger.Warn("Stopped handling OLT Indication Channel")
+	oltLogger.WithFields(log.Fields{
+		"stream": stream,
+	}).Warn("Stopped handling OLT Indication Channel")
 }
 
 // returns an ONU with a given Serial Number
diff --git a/internal/bbsim/devices/onu.go b/internal/bbsim/devices/onu.go
index f9e2e32..a8d67d6 100644
--- a/internal/bbsim/devices/onu.go
+++ b/internal/bbsim/devices/onu.go
@@ -329,6 +329,7 @@
 		"onuID":   o.ID,
 		"onuSN":   o.Sn(),
 		"ponPort": o.PonPortID,
+		"stream":  stream,
 	}).Debug("Starting ONU Indication Channel")
 
 loop:
@@ -338,14 +339,20 @@
 			onuLogger.WithFields(log.Fields{
 				"onuID": o.ID,
 				"onuSN": o.Sn(),
-			}).Tracef("ONU message handling canceled via context")
+			}).Debug("ONU message handling canceled via context")
+			break loop
+		case <-stream.Context().Done():
+			onuLogger.WithFields(log.Fields{
+				"onuID": o.ID,
+				"onuSN": o.Sn(),
+			}).Debug("ONU message handling canceled via stream context")
 			break loop
 		case message, ok := <-o.Channel:
 			if !ok || ctx.Err() != nil {
 				onuLogger.WithFields(log.Fields{
 					"onuID": o.ID,
 					"onuSN": o.Sn(),
-				}).Tracef("ONU message handling canceled via channel close")
+				}).Debug("ONU message handling canceled via channel close")
 				break loop
 			}
 			onuLogger.WithFields(log.Fields{
@@ -477,8 +484,9 @@
 		}
 	}
 	onuLogger.WithFields(log.Fields{
-		"onuID": o.ID,
-		"onuSN": o.Sn(),
+		"onuID":  o.ID,
+		"onuSN":  o.Sn(),
+		"stream": stream,
 	}).Debug("Stopped handling ONU Indication Channel")
 }
 
diff --git a/internal/bbsim/devices/onu_indications_test.go b/internal/bbsim/devices/onu_indications_test.go
index 04af678..5cc3533 100644
--- a/internal/bbsim/devices/onu_indications_test.go
+++ b/internal/bbsim/devices/onu_indications_test.go
@@ -47,6 +47,10 @@
 	return nil
 }
 
+func (s *mockStream) Context() context.Context {
+	return context.Background()
+}
+
 // test that we're sending a Discovery indication to VOLTHA
 func Test_Onu_DiscoverIndication_send_on_discovery(t *testing.T) {
 	onu := createTestOnu()
diff --git a/internal/bbsim/devices/service_test.go b/internal/bbsim/devices/service_test.go
index a71a50c..ec081db 100644
--- a/internal/bbsim/devices/service_test.go
+++ b/internal/bbsim/devices/service_test.go
@@ -46,8 +46,9 @@
 	s.HandlePacketsCallCount = s.HandlePacketsCallCount + 1
 }
 
-func (s *mockService) Initialize(stream types.Stream) {}
-func (s *mockService) Disable()                       {}
+func (s *mockService) Initialize(stream types.Stream)   {}
+func (s *mockService) UpdateStream(stream types.Stream) {}
+func (s *mockService) Disable()                         {}
 
 func createTestService(needsEapol bool, needsDchp bool) (*Service, error) {
 
diff --git a/internal/bbsim/devices/services.go b/internal/bbsim/devices/services.go
index 4391414..8c4ddfd 100644
--- a/internal/bbsim/devices/services.go
+++ b/internal/bbsim/devices/services.go
@@ -45,6 +45,7 @@
 	HandleDhcp(pbit uint8, cTag int) // Sends the DHCPDiscover packet
 
 	Initialize(stream bbsimTypes.Stream)
+	UpdateStream(stream bbsimTypes.Stream)
 	Disable()
 }
 
@@ -115,7 +116,7 @@
 					serviceLogger.Fatal("initialize invoke with wrong arguments")
 				}
 
-				service.Stream = stream
+				service.UpdateStream(stream)
 
 				service.PacketCh = make(chan bbsimTypes.OnuPacketMessage)
 				service.Channel = make(chan bbsimTypes.Message)
@@ -290,6 +291,10 @@
 	return &service, nil
 }
 
+func (s *Service) UpdateStream(stream bbsimTypes.Stream) {
+	s.Stream = stream
+}
+
 // HandleAuth is used to start EAPOL for a particular Service when the corresponding flow is received
 func (s *Service) HandleAuth() {
 
diff --git a/internal/bbsim/types/interfaces.go b/internal/bbsim/types/interfaces.go
index 380af34..946d507 100644
--- a/internal/bbsim/types/interfaces.go
+++ b/internal/bbsim/types/interfaces.go
@@ -18,9 +18,11 @@
 
 import (
 	"github.com/opencord/voltha-protos/v4/go/openolt"
+	"google.golang.org/grpc"
 )
 
 // represent a gRPC stream
 type Stream interface {
 	Send(*openolt.Indication) error
+	grpc.ServerStream
 }