[VOL-3583] Reject unexpected EAPOL packets

Change-Id: I7c164fb2815e2ab573ab3f466a373cb75c61ac86
diff --git a/internal/bbsim/responders/eapol/eapol.go b/internal/bbsim/responders/eapol/eapol.go
index e341cef..afaee68 100644
--- a/internal/bbsim/responders/eapol/eapol.go
+++ b/internal/bbsim/responders/eapol/eapol.go
@@ -36,6 +36,24 @@
 
 var eapolVersion uint8 = 1
 
+// constants for the EAPOL state machine states and events
+const (
+	StateCreated                 = "created"
+	StateAuthStarted             = "auth_started"
+	StateStartSent               = "eap_start_sent"
+	StateResponseIdentitySent    = "eap_response_identity_sent"
+	StateResponseChallengeSent   = "eap_response_challenge_sent"
+	StateResponseSuccessReceived = "eap_response_success_received"
+	StateAuthFailed              = "auth_failed"
+
+	EventStartAuth               = "start_auth"
+	EventStartSent               = "eap_start_sent"
+	EventResponseIdentitySent    = "eap_response_identity_sent"
+	EventResponseChallengeSent   = "eap_response_challenge_sent"
+	EventResponseSuccessReceived = "eap_response_success_received"
+	EventAuthFailed              = "auth_failed"
+)
+
 func sendEapolPktIn(msg bbsim.ByteMsg, portNo uint32, gemid uint32, stream bbsim.Stream) error {
 	// FIXME unify sendDHCPPktIn and sendEapolPktIn methods
 
@@ -181,7 +199,7 @@
 }
 
 func updateAuthFailed(onuId uint32, ponPortId uint32, serialNumber string, onuStateMachine *fsm.FSM) error {
-	if err := onuStateMachine.Event("auth_failed"); err != nil {
+	if err := onuStateMachine.Event(EventAuthFailed); err != nil {
 		eapolLogger.WithFields(log.Fields{
 			"OnuId":  onuId,
 			"IntfId": ponPortId,
@@ -253,7 +271,7 @@
 		"GemPortId": gemPort,
 	}).Debug("Sent EapStart packet")
 
-	if err := stateMachine.Event("eap_start_sent"); err != nil {
+	if err := stateMachine.Event(EventStartSent); err != nil {
 		eapolLogger.WithFields(log.Fields{
 			"OnuId":     onuId,
 			"IntfId":    ponPortId,
@@ -327,6 +345,21 @@
 		}).Infof("Sent EAPIdentityRequest packet")
 		return
 	} else if eap.Code == layers.EAPCodeRequest && eap.Type == layers.EAPTypeIdentity {
+		if state := stateMachine.Current(); state != StateStartSent {
+			log.WithFields(log.Fields{
+				"OnuId":  onuId,
+				"IntfId": ponPortId,
+				"OnuSn":  serialNumber,
+				"PortNo": portNo,
+				"UniId":  uniId,
+			}).Errorf("Received EAPIdentityRequest packet while in state %q, dropped", state)
+
+			if state := stateMachine.Current(); state != StateAuthFailed && state != StateResponseSuccessReceived {
+				_ = updateAuthFailed(onuId, ponPortId, serialNumber, stateMachine)
+			}
+			return
+		}
+
 		reseap := createEAPIdentityResponse(eap.Id)
 		pkt := createEAPOLPkt(reseap, serviceId, uniId, onuId, ponPortId, oltId)
 
@@ -337,7 +370,7 @@
 		}
 
 		if err := sendEapolPktIn(msg, portNo, gemPortId, stream); err != nil {
-			_ = stateMachine.Event("auth_failed")
+			_ = updateAuthFailed(onuId, ponPortId, serialNumber, stateMachine)
 			return
 		}
 		eapolLogger.WithFields(log.Fields{
@@ -347,7 +380,7 @@
 			"PortNo": portNo,
 			"UniId":  uniId,
 		}).Debugf("Sent EAPIdentityResponse packet")
-		if err := stateMachine.Event("eap_response_identity_sent"); err != nil {
+		if err := stateMachine.Event(EventResponseIdentitySent); err != nil {
 			eapolLogger.WithFields(log.Fields{
 				"OnuId":  onuId,
 				"IntfId": ponPortId,
@@ -381,6 +414,21 @@
 		}).Infof("Sent EAPChallengeRequest packet")
 		return
 	} else if eap.Code == layers.EAPCodeRequest && eap.Type == layers.EAPTypeOTP {
+		if state := stateMachine.Current(); state != StateResponseIdentitySent {
+			log.WithFields(log.Fields{
+				"OnuId":  onuId,
+				"IntfId": ponPortId,
+				"OnuSn":  serialNumber,
+				"PortNo": portNo,
+				"UniId":  uniId,
+			}).Errorf("Received EAPChallengeRequest packet while in state %q, dropped", state)
+
+			if state := stateMachine.Current(); state != StateAuthFailed && state != StateResponseSuccessReceived {
+				_ = updateAuthFailed(onuId, ponPortId, serialNumber, stateMachine)
+			}
+			return
+		}
+
 		senddata := getMD5Data(eap)
 		senddata = append([]byte{0x10}, senddata...)
 		sendeap := createEAPChallengeResponse(eap.Id, senddata)
@@ -393,7 +441,7 @@
 		}
 
 		if err := sendEapolPktIn(msg, portNo, gemPortId, stream); err != nil {
-			_ = stateMachine.Event("auth_failed")
+			_ = updateAuthFailed(onuId, ponPortId, serialNumber, stateMachine)
 			return
 		}
 		eapolLogger.WithFields(log.Fields{
@@ -403,7 +451,7 @@
 			"PortNo": portNo,
 			"UniId":  uniId,
 		}).Debugf("Sent EAPChallengeResponse packet")
-		if err := stateMachine.Event("eap_response_challenge_sent"); err != nil {
+		if err := stateMachine.Event(EventResponseChallengeSent); err != nil {
 			eapolLogger.WithFields(log.Fields{
 				"OnuId":  onuId,
 				"IntfId": ponPortId,
@@ -442,6 +490,21 @@
 			}).Errorf("Error while transitioning ONU State %v", err)
 		}
 	} else if eap.Code == layers.EAPCodeSuccess && eap.Type == layers.EAPTypeNone {
+		if state := stateMachine.Current(); state != StateResponseChallengeSent {
+			log.WithFields(log.Fields{
+				"OnuId":  onuId,
+				"IntfId": ponPortId,
+				"OnuSn":  serialNumber,
+				"PortNo": portNo,
+				"UniId":  uniId,
+			}).Errorf("Received EAP Success packet while in state %q, dropped", state)
+
+			if state := stateMachine.Current(); state != StateAuthFailed && state != StateResponseSuccessReceived {
+				_ = updateAuthFailed(onuId, ponPortId, serialNumber, stateMachine)
+			}
+			return
+		}
+
 		eapolLogger.WithFields(log.Fields{
 			"OnuId":  onuId,
 			"IntfId": ponPortId,
@@ -449,7 +512,7 @@
 			"PortNo": portNo,
 			"UniId":  uniId,
 		}).Debugf("Received EAPSuccess packet")
-		if err := stateMachine.Event("eap_response_success_received"); err != nil {
+		if err := stateMachine.Event(EventResponseSuccessReceived); err != nil {
 			eapolLogger.WithFields(log.Fields{
 				"OnuId":  onuId,
 				"IntfId": ponPortId,
diff --git a/internal/bbsim/responders/eapol/eapol_test.go b/internal/bbsim/responders/eapol/eapol_test.go
index 21ca86b..d2f861d 100644
--- a/internal/bbsim/responders/eapol/eapol_test.go
+++ b/internal/bbsim/responders/eapol/eapol_test.go
@@ -18,10 +18,14 @@
 
 import (
 	"errors"
+	"fmt"
 	"net"
 	"testing"
 
+	"github.com/google/gopacket"
+	"github.com/google/gopacket/layers"
 	"github.com/looplab/fsm"
+	"github.com/opencord/bbsim/internal/bbsim/types"
 	"github.com/opencord/voltha-protos/v5/go/openolt"
 	"google.golang.org/grpc"
 	"gotest.tools/assert"
@@ -30,13 +34,13 @@
 // MOCKS
 
 var eapolStateMachine = fsm.NewFSM(
-	"auth_started",
+	StateAuthStarted,
 	fsm.Events{
-		{Name: "eap_start_sent", Src: []string{"auth_started"}, Dst: "eap_start_sent"},
-		{Name: "eap_response_identity_sent", Src: []string{"eap_start_sent"}, Dst: "eap_response_identity_sent"},
-		{Name: "eap_response_challenge_sent", Src: []string{"eap_response_identity_sent"}, Dst: "eap_response_challenge_sent"},
-		{Name: "eap_response_success_received", Src: []string{"eap_response_challenge_sent"}, Dst: "eap_response_success_received"},
-		{Name: "auth_failed", Src: []string{"auth_started", "eap_start_sent", "eap_response_identity_sent", "eap_response_challenge_sent"}, Dst: "auth_failed"},
+		{Name: EventStartSent, Src: []string{StateAuthStarted}, Dst: StateStartSent},
+		{Name: EventResponseIdentitySent, Src: []string{StateStartSent}, Dst: StateResponseIdentitySent},
+		{Name: EventResponseChallengeSent, Src: []string{StateResponseIdentitySent}, Dst: StateResponseChallengeSent},
+		{Name: EventResponseSuccessReceived, Src: []string{StateResponseChallengeSent}, Dst: StateResponseSuccessReceived},
+		{Name: EventAuthFailed, Src: []string{StateAuthStarted, StateStartSent, StateResponseIdentitySent, StateResponseChallengeSent}, Dst: StateAuthFailed},
 	},
 	fsm.Callbacks{},
 )
@@ -49,6 +53,8 @@
 var serialNumber string = "BBSM00000001"
 var macAddress = net.HardwareAddr{0x01, 0x80, 0xC2, 0x00, 0x00, 0x03}
 var portNo uint32 = 16
+var serviceId uint32 = 0
+var oltId int = 0
 
 type mockStream struct {
 	grpc.ServerStream
@@ -69,7 +75,7 @@
 // TESTS
 
 func TestSendEapStartSuccess(t *testing.T) {
-	eapolStateMachine.SetState("auth_started")
+	eapolStateMachine.SetState(StateAuthStarted)
 
 	stream := &mockStream{
 		Calls: make(map[int]*openolt.PacketIndication),
@@ -87,13 +93,13 @@
 	assert.Equal(t, stream.Calls[1].IntfType, "pon")
 	assert.Equal(t, stream.Calls[1].GemportId, uint32(gemPortId))
 
-	assert.Equal(t, eapolStateMachine.Current(), "eap_start_sent")
+	assert.Equal(t, eapolStateMachine.Current(), StateStartSent)
 
 }
 
 func TestSendEapStartFailStreamError(t *testing.T) {
 
-	eapolStateMachine.SetState("auth_started")
+	eapolStateMachine.SetState(StateAuthStarted)
 
 	stream := &mockStream{
 		Calls: make(map[int]*openolt.PacketIndication),
@@ -108,7 +114,7 @@
 
 	assert.Equal(t, err.Error(), "fake-error")
 
-	assert.Equal(t, eapolStateMachine.Current(), "auth_failed")
+	assert.Equal(t, eapolStateMachine.Current(), StateAuthFailed)
 }
 
 // TODO test eapol.HandleNextPacket
@@ -119,28 +125,84 @@
 	var ponPortId uint32 = 0
 	var serialNumber string = "BBSM00000001"
 
-	eapolStateMachine.SetState("auth_started")
+	eapolStateMachine.SetState(StateAuthStarted)
 	_ = updateAuthFailed(onuId, ponPortId, serialNumber, eapolStateMachine)
 	assert.Equal(t, eapolStateMachine.Current(), "auth_failed")
 
-	eapolStateMachine.SetState("eap_start_sent")
+	eapolStateMachine.SetState(StateStartSent)
 	_ = updateAuthFailed(onuId, ponPortId, serialNumber, eapolStateMachine)
 	assert.Equal(t, eapolStateMachine.Current(), "auth_failed")
 
-	eapolStateMachine.SetState("eap_response_identity_sent")
+	eapolStateMachine.SetState(StateResponseIdentitySent)
 	_ = updateAuthFailed(onuId, ponPortId, serialNumber, eapolStateMachine)
 	assert.Equal(t, eapolStateMachine.Current(), "auth_failed")
 
-	eapolStateMachine.SetState("eap_response_challenge_sent")
+	eapolStateMachine.SetState(StateResponseChallengeSent)
 	_ = updateAuthFailed(onuId, ponPortId, serialNumber, eapolStateMachine)
 	assert.Equal(t, eapolStateMachine.Current(), "auth_failed")
 
-	eapolStateMachine.SetState("eap_response_success_received")
+	eapolStateMachine.SetState(StateResponseSuccessReceived)
 	err := updateAuthFailed(onuId, ponPortId, serialNumber, eapolStateMachine)
 	if err == nil {
 		t.Errorf("updateAuthFailed did not return an error")
 		t.Fail()
 	}
-	assert.Equal(t, err.Error(), "event auth_failed inappropriate in current state eap_response_success_received")
+	assert.Equal(t, err.Error(), fmt.Sprintf("event %s inappropriate in current state %s", EventAuthFailed, StateResponseSuccessReceived))
 
 }
+
+func createTestEAPOLPkt(eap *layers.EAP) gopacket.Packet {
+	bytes := createEAPOLPkt(eap, serviceId, uniId, onuId, ponPortId, oltId)
+	return gopacket.NewPacket(bytes, layers.LayerTypeEthernet, gopacket.Default)
+}
+
+func handleTestEAPOLPkt(pkt gopacket.Packet, stream types.Stream) {
+	HandleNextPacket(onuId, ponPortId, gemPortId, serialNumber, portNo, uniId, serviceId, oltId, eapolStateMachine, pkt, stream, nil)
+}
+
+func TestDropUnexpectedPackets(t *testing.T) {
+	stream := &mockStream{
+		Calls: make(map[int]*openolt.PacketIndication),
+	}
+
+	const eapId uint8 = 1
+
+	//Create test packets
+	identityRequest := createEAPIdentityRequest(eapId)
+	challangeRequest := createEAPChallengeRequest(eapId, []byte{0x10})
+	success := createEAPSuccess(eapId)
+
+	identityPkt := createTestEAPOLPkt(identityRequest)
+	challangePkt := createTestEAPOLPkt(challangeRequest)
+	successPkt := createTestEAPOLPkt(success)
+
+	testStates := map[string]struct {
+		packets          []gopacket.Packet
+		destinationState string
+	}{
+		//All packet should be dropped in state auth_started
+		StateAuthStarted: {[]gopacket.Packet{identityPkt, challangePkt, successPkt}, StateAuthFailed},
+		//Only the identity request packet should be handled in state eap_start_sent
+		StateStartSent: {[]gopacket.Packet{challangePkt, successPkt}, StateAuthFailed},
+		//Only the challange request packet should be handled in state eap_response_identity_sent
+		StateResponseIdentitySent: {[]gopacket.Packet{identityPkt, successPkt}, StateAuthFailed},
+		//Only the success packet should be handled in state eap_response_challenge_sent
+		StateResponseChallengeSent: {[]gopacket.Packet{identityPkt, challangePkt}, StateAuthFailed},
+		//All packet should be dropped in state eap_response_success_received
+		StateResponseSuccessReceived: {[]gopacket.Packet{identityPkt, challangePkt, successPkt}, StateResponseSuccessReceived},
+		//All packet should be dropped in state auth_failed
+		StateAuthFailed: {[]gopacket.Packet{identityPkt, challangePkt}, StateAuthFailed},
+	}
+
+	for s, info := range testStates {
+		for _, p := range info.packets {
+			eapolStateMachine.SetState(s)
+			handleTestEAPOLPkt(p, stream)
+
+			//No response should be sent
+			assert.Equal(t, stream.CallCount, 0)
+			//The state machine should transition to the failed state
+			assert.Equal(t, eapolStateMachine.Current(), info.destinationState)
+		}
+	}
+}