SEBA-901 - handle adapter restart

Stops existing goroutines processing messages and
allows for ONUs not in the initialized state when
adapter reconnects to not attempt to rediscover.

Change-Id: Ie3951d6ad36b7c8b3a4ddfbf55850b8ed7cf35d8
diff --git a/internal/bbr/devices/olt.go b/internal/bbr/devices/olt.go
index 61af90a..7104880 100644
--- a/internal/bbr/devices/olt.go
+++ b/internal/bbr/devices/olt.go
@@ -214,7 +214,9 @@
 		}).Fatal("Cannot find ONU")
 	}
 
-	go onu.ProcessOnuMessages(nil, client)
+	ctx, cancel := context.WithCancel(context.TODO())
+	go onu.ProcessOnuMessages(ctx, nil, client)
+	defer cancel()
 
 	go func() {
 
diff --git a/internal/bbsim/devices/olt.go b/internal/bbsim/devices/olt.go
index 5d6e571..cc4965c 100644
--- a/internal/bbsim/devices/olt.go
+++ b/internal/bbsim/devices/olt.go
@@ -43,6 +43,8 @@
 })
 
 type OltDevice struct {
+	sync.Mutex
+
 	// BBSIM Internals
 	ID              int
 	SerialNumber    string
@@ -60,6 +62,9 @@
 
 	// OLT Attributes
 	OperState *fsm.FSM
+
+	enableContext       context.Context
+	enableContextCancel context.CancelFunc
 }
 
 var olt OltDevice
