[SEBA-873] add reboot olt support

Change-Id: I1570d05313661a6d66e1596b9f9a1a1cc17d1a73
diff --git a/internal/bbsim/devices/olt.go b/internal/bbsim/devices/olt.go
index e810ce1..3264bc3 100644
--- a/internal/bbsim/devices/olt.go
+++ b/internal/bbsim/devices/olt.go
@@ -22,6 +22,7 @@
 	"fmt"
 	"net"
 	"sync"
+	"time"
 
 	"github.com/google/gopacket"
 	"github.com/google/gopacket/layers"
@@ -33,6 +34,7 @@
 	"github.com/opencord/voltha-protos/v2/go/tech_profile"
 	log "github.com/sirupsen/logrus"
 	"google.golang.org/grpc"
+	"google.golang.org/grpc/reflection"
 )
 
 var oltLogger = log.WithFields(log.Fields{
@@ -48,9 +50,7 @@
 	NumOnuPerPon    int
 	InternalState   *fsm.FSM
 	channel         chan Message
-	oltDoneChannel  *chan bool
-	apiDoneChannel  *chan bool
-	nniPktInChannel chan *bbsim.PacketMsg
+	nniPktInChannel chan *bbsim.PacketMsg // packets coming in from the NNI and going to VOLTHA
 
 	Delay int
 
@@ -62,12 +62,13 @@
 }
 
 var olt OltDevice
+var oltServer *grpc.Server
 
 func GetOLT() *OltDevice {
 	return &olt
 }
 
-func CreateOLT(oltId int, nni int, pon int, onuPerPon int, sTag int, cTagInit int, oltDoneChannel *chan bool, apiDoneChannel *chan bool, auth bool, dhcp bool, delay int, isMock bool) *OltDevice {
+func CreateOLT(oltId int, nni int, pon int, onuPerPon int, sTag int, cTagInit int, auth bool, dhcp bool, delay int, isMock bool) *OltDevice {
 	oltLogger.WithFields(log.Fields{
 		"ID":           oltId,
 		"NumNni":       nni,
@@ -81,16 +82,12 @@
 		OperState: getOperStateFSM(func(e *fsm.Event) {
 			oltLogger.Debugf("Changing OLT OperState from %s to %s", e.Src, e.Dst)
 		}),
-		NumNni:          nni,
-		NumPon:          pon,
-		NumOnuPerPon:    onuPerPon,
-		Pons:            []*PonPort{},
-		Nnis:            []*NniPort{},
-		channel:         make(chan Message),
-		oltDoneChannel:  oltDoneChannel,
-		apiDoneChannel:  apiDoneChannel,
-		nniPktInChannel: make(chan *bbsim.PacketMsg, 1024), // packets coming in from the NNI and going to VOLTHA
-		Delay:           delay,
+		NumNni:       nni,
+		NumPon:       pon,
+		NumOnuPerPon: onuPerPon,
+		Pons:         []*PonPort{},
+		Nnis:         []*NniPort{},
+		Delay:        delay,
 	}
 
 	// OLT State machine
@@ -98,13 +95,16 @@
 	olt.InternalState = fsm.NewFSM(
 		"created",
 		fsm.Events{
-			{Name: "enable", Src: []string{"created"}, Dst: "enabled"},
+			{Name: "initialize", Src: []string{"disabled", "created"}, Dst: "initialized"},
+			{Name: "enable", Src: []string{"initialized", "disabled"}, Dst: "enabled"},
 			{Name: "disable", Src: []string{"enabled"}, Dst: "disabled"},
 		},
 		fsm.Callbacks{
 			"enter_state": func(e *fsm.Event) {
 				oltLogger.Debugf("Changing OLT InternalState from %s to %s", e.Src, e.Dst)
 			},
+			"enter_disabled":    func(e *fsm.Event) { olt.disableOlt() },
+			"enter_initialized": func(e *fsm.Event) { olt.InitOlt() },
 		},
 	)
 
@@ -143,16 +143,96 @@
 
 		olt.Pons = append(olt.Pons, &p)
 	}
+
+	if err := olt.InternalState.Event("initialize"); err != nil {
+		log.Errorf("Error initializing OLT: %v", err)
+		return nil
+	}
+
 	return &olt
 }
 
-// this function start the OLT gRPC server and blocks until it's done
-func StartOlt(olt *OltDevice, group *sync.WaitGroup) {
-	newOltServer(*olt)
-	group.Done()
+func (o *OltDevice) InitOlt() error {
+
+	if oltServer == nil {
+		oltServer, _ = newOltServer()
+	} else {
+		oltLogger.Warn("OLT server already running.")
+	}
+
+	// create new channel for processOltMessages Go routine
+	o.channel = make(chan Message)
+
+	o.nniPktInChannel = make(chan *bbsim.PacketMsg, 1024)
+	// FIXME we are assuming we have only one NNI
+	if o.Nnis[0] != nil {
+		ch, err := o.Nnis[0].NewVethChan()
+		if err == nil {
+			o.nniPktInChannel = ch
+		} else {
+			log.Errorf("Error getting NNI channel: %v", err)
+		}
+	}
+
+	for i := range olt.Pons {
+		for _, onu := range olt.Pons[i].Onus {
+			if err := onu.InternalState.Event("initialize"); err != nil {
+				log.Errorf("Error initializing ONU: %v", err)
+				return err
+			}
+		}
+	}
+
+	return nil
 }
 
-func newOltServer(o OltDevice) error {
+// callback for disable state entry
+func (o *OltDevice) disableOlt() error {
+
+	// disable all onus
+	for i := range o.Pons {
+		for _, onu := range o.Pons[i].Onus {
+			// NOTE order of these is important.
+			onu.OperState.Event("disable")
+			onu.InternalState.Event("disable")
+		}
+	}
+
+	// TODO handle hard poweroff (i.e. no indications sent to Voltha) vs soft poweroff
+	if err := StopOltServer(); err != nil {
+		return err
+	}
+
+	// terminate the OLT's processOltMessages go routine
+	close(o.channel)
+	// terminate the OLT's processNniPacketIns go routine
+	close(o.nniPktInChannel)
+	return nil
+}
+
+func (o *OltDevice) RestartOLT() error {
+	oltLogger.Infof("Simulating OLT restart... (%ds)", OltRebootDelay)
+
+	// transition internal state to disable
+	if !o.InternalState.Is("disabled") {
+		if err := o.InternalState.Event("disable"); err != nil {
+			log.Errorf("Error disabling OLT: %v", err)
+			return err
+		}
+	}
+
+	time.Sleep(OltRebootDelay * time.Second)
+
+	if err := o.InternalState.Event("initialize"); err != nil {
+		log.Errorf("Error initializing OLT: %v", err)
+		return err
+	}
+	oltLogger.Info("OLT restart completed")
+	return nil
+}
+
+// newOltServer launches a new grpc server for OpenOLT
+func newOltServer() (*grpc.Server, error) {
 	// TODO make configurable
 	address := "0.0.0.0:50060"
 	lis, err := net.Listen("tcp", address)
@@ -160,51 +240,50 @@
 		oltLogger.Fatalf("OLT failed to listen: %v", err)
 	}
 	grpcServer := grpc.NewServer()
+
+	o := GetOLT()
 	openolt.RegisterOpenoltServer(grpcServer, o)
 
-	wg := sync.WaitGroup{}
-	wg.Add(1)
+	reflection.Register(grpcServer)
 
 	go grpcServer.Serve(lis)
 	oltLogger.Debugf("OLT Listening on: %v", address)
 
-	for {
-		_, ok := <-*o.oltDoneChannel
-		if !ok {
-			// if the olt Channel is closed, stop the gRPC server
-			log.Warnf("Stopping OLT gRPC server")
-			grpcServer.Stop()
-			wg.Done()
-			break
-		}
+	return grpcServer, nil
+}
+
+// StopOltServer stops the OpenOLT grpc server
+func StopOltServer() error {
+	// TODO handle poweroff vs graceful shutdown
+	if oltServer != nil {
+		log.Warnf("Stopping OLT gRPC server")
+		oltServer.Stop()
+		oltServer = nil
 	}
-
-	wg.Wait()
-
 	return nil
 }
 
 // Device Methods
 
-func (o OltDevice) Enable(stream openolt.Openolt_EnableIndicationServer) error {
-
+// Enable implements the OpenOLT EnableIndicationServer functionality
+func (o *OltDevice) Enable(stream openolt.Openolt_EnableIndicationServer) error {
 	oltLogger.Debug("Enable OLT called")
 
 	wg := sync.WaitGroup{}
 	wg.Add(2)
 
-	// create a Channel for all the OLT events
-	go o.processOltMessages(stream)
-	go o.processNniPacketIns(stream)
+	// create Go routine to process all OLT events
+	go o.processOltMessages(stream, &wg)
+	go o.processNniPacketIns(stream, &wg)
 
 	// enable the OLT
-	olt_msg := Message{
+	oltMsg := Message{
 		Type: OltIndication,
 		Data: OltIndicationMessage{
 			OperState: UP,
 		},
 	}
-	o.channel <- olt_msg
+	o.channel <- oltMsg
 
 	// send NNI Port Indications
 	for _, nni := range o.Nnis {
@@ -217,9 +296,11 @@
 		}
 		o.channel <- msg
 	}
+
 	go o.processOmciMessages()
+
 	// send PON Port indications
-	for _, pon := range o.Pons {
+	for i, pon := range o.Pons {
 		msg := Message{
 			Type: PonIndication,
 			Data: PonIndicationMessage{
@@ -229,29 +310,24 @@
 		}
 		o.channel <- msg
 
-		for _, onu := range pon.Onus {
+		for _, onu := range o.Pons[i].Onus {
 			go onu.ProcessOnuMessages(stream, nil)
-			// FIXME move the message generation in the state transition
-			// from here only invoke the state transition
-			msg := Message{
-				Type: OnuDiscIndication,
-				Data: OnuDiscIndicationMessage{
-					Onu:       onu,
-					OperState: UP,
-				},
+			if err := onu.InternalState.Event("discover"); err != nil {
+				log.Errorf("Error discover ONU: %v", err)
+				return err
 			}
-			onu.Channel <- msg
 		}
 	}
 
+	oltLogger.Warn("Enable OLT Done")
 	wg.Wait()
 	return nil
 }
 
-func (o OltDevice) processOmciMessages() {
+func (o *OltDevice) processOmciMessages() {
 	ch := omcisim.GetChannel()
 
-	oltLogger.Debug("Started OMCI Indication Channel")
+	oltLogger.Debug("Starting OMCI Indication Channel")
 
 	for message := range ch {
 		onuId := message.Data.OnuId
@@ -284,7 +360,7 @@
 	return nil, errors.New(fmt.Sprintf("Cannot find NniPort with id %d in OLT %d", id, o.ID))
 }
 
-func (o OltDevice) sendOltIndication(msg OltIndicationMessage, stream openolt.Openolt_EnableIndicationServer) {
+func (o *OltDevice) sendOltIndication(msg OltIndicationMessage, stream openolt.Openolt_EnableIndicationServer) {
 	data := &openolt.Indication_OltInd{OltInd: &openolt.OltIndication{OperState: msg.OperState.String()}}
 	if err := stream.Send(&openolt.Indication{Data: data}); err != nil {
 		oltLogger.Errorf("Failed to send Indication_OltInd: %v", err)
@@ -295,7 +371,7 @@
 	}).Debug("Sent Indication_OltInd")
 }
 
-func (o OltDevice) sendNniIndication(msg NniIndicationMessage, stream openolt.Openolt_EnableIndicationServer) {
+func (o *OltDevice) sendNniIndication(msg NniIndicationMessage, stream openolt.Openolt_EnableIndicationServer) {
 	nni, _ := o.getNniById(msg.NniPortID)
 	nni.OperState.Event("enable")
 	// NOTE Operstate may need to be an integer
@@ -316,7 +392,7 @@
 	}).Debug("Sent Indication_IntfOperInd for NNI")
 }
 
-func (o OltDevice) sendPonIndication(msg PonIndicationMessage, stream openolt.Openolt_EnableIndicationServer) {
+func (o *OltDevice) sendPonIndication(msg PonIndicationMessage, stream openolt.Openolt_EnableIndicationServer) {
 	pon, _ := o.GetPonById(msg.PonPortID)
 	pon.OperState.Event("enable")
 	discoverData := &openolt.Indication_IntfInd{IntfInd: &openolt.IntfIndication{
@@ -350,8 +426,9 @@
 	}).Debug("Sent Indication_IntfOperInd for PON")
 }
 
-func (o OltDevice) processOltMessages(stream openolt.Openolt_EnableIndicationServer) {
-	oltLogger.Debug("Started OLT Indication Channel")
+// processOltMessages handles messages received over the OpenOLT interface
+func (o *OltDevice) processOltMessages(stream openolt.Openolt_EnableIndicationServer, wg *sync.WaitGroup) {
+	oltLogger.Debug("Starting OLT Indication Channel")
 	for message := range o.channel {
 
 		oltLogger.WithFields(log.Fields{
@@ -381,9 +458,12 @@
 		}
 
 	}
+	wg.Done()
+	oltLogger.Warn("Stopped handling OLT Indication Channel")
 }
 
-func (o OltDevice) processNniPacketIns(stream openolt.Openolt_EnableIndicationServer) {
+// processNniPacketIns handles messages received over the NNI interface
+func (o *OltDevice) processNniPacketIns(stream openolt.Openolt_EnableIndicationServer, wg *sync.WaitGroup) {
 	oltLogger.WithFields(log.Fields{
 		"nniChannel": o.nniPktInChannel,
 	}).Debug("Started NNI Channel")
@@ -436,6 +516,10 @@
 			"OnuSn":    onu.Sn(),
 		}).Tracef("Sent PktInd indication")
 	}
+	wg.Done()
+	oltLogger.WithFields(log.Fields{
+		"nniChannel": o.nniPktInChannel,
+	}).Warn("Stopped handling NNI Channel")
 }
 
 // returns an ONU with a given Serial Number
@@ -527,13 +611,13 @@
 
 func (o OltDevice) DisableOlt(context.Context, *openolt.Empty) (*openolt.Empty, error) {
 	// NOTE when we disable the OLT should we disable NNI, PONs and ONUs altogether?
-	olt_msg := Message{
+	oltMsg := Message{
 		Type: OltIndication,
 		Data: OltIndicationMessage{
 			OperState: DOWN,
 		},
 	}
-	o.channel <- olt_msg
+	o.channel <- oltMsg
 	return new(openolt.Empty), nil
 }
 
@@ -542,7 +626,7 @@
 	return new(openolt.Empty), nil
 }
 
-func (o OltDevice) EnableIndication(_ *openolt.Empty, stream openolt.Openolt_EnableIndicationServer) error {
+func (o *OltDevice) EnableIndication(_ *openolt.Empty, stream openolt.Openolt_EnableIndicationServer) error {
 	oltLogger.WithField("oltId", o.ID).Info("OLT receives EnableIndication call from VOLTHA")
 	o.Enable(stream)
 	return nil
@@ -703,11 +787,8 @@
 }
 
 func (o OltDevice) Reboot(context.Context, *openolt.Empty) (*openolt.Empty, error) {
-	defer func() {
-		oltLogger.Info("Shutting Down")
-		close(*o.oltDoneChannel)
-		close(*o.apiDoneChannel)
-	}()
+	oltLogger.Info("Shutting down")
+	o.RestartOLT()
 	return new(openolt.Empty), nil
 }
 
@@ -719,7 +800,7 @@
 func (o OltDevice) UplinkPacketOut(context context.Context, packet *openolt.UplinkPacket) (*openolt.Empty, error) {
 	pkt := gopacket.NewPacket(packet.Pkt, layers.LayerTypeEthernet, gopacket.Default)
 
-	sendNniPacket(pkt)
+	o.Nnis[0].sendNniPacket(pkt) // FIXME we are assuming we have only one NNI
 	// NOTE should we return an error if sendNniPakcet fails?
 	return new(openolt.Empty), nil
 }