SEBA-901 - handle adapter restart

Stops existing goroutines processing messages and
allows for ONUs not in the initialized state when
adapter reconnects to not attempt to rediscover.

Change-Id: Ie3951d6ad36b7c8b3a4ddfbf55850b8ed7cf35d8
diff --git a/internal/bbsim/devices/olt.go b/internal/bbsim/devices/olt.go
index 5d6e571..cc4965c 100644
--- a/internal/bbsim/devices/olt.go
+++ b/internal/bbsim/devices/olt.go
@@ -43,6 +43,8 @@
 })
 
 type OltDevice struct {
+	sync.Mutex
+
 	// BBSIM Internals
 	ID              int
 	SerialNumber    string
@@ -60,6 +62,9 @@
 
 	// OLT Attributes
 	OperState *fsm.FSM
+
+	enableContext       context.Context
+	enableContextCancel context.CancelFunc
 }
 
 var olt OltDevice
@@ -270,12 +275,23 @@
 func (o *OltDevice) Enable(stream openolt.Openolt_EnableIndicationServer) error {
 	oltLogger.Debug("Enable OLT called")
 
+	// If enabled has already been called then an enabled context has
+	// been created. If this is the case then we want to cancel all the
+	// proessing loops associated with that enable before we recreate
+	// new ones
+	o.Lock()
+	if o.enableContext != nil && o.enableContextCancel != nil {
+		o.enableContextCancel()
+	}
+	o.enableContext, o.enableContextCancel = context.WithCancel(context.TODO())
+	o.Unlock()
+
 	wg := sync.WaitGroup{}
 	wg.Add(3)
 
 	// create Go routine to process all OLT events
-	go o.processOltMessages(stream, &wg)
-	go o.processNniPacketIns(stream, &wg)
+	go o.processOltMessages(o.enableContext, stream, &wg)
+	go o.processNniPacketIns(o.enableContext, stream, &wg)
 
 	// enable the OLT
 	oltMsg := Message{
@@ -298,7 +314,7 @@
 		o.channel <- msg
 	}
 
-	go o.processOmciMessages()
+	go o.processOmciMessages(o.enableContext, &wg)
 
 	// send PON Port indications
 	for i, pon := range o.Pons {
@@ -312,7 +328,10 @@
 		o.channel <- msg
 
 		for _, onu := range o.Pons[i].Onus {
-			go onu.ProcessOnuMessages(stream, nil)
+			go onu.ProcessOnuMessages(o.enableContext, stream, nil)
+			if onu.InternalState.Current() != "initialized" {
+				continue
+			}
 			if err := onu.InternalState.Event("discover"); err != nil {
 				log.Errorf("Error discover ONU: %v", err)
 				return err
@@ -325,20 +344,34 @@
 	return nil
 }
 
-func (o *OltDevice) processOmciMessages() {
+func (o *OltDevice) processOmciMessages(ctx context.Context, wg *sync.WaitGroup) {
 	ch := omcisim.GetChannel()
 
 	oltLogger.Debug("Starting OMCI Indication Channel")
 
-	for message := range ch {
-		onuId := message.Data.OnuId
-		intfId := message.Data.IntfId
-		onu, err := o.FindOnuById(intfId, onuId)
-		if err != nil {
-			oltLogger.Errorf("Failed to find onu: %v", err)
+loop:
+	for {
+		select {
+		case <-ctx.Done():
+			oltLogger.Debug("OMCI processing canceled via context")
+			break loop
+		case message, ok := <-ch:
+			if !ok || ctx.Err() != nil {
+				oltLogger.Debug("OMCI processing canceled via channel close")
+				break loop
+			}
+			onuId := message.Data.OnuId
+			intfId := message.Data.IntfId
+			onu, err := o.FindOnuById(intfId, onuId)
+			if err != nil {
+				oltLogger.Errorf("Failed to find onu: %v", err)
+				continue
+			}
+			go onu.processOmciMessage(message)
 		}
-		go onu.processOmciMessage(message)
 	}
+
+	wg.Done()
 }
 
 // Helpers method
@@ -432,94 +465,120 @@
 }
 
 // processOltMessages handles messages received over the OpenOLT interface
-func (o *OltDevice) processOltMessages(stream openolt.Openolt_EnableIndicationServer, wg *sync.WaitGroup) {
+func (o *OltDevice) processOltMessages(ctx context.Context, stream openolt.Openolt_EnableIndicationServer, wg *sync.WaitGroup) {
 	oltLogger.Debug("Starting OLT Indication Channel")
-	for message := range o.channel {
+	ch := o.channel
 
-		oltLogger.WithFields(log.Fields{
-			"oltId":       o.ID,
-			"messageType": message.Type,
-		}).Trace("Received message")
-
-		switch message.Type {
-		case OltIndication:
-			msg, _ := message.Data.(OltIndicationMessage)
-			if msg.OperState == UP {
-				o.InternalState.Event("enable")
-				o.OperState.Event("enable")
-			} else if msg.OperState == DOWN {
-				o.InternalState.Event("disable")
-				o.OperState.Event("disable")
+loop:
+	for {
+		select {
+		case <-ctx.Done():
+			oltLogger.Debug("OLT Indication processing canceled via context")
+			break loop
+		case message, ok := <-ch:
+			if !ok || ctx.Err() != nil {
+				oltLogger.Debug("OLT Indication processing canceled via closed channel")
+				break loop
 			}
-			o.sendOltIndication(msg, stream)
-		case NniIndication:
-			msg, _ := message.Data.(NniIndicationMessage)
-			o.sendNniIndication(msg, stream)
-		case PonIndication:
-			msg, _ := message.Data.(PonIndicationMessage)
-			o.sendPonIndication(msg, stream)
-		default:
-			oltLogger.Warnf("Received unknown message data %v for type %v in OLT Channel", message.Data, message.Type)
-		}
 
+			oltLogger.WithFields(log.Fields{
+				"oltId":       o.ID,
+				"messageType": message.Type,
+			}).Trace("Received message")
+
+			switch message.Type {
+			case OltIndication:
+				msg, _ := message.Data.(OltIndicationMessage)
+				if msg.OperState == UP {
+					o.InternalState.Event("enable")
+					o.OperState.Event("enable")
+				} else if msg.OperState == DOWN {
+					o.InternalState.Event("disable")
+					o.OperState.Event("disable")
+				}
+				o.sendOltIndication(msg, stream)
+			case NniIndication:
+				msg, _ := message.Data.(NniIndicationMessage)
+				o.sendNniIndication(msg, stream)
+			case PonIndication:
+				msg, _ := message.Data.(PonIndicationMessage)
+				o.sendPonIndication(msg, stream)
+			default:
+				oltLogger.Warnf("Received unknown message data %v for type %v in OLT Channel", message.Data, message.Type)
+			}
+		}
 	}
 	wg.Done()
 	oltLogger.Warn("Stopped handling OLT Indication Channel")
 }
 
 // processNniPacketIns handles messages received over the NNI interface
-func (o *OltDevice) processNniPacketIns(stream openolt.Openolt_EnableIndicationServer, wg *sync.WaitGroup) {
+func (o *OltDevice) processNniPacketIns(ctx context.Context, stream openolt.Openolt_EnableIndicationServer, wg *sync.WaitGroup) {
 	oltLogger.WithFields(log.Fields{
 		"nniChannel": o.nniPktInChannel,
 	}).Debug("Started NNI Channel")
 	nniId := o.Nnis[0].ID // FIXME we are assuming we have only one NNI
-	for message := range o.nniPktInChannel {
-		oltLogger.Tracef("Received packets on NNI Channel")
 
-		onuMac, err := packetHandlers.GetDstMacAddressFromPacket(message.Pkt)
+	ch := o.nniPktInChannel
 
-		if err != nil {
-			log.WithFields(log.Fields{
-				"IntfType": "nni",
-				"IntfId":   nniId,
-				"Pkt":      message.Pkt.Data(),
-			}).Error("Can't find Dst MacAddress in packet")
-			return
-		}
+loop:
+	for {
+		select {
+		case <-ctx.Done():
+			oltLogger.Debug("NNI Indication processing canceled via context")
+			break loop
+		case message, ok := <-ch:
+			if !ok || ctx.Err() != nil {
+				oltLogger.Debug("NNI Indication processing canceled via channel closed")
+				break loop
+			}
+			oltLogger.Tracef("Received packets on NNI Channel")
 
-		onu, err := o.FindOnuByMacAddress(onuMac)
-		if err != nil {
-			log.WithFields(log.Fields{
-				"IntfType":   "nni",
-				"IntfId":     nniId,
-				"Pkt":        message.Pkt.Data(),
-				"MacAddress": onuMac.String(),
-			}).Error("Can't find ONU with MacAddress")
-			return
-		}
+			onuMac, err := packetHandlers.GetDstMacAddressFromPacket(message.Pkt)
 
-		doubleTaggedPkt, err := packetHandlers.PushDoubleTag(onu.STag, onu.CTag, message.Pkt)
-		if err != nil {
-			log.Error("Fail to add double tag to packet")
-		}
+			if err != nil {
+				log.WithFields(log.Fields{
+					"IntfType": "nni",
+					"IntfId":   nniId,
+					"Pkt":      message.Pkt.Data(),
+				}).Error("Can't find Dst MacAddress in packet")
+				return
+			}
 
-		data := &openolt.Indication_PktInd{PktInd: &openolt.PacketIndication{
-			IntfType: "nni",
-			IntfId:   nniId,
-			Pkt:      doubleTaggedPkt.Data()}}
-		if err := stream.Send(&openolt.Indication{Data: data}); err != nil {
+			onu, err := o.FindOnuByMacAddress(onuMac)
+			if err != nil {
+				log.WithFields(log.Fields{
+					"IntfType":   "nni",
+					"IntfId":     nniId,
+					"Pkt":        message.Pkt.Data(),
+					"MacAddress": onuMac.String(),
+				}).Error("Can't find ONU with MacAddress")
+				return
+			}
+
+			doubleTaggedPkt, err := packetHandlers.PushDoubleTag(onu.STag, onu.CTag, message.Pkt)
+			if err != nil {
+				log.Error("Fail to add double tag to packet")
+			}
+
+			data := &openolt.Indication_PktInd{PktInd: &openolt.PacketIndication{
+				IntfType: "nni",
+				IntfId:   nniId,
+				Pkt:      doubleTaggedPkt.Data()}}
+			if err := stream.Send(&openolt.Indication{Data: data}); err != nil {
+				oltLogger.WithFields(log.Fields{
+					"IntfType": data.PktInd.IntfType,
+					"IntfId":   nniId,
+					"Pkt":      doubleTaggedPkt.Data(),
+				}).Errorf("Fail to send PktInd indication: %v", err)
+			}
 			oltLogger.WithFields(log.Fields{
 				"IntfType": data.PktInd.IntfType,
 				"IntfId":   nniId,
 				"Pkt":      doubleTaggedPkt.Data(),
-			}).Errorf("Fail to send PktInd indication: %v", err)
+				"OnuSn":    onu.Sn(),
+			}).Tracef("Sent PktInd indication")
 		}
-		oltLogger.WithFields(log.Fields{
-			"IntfType": data.PktInd.IntfType,
-			"IntfId":   nniId,
-			"Pkt":      doubleTaggedPkt.Data(),
-			"OnuSn":    onu.Sn(),
-		}).Tracef("Sent PktInd indication")
 	}
 	wg.Done()
 	oltLogger.WithFields(log.Fields{