@@ -270,12 +275,23 @@
 func (o *OltDevice) Enable(stream openolt.Openolt_EnableIndicationServer) error {
 	oltLogger.Debug("Enable OLT called")
 
+	// If enabled has already been called then an enabled context has
+	// been created. If this is the case then we want to cancel all the
+	// proessing loops associated with that enable before we recreate
+	// new ones
+	o.Lock()
+	if o.enableContext != nil && o.enableContextCancel != nil {
+		o.enableContextCancel()
+	}
+	o.enableContext, o.enableContextCancel = context.WithCancel(context.TODO())
+	o.Unlock()
+
 	wg := sync.WaitGroup{}
 	wg.Add(3)
 
 	// create Go routine to process all OLT events
-	go o.processOltMessages(stream, &wg)
-	go o.processNniPacketIns(stream, &wg)
+	go o.processOltMessages(o.enableContext, stream, &wg)
+	go o.processNniPacketIns(o.enableContext, stream, &wg)
 
 	// enable the OLT
 	oltMsg := Message{
@@ -298,7 +314,7 @@
 		o.channel <- msg
 	}
 
-	go o.processOmciMessages()
+	go o.processOmciMessages(o.enableContext, &wg)
 
 	// send PON Port indications
 	for i, pon := range o.Pons {
@@ -312,7 +328,10 @@
 		o.channel <- msg
 
 		for _, onu := range o.Pons[i].Onus {
-			go onu.ProcessOnuMessages(stream, nil)
+			go onu.ProcessOnuMessages(o.enableContext, stream, nil)
+			if onu.InternalState.Current() != "initialized" {
+				continue
+			}
 			if err := onu.InternalState.Event("discover"); err != nil {
 				log.Errorf("Error discover ONU: %v", err)
 				return err
@@ -325,20 +344,34 @@
 	return nil
 }
 
-func (o *OltDevice) processOmciMessages() {
+func (o *OltDevice) processOmciMessages(ctx context.Context, wg *sync.WaitGroup) {
 	ch := omcisim.GetChannel()
 
 	oltLogger.Debug("Starting OMCI Indication Channel")
 
-	for message := range ch {
-		onuId := message.Data.OnuId
-		intfId := message.Data.IntfId
-		onu, err := o.FindOnuById(intfId, onuId)
-		if err != nil {
-			oltLogger.Errorf("Failed to find onu: %v", err)
+loop:
+	for {
+		select {
+		case <-ctx.Done():
+			oltLogger.Debug("OMCI processing canceled via context")
+			break loop
+		case message, ok := <-ch:
+			if !ok || ctx.Err() != nil {
+				oltLogger.Debug("OMCI processing canceled via channel close")
+				break loop
+			}
+			onuId := message.Data.OnuId
+			intfId := message.Data.IntfId
+			onu, err := o.FindOnuById(intfId, onuId)
+			if err != nil {
+				oltLogger.Errorf("Failed to find onu: %v", err)
+				continue
+			}
+			go onu.processOmciMessage(message)
 		}
-		go onu.processOmciMessage(message)
 	}
+
+	wg.Done()
 }
 
 // Helpers method
@@ -432,94 +465,120 @@
 }
 
 // processOltMessages handles messages received over the OpenOLT interface
-func (o *OltDevice) processOltMessages(stream openolt.Openolt_EnableIndicationServer, wg *sync.WaitGroup) {
+func (o *OltDevice) processOltMessages(ctx context.Context, stream openolt.Openolt_EnableIndicationServer, wg *sync.WaitGroup) {
 	oltLogger.Debug("Starting OLT Indication Channel")
-	for message := range o.channel {
+	ch := o.channel
 
-		oltLogger.WithFields(log.Fields{
-			"oltId":       o.ID,
-			"messageType": message.Type,
-		}).Trace("Received message")
-
-		switch message.Type {
-		case OltIndication:
-			msg, _ := message.Data.(OltIndicationMessage)
-			if msg.OperState == UP {
-				o.InternalState.Event("enable")
-				o.OperState.Event("enable")
-			} else if msg.OperState == DOWN {
-				o.InternalState.Event("disable")
-				o.OperState.Event("disable")
+loop:
+	for {
+		select {
+		case <-ctx.Done():
+			oltLogger.Debug("OLT Indication processing canceled via context")
+			break loop
+		case message, ok := <-ch:
+			if !ok || ctx.Err() != nil {
+				oltLogger.Debug("OLT Indication processing canceled via closed channel")
+				break loop
 			}
-			o.sendOltIndication(msg, stream)
-		case NniIndication:
-			msg, _ := message.Data.(NniIndicationMessage)
-			o.sendNniIndication(msg, stream)
-		case PonIndication:
-			msg, _ := message.Data.(PonIndicationMessage)
-			o.sendPonIndication(msg, stream)
-		default:
-			oltLogger.Warnf("Received unknown message data %v for type %v in OLT Channel", message.Data, message.Type)
-		}
 
+			oltLogger.WithFields(log.Fields{
+				"oltId":       o.ID,
+				"messageType": message.Type,
+			}).Trace("Received message")
+
+			switch message.Type {
+			case OltIndication:
+				msg, _ := message.Data.(OltIndicationMessage)
+				if msg.OperState == UP {
+					o.InternalState.Event("enable")
+					o.OperState.Event("enable")
+				} else if msg.OperState == DOWN {
+					o.InternalState.Event("disable")
+					o.OperState.Event("disable")
+				}
+				o.sendOltIndication(msg, stream)
+			case NniIndication:
+				msg, _ := message.Data.(NniIndicationMessage)
+				o.sendNniIndication(msg, stream)
+			case PonIndication:
+				msg, _ := message.Data.(PonIndicationMessage)
+				o.sendPonIndication(msg, stream)
+			default:
+				oltLogger.Warnf("Received unknown message data %v for type %v in OLT Channel", message.Data, message.Type)
+			}
+		}
 	}
 	wg.Done()
 	oltLogger.Warn("Stopped handling OLT Indication Channel")
 }
 
 // processNniPacketIns handles messages received over the NNI interface
-func (o *OltDevice) processNniPacketIns(stream openolt.Openolt_EnableIndicationServer, wg *sync.WaitGroup) {
+func (o *OltDevice) processNniPacketIns(ctx context.Context, stream openolt.Openolt_EnableIndicationServer, wg *sync.WaitGroup) {
 	oltLogger.WithFields(log.Fields{
 		"nniChannel": o.nniPktInChannel,
 	}).Debug("Started NNI Channel")
 	nniId := o.Nnis[0].ID // FIXME we are assuming we have only one NNI
-	for message := range o.nniPktInChannel {
-		oltLogger.Tracef("Received packets on NNI Channel")
 
-		onuMac, err := packetHandlers.GetDstMacAddressFromPacket(message.Pkt)
+	ch := o.nniPktInChannel
 
-		if err != nil {
-			log.WithFields(log.Fields{
-				"IntfType": "nni",
-				"IntfId":   nniId,
-				"Pkt":      message.Pkt.Data(),
-			}).Error("Can't find Dst MacAddress in packet")
-			return
-		}
+loop:
+	for {
+		select {
+		case <-ctx.Done():
+			oltLogger.Debug("NNI Indication processing canceled via context")
+			break loop
+		case message, ok := <-ch:
+			if !ok || ctx.Err() != nil {
+				oltLogger.Debug("NNI Indication processing canceled via channel closed")
+				break loop
+			}
+			oltLogger.Tracef("Received packets on NNI Channel")
 
-		onu, err := o.FindOnuByMacAddress(onuMac)
-		if err != nil {
-			log.WithFields(log.Fields{
-				"IntfType":   "nni",
-				"IntfId":     nniId,
-				"Pkt":        message.Pkt.Data(),
-				"MacAddress": onuMac.String(),
-			}).Error("Can't find ONU with MacAddress")
-			return
-		}
+			onuMac, err := packetHandlers.GetDstMacAddressFromPacket(message.Pkt)
 
-		doubleTaggedPkt, err := packetHandlers.PushDoubleTag(onu.STag, onu.CTag, message.Pkt)
-		if err != nil {
-			log.Error("Fail to add double tag to packet")
-		}
+			if err != nil {
+				log.WithFields(log.Fields{
+					"IntfType": "nni",
+					"IntfId":   nniId,
+					"Pkt":      message.Pkt.Data(),
+				}).Error("Can't find Dst MacAddress in packet")
+				return
+			}
 
-		data := &openolt.Indication_PktInd{PktInd: &openolt.PacketIndication{
-			IntfType: "nni",
-			IntfId:   nniId,
-			Pkt:      doubleTaggedPkt.Data()}}
-		if err := stream.Send(&openolt.Indication{Data: data}); err != nil {
+			onu, err := o.FindOnuByMacAddress(onuMac)
+			if err != nil {
+				log.WithFields(log.Fields{
+					"IntfType":   "nni",
+					"IntfId":     nniId,
+					"Pkt":        message.Pkt.Data(),
+					"MacAddress": onuMac.String(),
+				}).Error("Can't find ONU with MacAddress")
+				return
+			}
+
+			doubleTaggedPkt, err := packetHandlers.PushDoubleTag(onu.STag, onu.CTag, message.Pkt)
+			if err != nil {
+				log.Error("Fail to add double tag to packet")
+			}
+
+			data := &openolt.Indication_PktInd{PktInd: &openolt.PacketIndication{
+				IntfType: "nni",
+				IntfId:   nniId,
+				Pkt:      doubleTaggedPkt.Data()}}
+			if err := stream.Send(&openolt.Indication{Data: data}); err != nil {
+				oltLogger.WithFields(log.Fields{
+					"IntfType": data.PktInd.IntfType,
+					"IntfId":   nniId,
+					"Pkt":      doubleTaggedPkt.Data(),
+				}).Errorf("Fail to send PktInd indication: %v", err)
+			}
 			oltLogger.WithFields(log.Fields{
 				"IntfType": data.PktInd.IntfType,
 				"IntfId":   nniId,
 				"Pkt":      doubleTaggedPkt.Data(),
-			}).Errorf("Fail to send PktInd indication: %v", err)
+				"OnuSn":    onu.Sn(),
+			}).Tracef("Sent PktInd indication")
 		}
-		oltLogger.WithFields(log.Fields{
-			"IntfType": data.PktInd.IntfType,
-			"IntfId":   nniId,
-			"Pkt":      doubleTaggedPkt.Data(),
-			"OnuSn":    onu.Sn(),
-		}).Tracef("Sent PktInd indication")
 	}
 	wg.Done()
 	oltLogger.WithFields(log.Fields{
diff --git a/internal/bbsim/devices/onu.go b/internal/bbsim/devices/onu.go
index 4fd18cf..9ab83fa 100644
--- a/internal/bbsim/devices/onu.go
+++ b/internal/bbsim/devices/onu.go
@@ -246,88 +246,105 @@
 }
 
 // ProcessOnuMessages starts indication channel for each ONU
-func (o *Onu) ProcessOnuMessages(stream openolt.Openolt_EnableIndicationServer, client openolt.OpenoltClient) {
+func (o *Onu) ProcessOnuMessages(ctx context.Context, stream openolt.Openolt_EnableIndicationServer, client openolt.OpenoltClient) {
 	onuLogger.WithFields(log.Fields{
 		"onuID":   o.ID,
 		"onuSN":   o.Sn(),
 		"ponPort": o.PonPortID,
 	}).Debug("Starting ONU Indication Channel")
 
-	for message := range o.Channel {
-		onuLogger.WithFields(log.Fields{
-			"onuID":       o.ID,
-			"onuSN":       o.Sn(),
-			"messageType": message.Type,
-		}).Tracef("Received message on ONU Channel")
-
-		switch message.Type {
-		case OnuDiscIndication:
-			msg, _ := message.Data.(OnuDiscIndicationMessage)
-			// NOTE we need to slow down and send ONU Discovery Indication in batches to better emulate a real scenario
-			time.Sleep(time.Duration(int(o.ID)*o.PonPort.Olt.Delay) * time.Millisecond)
-			o.sendOnuDiscIndication(msg, stream)
-		case OnuIndication:
-			msg, _ := message.Data.(OnuIndicationMessage)
-			o.sendOnuIndication(msg, stream)
-		case OMCI:
-			msg, _ := message.Data.(OmciMessage)
-			o.handleOmciMessage(msg, stream)
-		case FlowUpdate:
-			msg, _ := message.Data.(OnuFlowUpdateMessage)
-			o.handleFlowUpdate(msg)
-		case StartEAPOL:
-			log.Infof("Receive StartEAPOL message on ONU Channel")
-			eapol.SendEapStart(o.ID, o.PonPortID, o.Sn(), o.PortNo, o.HwAddress, o.InternalState, stream)
-		case StartDHCP:
-			log.Infof("Receive StartDHCP message on ONU Channel")
-			// FIXME use id, ponId as SendEapStart
-			dhcp.SendDHCPDiscovery(o.PonPortID, o.ID, o.Sn(), o.PortNo, o.InternalState, o.HwAddress, o.CTag, stream)
-		case OnuPacketOut:
-
-			msg, _ := message.Data.(OnuPacketMessage)
-
-			log.WithFields(log.Fields{
-				"IntfId":  msg.IntfId,
-				"OnuId":   msg.OnuId,
-				"pktType": msg.Type,
-			}).Trace("Received OnuPacketOut Message")
-
-			if msg.Type == packetHandlers.EAPOL {
-				eapol.HandleNextPacket(msg.OnuId, msg.IntfId, o.Sn(), o.PortNo, o.InternalState, msg.Packet, stream, client)
-			} else if msg.Type == packetHandlers.DHCP {
-				// NOTE here we receive packets going from the DHCP Server to the ONU
-				// for now we expect them to be double-tagged, but ideally the should be single tagged
-				dhcp.HandleNextPacket(o.ID, o.PonPortID, o.Sn(), o.PortNo, o.HwAddress, o.CTag, o.InternalState, msg.Packet, stream)
+loop:
+	for {
+		select {
+		case <-ctx.Done():
+			onuLogger.WithFields(log.Fields{
+				"onuID": o.ID,
+				"onuSN": o.Sn(),
+			}).Tracef("ONU message handling canceled via context")
+			break loop
+		case message, ok := <-o.Channel:
+			if !ok || ctx.Err() != nil {
+				onuLogger.WithFields(log.Fields{
+					"onuID": o.ID,
+					"onuSN": o.Sn(),
+				}).Tracef("ONU message handling canceled via channel close")
+				break loop
 			}
-		case OnuPacketIn:
-			// NOTE we only receive BBR packets here.
-			// Eapol.HandleNextPacket can handle both BBSim and BBr cases so the call is the same
-			// in the DHCP case VOLTHA only act as a proxy, the behaviour is completely different thus we have a dhcp.HandleNextBbrPacket
-			msg, _ := message.Data.(OnuPacketMessage)
+			onuLogger.WithFields(log.Fields{
+				"onuID":       o.ID,
+				"onuSN":       o.Sn(),
+				"messageType": message.Type,
+			}).Tracef("Received message on ONU Channel")
 
-			log.WithFields(log.Fields{
-				"IntfId":  msg.IntfId,
-				"OnuId":   msg.OnuId,
-				"pktType": msg.Type,
-			}).Trace("Received OnuPacketIn Message")
+			switch message.Type {
+			case OnuDiscIndication:
+				msg, _ := message.Data.(OnuDiscIndicationMessage)
+				// NOTE we need to slow down and send ONU Discovery Indication in batches to better emulate a real scenario
+				time.Sleep(time.Duration(int(o.ID)*o.PonPort.Olt.Delay) * time.Millisecond)
+				o.sendOnuDiscIndication(msg, stream)
+			case OnuIndication:
+				msg, _ := message.Data.(OnuIndicationMessage)
+				o.sendOnuIndication(msg, stream)
+			case OMCI:
+				msg, _ := message.Data.(OmciMessage)
+				o.handleOmciMessage(msg, stream)
+			case FlowUpdate:
+				msg, _ := message.Data.(OnuFlowUpdateMessage)
+				o.handleFlowUpdate(msg)
+			case StartEAPOL:
+				log.Infof("Receive StartEAPOL message on ONU Channel")
+				eapol.SendEapStart(o.ID, o.PonPortID, o.Sn(), o.PortNo, o.HwAddress, o.InternalState, stream)
+			case StartDHCP:
+				log.Infof("Receive StartDHCP message on ONU Channel")
+				// FIXME use id, ponId as SendEapStart
+				dhcp.SendDHCPDiscovery(o.PonPortID, o.ID, o.Sn(), o.PortNo, o.InternalState, o.HwAddress, o.CTag, stream)
+			case OnuPacketOut:
 
-			if msg.Type == packetHandlers.EAPOL {
-				eapol.HandleNextPacket(msg.OnuId, msg.IntfId, o.Sn(), o.PortNo, o.InternalState, msg.Packet, stream, client)
-			} else if msg.Type == packetHandlers.DHCP {
-				dhcp.HandleNextBbrPacket(o.ID, o.PonPortID, o.Sn(), o.STag, o.HwAddress, o.DoneChannel, msg.Packet, client)
+				msg, _ := message.Data.(OnuPacketMessage)
+
+				log.WithFields(log.Fields{
+					"IntfId":  msg.IntfId,
+					"OnuId":   msg.OnuId,
+					"pktType": msg.Type,
+				}).Trace("Received OnuPacketOut Message")
+
+				if msg.Type == packetHandlers.EAPOL {
+					eapol.HandleNextPacket(msg.OnuId, msg.IntfId, o.Sn(), o.PortNo, o.InternalState, msg.Packet, stream, client)
+				} else if msg.Type == packetHandlers.DHCP {
+					// NOTE here we receive packets going from the DHCP Server to the ONU
+					// for now we expect them to be double-tagged, but ideally the should be single tagged
+					dhcp.HandleNextPacket(o.ID, o.PonPortID, o.Sn(), o.PortNo, o.HwAddress, o.CTag, o.InternalState, msg.Packet, stream)
+				}
+			case OnuPacketIn:
+				// NOTE we only receive BBR packets here.
+				// Eapol.HandleNextPacket can handle both BBSim and BBr cases so the call is the same
+				// in the DHCP case VOLTHA only act as a proxy, the behaviour is completely different thus we have a dhcp.HandleNextBbrPacket
+				msg, _ := message.Data.(OnuPacketMessage)
+
+				log.WithFields(log.Fields{
+					"IntfId":  msg.IntfId,
+					"OnuId":   msg.OnuId,
+					"pktType": msg.Type,
+				}).Trace("Received OnuPacketIn Message")
+
+				if msg.Type == packetHandlers.EAPOL {
+					eapol.HandleNextPacket(msg.OnuId, msg.IntfId, o.Sn(), o.PortNo, o.InternalState, msg.Packet, stream, client)
+				} else if msg.Type == packetHandlers.DHCP {
+					dhcp.HandleNextBbrPacket(o.ID, o.PonPortID, o.Sn(), o.STag, o.HwAddress, o.DoneChannel, msg.Packet, client)
+				}
+			case DyingGaspIndication:
+				msg, _ := message.Data.(DyingGaspIndicationMessage)
+				o.sendDyingGaspInd(msg, stream)
+			case OmciIndication:
+				msg, _ := message.Data.(OmciIndicationMessage)
+				o.handleOmci(msg, client)
+			case SendEapolFlow:
+				o.sendEapolFlow(client)
+			case SendDhcpFlow:
+				o.sendDhcpFlow(client)
+			default:
+				onuLogger.Warnf("Received unknown message data %v for type %v in OLT Channel", message.Data, message.Type)
 			}
-		case DyingGaspIndication:
-			msg, _ := message.Data.(DyingGaspIndicationMessage)
-			o.sendDyingGaspInd(msg, stream)
-		case OmciIndication:
-			msg, _ := message.Data.(OmciIndicationMessage)
-			o.handleOmci(msg, client)
-		case SendEapolFlow:
-			o.sendEapolFlow(client)
-		case SendDhcpFlow:
-			o.sendDhcpFlow(client)
-		default:
-			onuLogger.Warnf("Received unknown message data %v for type %v in OLT Channel", message.Data, message.Type)
 		}
 	}
 	onuLogger.WithFields(log.Fields{
diff --git a/internal/bbsim/devices/onu_indications_test.go b/internal/bbsim/devices/onu_indications_test.go
index d12c5a6..37008a2 100644
--- a/internal/bbsim/devices/onu_indications_test.go
+++ b/internal/bbsim/devices/onu_indications_test.go
@@ -17,6 +17,7 @@
 package devices
 
 import (
+	"context"
 	"errors"
 	"github.com/opencord/voltha-protos/v2/go/openolt"
 	"google.golang.org/grpc"
@@ -52,7 +53,8 @@
 		fail:      false,
 		channel:   make(chan int, 10),
 	}
-	go onu.ProcessOnuMessages(stream, nil)
+	ctx, cancel := context.WithCancel(context.TODO())
+	go onu.ProcessOnuMessages(ctx, stream, nil)
 	onu.InternalState.SetState("initialized")
 	onu.InternalState.Event("discover")
 
@@ -62,6 +64,7 @@
 		assert.Equal(t, stream.Calls[1].IntfId, onu.PonPortID)
 		assert.Equal(t, stream.Calls[1].SerialNumber, onu.SerialNumber)
 	}
+	cancel()
 }
 
 // test that if the discovery indication is not acknowledge we'll keep sending new ones
@@ -73,7 +76,8 @@
 		fail:      false,
 		channel:   make(chan int, 10),
 	}
-	go onu.ProcessOnuMessages(stream, nil)
+	ctx, cancel := context.WithCancel(context.TODO())
+	go onu.ProcessOnuMessages(ctx, stream, nil)
 	onu.InternalState.SetState("initialized")
 	onu.InternalState.Event("discover")
 
@@ -81,6 +85,7 @@
 	case <-time.After(400 * time.Millisecond):
 		assert.Equal(t, stream.CallCount, 4)
 	}
+	cancel()
 }
 
 // test that if the discovery indication is not acknowledge we'll send a new one
@@ -93,7 +98,8 @@
 		fail:      false,
 		channel:   make(chan int, 10),
 	}
-	go onu.ProcessOnuMessages(stream, nil)
+	ctx, cancel := context.WithCancel(context.TODO())
+	go onu.ProcessOnuMessages(ctx, stream, nil)
 	onu.InternalState.SetState("initialized")
 	onu.InternalState.Event("discover")
 
@@ -110,4 +116,5 @@
 
 		assert.Equal(t, stream.CallCount, 2)
 	}
+	cancel()
 }