Merge "[VOL-2778] Restart Auth and DHCP"
diff --git a/internal/bbsim/api/onus_handler.go b/internal/bbsim/api/onus_handler.go
index fe42b35..332f237 100644
--- a/internal/bbsim/api/onus_handler.go
+++ b/internal/bbsim/api/onus_handler.go
@@ -282,19 +282,32 @@
 		return res, err
 	}
 
-	if err := onu.InternalState.Event("start_auth"); err != nil {
-		logger.WithFields(log.Fields{
-			"OnuId":  onu.ID,
-			"IntfId": onu.PonPortID,
-			"OnuSn":  onu.Sn(),
-		}).Errorf("Cannot restart authenticaton for ONU: %s", err.Error())
-		res.StatusCode = int32(codes.FailedPrecondition)
-		res.Message = err.Error()
-		return res, err
+	errors := []string{}
+	success := true
+
+	for _, s := range onu.Services {
+		service := s.(*devices.Service)
+		if service.NeedsEapol {
+			if err := service.EapolState.Event("start_auth"); err != nil {
+				logger.WithFields(log.Fields{
+					"OnuId":   onu.ID,
+					"IntfId":  onu.PonPortID,
+					"OnuSn":   onu.Sn(),
+					"Service": service.Name,
+				}).Errorf("Cannot restart authenticaton for Service: %s", err.Error())
+				errors = append(errors, fmt.Sprintf("%s: %s", service.Name, err.Error()))
+				success = false
+			}
+		}
 	}
 
-	res.StatusCode = int32(codes.OK)
-	res.Message = fmt.Sprintf("Authentication restarted for ONU %s.", onu.Sn())
+	if success {
+		res.StatusCode = int32(codes.OK)
+		res.Message = fmt.Sprintf("Authentication restarted for ONU %s.", onu.Sn())
+	} else {
+		res.StatusCode = int32(codes.FailedPrecondition)
+		res.Message = fmt.Sprintf("%v", errors)
+	}
 
 	return res, nil
 }
@@ -316,19 +329,33 @@
 		return res, err
 	}
 
-	if err := onu.InternalState.Event("start_dhcp"); err != nil {
-		logger.WithFields(log.Fields{
-			"OnuId":  onu.ID,
-			"IntfId": onu.PonPortID,
-			"OnuSn":  onu.Sn(),
-		}).Errorf("Cannot restart DHCP for ONU: %s", err.Error())
-		res.StatusCode = int32(codes.FailedPrecondition)
-		res.Message = err.Error()
-		return res, err
+	errors := []string{}
+	success := true
+
+	for _, s := range onu.Services {
+		service := s.(*devices.Service)
+		if service.NeedsDhcp {
+
+			if err := service.DHCPState.Event("start_dhcp"); err != nil {
+				logger.WithFields(log.Fields{
+					"OnuId":   onu.ID,
+					"IntfId":  onu.PonPortID,
+					"OnuSn":   onu.Sn(),
+					"Service": service.Name,
+				}).Errorf("Cannot restart DHCP for Service: %s", err.Error())
+				errors = append(errors, fmt.Sprintf("%s: %s", service.Name, err.Error()))
+				success = false
+			}
+		}
 	}
 
-	res.StatusCode = int32(codes.OK)
-	res.Message = fmt.Sprintf("DHCP restarted for ONU %s.", onu.Sn())
+	if success {
+		res.StatusCode = int32(codes.OK)
+		res.Message = fmt.Sprintf("DHCP restarted for ONU %s.", onu.Sn())
+	} else {
+		res.StatusCode = int32(codes.FailedPrecondition)
+		res.Message = fmt.Sprintf("%v", errors)
+	}
 
 	return res, nil
 }
diff --git a/internal/bbsim/devices/messageTypes.go b/internal/bbsim/devices/messageTypes.go
index d9ed620..cfe7200 100644
--- a/internal/bbsim/devices/messageTypes.go
+++ b/internal/bbsim/devices/messageTypes.go
@@ -34,6 +34,8 @@
 	OMCI              MessageType = 5
 	FlowAdd           MessageType = 6
 	FlowRemoved       MessageType = 18
+	StartEAPOL        MessageType = 7
+	StartDHCP         MessageType = 8
 	OnuPacketOut      MessageType = 9
 
 	// BBR messages
diff --git a/internal/bbsim/devices/olt.go b/internal/bbsim/devices/olt.go
index c0e7d3b..a9420de 100644
--- a/internal/bbsim/devices/olt.go
+++ b/internal/bbsim/devices/olt.go
@@ -114,6 +114,7 @@
 	if val, ok := ControlledActivationModes[options.BBSim.ControlledActivation]; ok {
 		olt.ControlledActivation = val
 	} else {
+		// FIXME throw an error if the ControlledActivation is not valid
 		oltLogger.Warn("Unknown ControlledActivation Mode given, running in Default mode")
 		olt.ControlledActivation = Default
 	}
@@ -1296,15 +1297,24 @@
 		"IntfId": onu.PonPortID,
 		"OnuId":  onu.ID,
 		"OnuSn":  onu.Sn(),
+		"Packet": hex.EncodeToString(onuPkt.Pkt),
 	}).Trace("Received OnuPacketOut")
 
 	rawpkt := gopacket.NewPacket(onuPkt.Pkt, layers.LayerTypeEthernet, gopacket.Default)
-	pktType, _ := packetHandlers.IsEapolOrDhcp(rawpkt)
+	pktType, err := packetHandlers.IsEapolOrDhcp(rawpkt)
+	if err != nil {
+		onuLogger.WithFields(log.Fields{
+			"IntfId": onu.PonPortID,
+			"OnuId":  onu.ID,
+			"OnuSn":  onu.Sn(),
+			"Pkt":    rawpkt.Data(),
+		}).Error("Can't find pktType in packet, droppint it")
+		return new(openolt.Empty), nil
+	}
 
 	pktMac, err := packetHandlers.GetDstMacAddressFromPacket(rawpkt)
-
 	if err != nil {
-		log.WithFields(log.Fields{
+		onuLogger.WithFields(log.Fields{
 			"IntfId": onu.PonPortID,
 			"OnuId":  onu.ID,
 			"OnuSn":  onu.Sn(),
diff --git a/internal/bbsim/devices/onu.go b/internal/bbsim/devices/onu.go
index 458678a..80308da 100644
--- a/internal/bbsim/devices/onu.go
+++ b/internal/bbsim/devices/onu.go
@@ -182,8 +182,7 @@
 
 				// Once the ONU is enabled start listening for packets
 				for _, s := range o.Services {
-					s.Initialize()
-					go s.HandlePackets(o.PonPort.Olt.OpenoltStream)
+					s.Initialize(o.PonPort.Olt.OpenoltStream)
 				}
 			},
 			"enter_disabled": func(event *fsm.Event) {
@@ -320,7 +319,7 @@
 
 				msg, _ := message.Data.(OnuPacketMessage)
 
-				log.WithFields(log.Fields{
+				onuLogger.WithFields(log.Fields{
 					"IntfId":  msg.IntfId,
 					"OnuId":   msg.OnuId,
 					"pktType": msg.Type,
@@ -340,6 +339,13 @@
 
 				service.PacketCh <- msg
 
+				onuLogger.WithFields(log.Fields{
+					"IntfId":      msg.IntfId,
+					"OnuId":       msg.OnuId,
+					"pktType":     msg.Type,
+					"ServiceName": service.Name,
+				}).Info("OnuPacketOut Sent on Service Packet channel")
+
 			case OnuPacketIn:
 				// NOTE we only receive BBR packets here.
 				// Eapol.HandleNextPacket can handle both BBSim and BBr cases so the call is the same
@@ -647,7 +653,7 @@
 		o.storePortNumber(uint32(msg.Flow.PortNo))
 
 		for _, s := range o.Services {
-			s.HandleAuth(o.PonPort.Olt.OpenoltStream)
+			s.HandleAuth()
 		}
 	} else if msg.Flow.Classifier.EthType == uint32(layers.EthernetTypeIPv4) &&
 		msg.Flow.Classifier.SrcPort == uint32(68) &&
@@ -655,7 +661,7 @@
 		(msg.Flow.Classifier.OPbits == 0 || msg.Flow.Classifier.OPbits == 255) {
 
 		for _, s := range o.Services {
-			s.HandleDhcp(o.PonPort.Olt.OpenoltStream, int(msg.Flow.Classifier.OVid))
+			s.HandleDhcp(int(msg.Flow.Classifier.OVid))
 		}
 	}
 }
diff --git a/internal/bbsim/devices/service_test.go b/internal/bbsim/devices/service_test.go
index 93ea368..2191178 100644
--- a/internal/bbsim/devices/service_test.go
+++ b/internal/bbsim/devices/service_test.go
@@ -22,6 +22,7 @@
 	"github.com/stretchr/testify/assert"
 	"net"
 	"testing"
+	"time"
 )
 
 type mockService struct {
@@ -31,20 +32,20 @@
 	HandlePacketsCallCount int
 }
 
-func (s *mockService) HandleAuth(stream types.Stream) {
+func (s *mockService) HandleAuth() {
 	s.HandleAuthCallCount = s.HandleAuthCallCount + 1
 }
 
-func (s *mockService) HandleDhcp(stream types.Stream, cTag int) {
+func (s *mockService) HandleDhcp(cTag int) {
 	s.HandleDhcpCallCount = s.HandleDhcpCallCount + 1
 }
 
-func (s *mockService) HandlePackets(stream types.Stream) {
+func (s *mockService) HandlePackets() {
 	s.HandlePacketsCallCount = s.HandlePacketsCallCount + 1
 }
 
-func (s *mockService) Initialize() {}
-func (s *mockService) Disable()    {}
+func (s *mockService) Initialize(stream types.Stream) {}
+func (s *mockService) Disable()                       {}
 
 // test the internalState transitions
 func TestService_InternalState(t *testing.T) {
@@ -57,13 +58,24 @@
 	assert.Nil(t, err)
 
 	assert.Empty(t, s.PacketCh)
-	s.Initialize()
+	s.Initialize(&mockStream{})
 
+	// check that channels have been created
 	assert.NotNil(t, s.PacketCh)
+	assert.NotNil(t, s.Channel)
+
+	// set EAPOL and DHCP states to something else
+	s.EapolState.SetState("eap_response_success_received")
+	s.DHCPState.SetState("dhcp_ack_received")
 
 	s.Disable()
+	// make sure the EAPOL and DHCP states have been reset after disable
 	assert.Equal(t, "created", s.EapolState.Current())
 	assert.Equal(t, "created", s.DHCPState.Current())
+
+	// make sure the channel have been closed
+	assert.Nil(t, s.Channel)
+	assert.Nil(t, s.PacketCh)
 }
 
 // make sure that if the service does not need EAPOL we're not sending any packet
@@ -80,8 +92,10 @@
 		Calls:   make(map[int]*openolt.Indication),
 		channel: make(chan int, 10),
 	}
+	s.Initialize(stream)
 
-	s.HandleAuth(stream)
+	s.HandleAuth()
+	time.Sleep(1 * time.Second)
 
 	// if the service does not need EAPOL we don't expect any packet to be generated
 	assert.Equal(t, stream.CallCount, 0)
@@ -103,8 +117,10 @@
 	stream := &mockStream{
 		Calls: make(map[int]*openolt.Indication),
 	}
+	s.Initialize(stream)
 
-	s.HandleAuth(stream)
+	s.HandleAuth()
+	time.Sleep(1 * time.Second)
 
 	// if the service does not need EAPOL we don't expect any packet to be generated
 	assert.Equal(t, stream.CallCount, 1)
@@ -126,8 +142,10 @@
 	stream := &mockStream{
 		Calls: make(map[int]*openolt.Indication),
 	}
+	s.Initialize(stream)
 
-	s.HandleDhcp(stream, 900)
+	s.HandleDhcp(900)
+	time.Sleep(1 * time.Second)
 
 	assert.Equal(t, stream.CallCount, 0)
 
@@ -149,9 +167,11 @@
 	stream := &mockStream{
 		Calls: make(map[int]*openolt.Indication),
 	}
+	s.Initialize(stream)
 
 	// NOTE that the c_tag is different from the one configured in the service
-	s.HandleDhcp(stream, 800)
+	s.HandleDhcp(800)
+	time.Sleep(1 * time.Second)
 
 	assert.Equal(t, stream.CallCount, 0)
 
@@ -172,11 +192,11 @@
 	stream := &mockStream{
 		Calls: make(map[int]*openolt.Indication),
 	}
+	s.Initialize(stream)
 
-	s.HandleDhcp(stream, 900)
+	s.HandleDhcp(900)
+	time.Sleep(1 * time.Second)
 
-	assert.Equal(t, stream.CallCount, 1)
-
-	// state should not change
-	assert.Equal(t, s.DHCPState.Current(), "dhcp_discovery_sent")
+	assert.Equal(t, 1, stream.CallCount)
+	assert.Equal(t, "dhcp_discovery_sent", s.DHCPState.Current())
 }
diff --git a/internal/bbsim/devices/services.go b/internal/bbsim/devices/services.go
index 360bd4a..c8e9461 100644
--- a/internal/bbsim/devices/services.go
+++ b/internal/bbsim/devices/services.go
@@ -31,11 +31,11 @@
 })
 
 type ServiceIf interface {
-	HandlePackets(stream bbsimTypes.Stream)        // start listening on the PacketCh
-	HandleAuth(stream bbsimTypes.Stream)           // Sends the EapoStart packet
-	HandleDhcp(stream bbsimTypes.Stream, cTag int) // Sends the DHCPDiscover packet
+	HandlePackets()      // start listening on the PacketCh
+	HandleAuth()         // Sends the EapoStart packet
+	HandleDhcp(cTag int) // Sends the DHCPDiscover packet
 
-	Initialize()
+	Initialize(stream bbsimTypes.Stream)
 	Disable()
 }
 
@@ -61,7 +61,9 @@
 	InternalState *fsm.FSM
 	EapolState    *fsm.FSM
 	DHCPState     *fsm.FSM
-	PacketCh      chan OnuPacketMessage
+	Channel       chan Message          // drive Service lifecycle
+	PacketCh      chan OnuPacketMessage // handle packets
+	Stream        bbsimTypes.Stream     // the gRPC stream to communicate with the adapter, created in the initialize transition
 }
 
 func NewService(name string, hwAddress net.HardwareAddr, onu *Onu, cTag int, sTag int,
@@ -97,7 +99,19 @@
 				service.logStateChange("InternalState", e.Src, e.Dst)
 			},
 			"enter_initialized": func(e *fsm.Event) {
+
+				stream, ok := e.Args[0].(bbsimTypes.Stream)
+				if !ok {
+					serviceLogger.Fatal("initialize invoke with wrong arguments")
+				}
+
+				service.Stream = stream
+
 				service.PacketCh = make(chan OnuPacketMessage)
+				service.Channel = make(chan Message)
+
+				go service.HandlePackets()
+				go service.HandleChannel()
 			},
 			"enter_disabled": func(e *fsm.Event) {
 				// reset the state machines
@@ -106,6 +120,10 @@
 
 				// stop listening for packets
 				close(service.PacketCh)
+				close(service.Channel)
+
+				service.PacketCh = nil
+				service.Channel = nil
 			},
 		},
 	)
@@ -124,6 +142,12 @@
 			"enter_state": func(e *fsm.Event) {
 				service.logStateChange("EapolState", e.Src, e.Dst)
 			},
+			"before_start_auth": func(e *fsm.Event) {
+				msg := Message{
+					Type: StartEAPOL,
+				}
+				service.Channel <- msg
+			},
 		},
 	)
 
@@ -142,13 +166,20 @@
 			"enter_state": func(e *fsm.Event) {
 				service.logStateChange("DHCPState", e.Src, e.Dst)
 			},
+			"before_start_dhcp": func(e *fsm.Event) {
+				msg := Message{
+					Type: StartDHCP,
+				}
+				service.Channel <- msg
+			},
 		},
 	)
 
 	return &service, nil
 }
 
-func (s *Service) HandleAuth(stream bbsimTypes.Stream) {
+// HandleAuth is used to start EAPOL for a particular Service when the corresponding flow is received
+func (s *Service) HandleAuth() {
 
 	if !s.NeedsEapol {
 		serviceLogger.WithFields(log.Fields{
@@ -161,8 +192,6 @@
 		return
 	}
 
-	// TODO check if the EAPOL flow was received before starting auth
-
 	if err := s.EapolState.Event("start_auth"); err != nil {
 		serviceLogger.WithFields(log.Fields{
 			"OnuId":  s.Onu.ID,
@@ -171,21 +200,11 @@
 			"Name":   s.Name,
 			"err":    err.Error(),
 		}).Error("Can't start auth for this Service")
-	} else {
-		if err := s.handleEapolStart(stream); err != nil {
-			serviceLogger.WithFields(log.Fields{
-				"OnuId":  s.Onu.ID,
-				"IntfId": s.Onu.PonPortID,
-				"OnuSn":  s.Onu.Sn(),
-				"Name":   s.Name,
-				"err":    err,
-			}).Error("Error while sending EapolStart packet")
-			_ = s.EapolState.Event("auth_failed")
-		}
 	}
 }
 
-func (s *Service) HandleDhcp(stream bbsimTypes.Stream, cTag int) {
+// HandleDhcp is used to start DHCP for a particular Service when the corresponding flow is received
+func (s *Service) HandleDhcp(cTag int) {
 
 	if s.CTag != cTag {
 		serviceLogger.WithFields(log.Fields{
@@ -218,21 +237,10 @@
 			"Name":   s.Name,
 			"err":    err.Error(),
 		}).Error("Can't start DHCP for this Service")
-	} else {
-		if err := s.handleDHCPStart(stream); err != nil {
-			serviceLogger.WithFields(log.Fields{
-				"OnuId":  s.Onu.ID,
-				"IntfId": s.Onu.PonPortID,
-				"OnuSn":  s.Onu.Sn(),
-				"Name":   s.Name,
-				"err":    err,
-			}).Error("Error while sending DHCPDiscovery packet")
-			_ = s.DHCPState.Event("dhcp_failed")
-		}
 	}
 }
 
-func (s *Service) HandlePackets(stream bbsimTypes.Stream) {
+func (s *Service) HandlePackets() {
 	serviceLogger.WithFields(log.Fields{
 		"OnuId":     s.Onu.ID,
 		"IntfId":    s.Onu.PonPortID,
@@ -261,15 +269,60 @@
 		}).Trace("Received message on Service Packet Channel")
 
 		if msg.Type == packetHandlers.EAPOL {
-			eapol.HandleNextPacket(msg.OnuId, msg.IntfId, s.GemPort, s.Onu.Sn(), s.Onu.PortNo, s.EapolState, msg.Packet, stream, nil)
+			eapol.HandleNextPacket(msg.OnuId, msg.IntfId, s.GemPort, s.Onu.Sn(), s.Onu.PortNo, s.EapolState, msg.Packet, s.Stream, nil)
 		} else if msg.Type == packetHandlers.DHCP {
-			_ = dhcp.HandleNextPacket(s.Onu.PonPort.Olt.ID, s.Onu.ID, s.Onu.PonPortID, s.Name, s.Onu.Sn(), s.Onu.PortNo, s.CTag, s.GemPort, s.HwAddress, s.DHCPState, msg.Packet, s.UsPonCTagPriority, stream)
+			_ = dhcp.HandleNextPacket(s.Onu.PonPort.Olt.ID, s.Onu.ID, s.Onu.PonPortID, s.Name, s.Onu.Sn(), s.Onu.PortNo, s.CTag, s.GemPort, s.HwAddress, s.DHCPState, msg.Packet, s.UsPonCTagPriority, s.Stream)
 		}
 	}
 }
 
-func (s *Service) Initialize() {
-	if err := s.InternalState.Event("initialized"); err != nil {
+func (s *Service) HandleChannel() {
+	serviceLogger.WithFields(log.Fields{
+		"OnuId":     s.Onu.ID,
+		"IntfId":    s.Onu.PonPortID,
+		"OnuSn":     s.Onu.Sn(),
+		"GemPortId": s.GemPort,
+		"Name":      s.Name,
+	}).Debug("Listening on Service Channel")
+
+	defer func() {
+		serviceLogger.WithFields(log.Fields{
+			"OnuId":     s.Onu.ID,
+			"IntfId":    s.Onu.PonPortID,
+			"OnuSn":     s.Onu.Sn(),
+			"GemPortId": s.GemPort,
+			"Name":      s.Name,
+		}).Debug("Done Listening on Service Channel")
+	}()
+	for msg := range s.Channel {
+		if msg.Type == StartEAPOL {
+			if err := s.handleEapolStart(s.Stream); err != nil {
+				serviceLogger.WithFields(log.Fields{
+					"OnuId":  s.Onu.ID,
+					"IntfId": s.Onu.PonPortID,
+					"OnuSn":  s.Onu.Sn(),
+					"Name":   s.Name,
+					"err":    err,
+				}).Error("Error while sending EapolStart packet")
+				_ = s.EapolState.Event("auth_failed")
+			}
+		} else if msg.Type == StartDHCP {
+			if err := s.handleDHCPStart(s.Stream); err != nil {
+				serviceLogger.WithFields(log.Fields{
+					"OnuId":  s.Onu.ID,
+					"IntfId": s.Onu.PonPortID,
+					"OnuSn":  s.Onu.Sn(),
+					"Name":   s.Name,
+					"err":    err,
+				}).Error("Error while sending DHCPDiscovery packet")
+				_ = s.DHCPState.Event("dhcp_failed")
+			}
+		}
+	}
+}
+
+func (s *Service) Initialize(stream bbsimTypes.Stream) {
+	if err := s.InternalState.Event("initialized", stream); err != nil {
 		serviceLogger.WithFields(log.Fields{
 			"OnuId":  s.Onu.ID,
 			"IntfId": s.Onu.PonPortID,
@@ -293,14 +346,14 @@
 }
 
 func (s *Service) handleEapolStart(stream bbsimTypes.Stream) error {
-
+	// TODO fail Auth if it does not succeed in 30 seconds
 	serviceLogger.WithFields(log.Fields{
 		"OnuId":   s.Onu.ID,
 		"IntfId":  s.Onu.PonPortID,
 		"OnuSn":   s.Onu.Sn(),
 		"GemPort": s.GemPort,
 		"Name":    s.Name,
-	}).Debugf("handleEapolStart")
+	}).Trace("handleEapolStart")
 
 	if err := eapol.SendEapStart(s.Onu.ID, s.Onu.PonPortID, s.Onu.Sn(), s.Onu.PortNo,
 		s.HwAddress, s.GemPort, s.EapolState, stream); err != nil {
@@ -317,7 +370,7 @@
 }
 
 func (s *Service) handleDHCPStart(stream bbsimTypes.Stream) error {
-
+	// TODO fail DHCP if it does not succeed in 30 seconds
 	serviceLogger.WithFields(log.Fields{
 		"OnuId":     s.Onu.ID,
 		"IntfId":    s.Onu.PonPortID,
diff --git a/internal/bbsim/responders/dhcp/dhcp.go b/internal/bbsim/responders/dhcp/dhcp.go
index 9f8b035..5fc8a64 100644
--- a/internal/bbsim/responders/dhcp/dhcp.go
+++ b/internal/bbsim/responders/dhcp/dhcp.go
@@ -380,6 +380,7 @@
 			"OnuSn":       serialNumber,
 			"ServiceName": serviceName,
 		}).Errorf("Error while transitioning ONU State %v", err)
+		return err
 	}
 	return nil
 }