Merge "[VOL-4039] Removing allocId and GemPort from cache when MibReset is received"
diff --git a/internal/bbsim/devices/onu.go b/internal/bbsim/devices/onu.go
index e2ed180..982e9ef 100644
--- a/internal/bbsim/devices/onu.go
+++ b/internal/bbsim/devices/onu.go
@@ -763,9 +763,13 @@
 			"IntfId":       o.PonPortID,
 			"OnuId":        o.ID,
 			"SerialNumber": o.Sn(),
-		}).Debug("received-mib-reset-request-resetting-mds")
+		}).Debug("received-mib-reset-request")
 		if responsePkt, errResp = omcilib.CreateMibResetResponse(msg.OmciMsg.TransactionID); errResp == nil {
 			o.MibDataSync = 0
+
+			// if the MIB reset is successful then remove all the stored AllocIds and GemPorts
+			o.PonPort.removeAllocId(o.SerialNumber)
+			o.PonPort.removeGemPortBySn(o.SerialNumber)
 		}
 	case omci.MibUploadRequestType:
 		responsePkt, _ = omcilib.CreateMibUploadResponse(msg.OmciMsg.TransactionID)
diff --git a/internal/bbsim/devices/onu_omci_test.go b/internal/bbsim/devices/onu_omci_test.go
index 20f8a35..846b61c 100644
--- a/internal/bbsim/devices/onu_omci_test.go
+++ b/internal/bbsim/devices/onu_omci_test.go
@@ -109,6 +109,76 @@
 	return omciPkt
 }
 
+func makeOmciStartSoftwareDownloadRequest(t *testing.T) []byte {
+	omciReq := &omci.StartSoftwareDownloadRequest{
+		MeBasePacket: omci.MeBasePacket{
+			EntityClass: me.SoftwareImageClassID,
+		},
+		ImageSize:            31,
+		NumberOfCircuitPacks: 1,
+		WindowSize:           31,
+		CircuitPacks:         []uint16{0},
+	}
+	omciPkt, err := omcilib.Serialize(omci.StartSoftwareDownloadRequestType, omciReq, 66)
+	if err != nil {
+		t.Fatal(err.Error())
+	}
+
+	omciPkt, _ = omcilib.HexEncode(omciPkt)
+
+	return omciPkt
+}
+
+func makeOmciEndSoftwareDownloadRequest(t *testing.T) []byte {
+	omciReq := &omci.EndSoftwareDownloadRequest{
+		MeBasePacket: omci.MeBasePacket{
+			EntityClass: me.SoftwareImageClassID,
+		},
+		NumberOfInstances: 1,
+		ImageInstances:    []uint16{0},
+	}
+	omciPkt, err := omcilib.Serialize(omci.EndSoftwareDownloadRequestType, omciReq, 66)
+	if err != nil {
+		t.Fatal(err.Error())
+	}
+
+	omciPkt, _ = omcilib.HexEncode(omciPkt)
+
+	return omciPkt
+}
+
+func makeOmciActivateSoftwareRequest(t *testing.T) []byte {
+	omciReq := &omci.ActivateSoftwareRequest{
+		MeBasePacket: omci.MeBasePacket{
+			EntityClass: me.SoftwareImageClassID,
+		},
+	}
+	omciPkt, err := omcilib.Serialize(omci.ActivateSoftwareRequestType, omciReq, 66)
+	if err != nil {
+		t.Fatal(err.Error())
+	}
+
+	omciPkt, _ = omcilib.HexEncode(omciPkt)
+
+	return omciPkt
+}
+
+func makeOmciCommitSoftwareRequest(t *testing.T) []byte {
+	omciReq := &omci.CommitSoftwareRequest{
+		MeBasePacket: omci.MeBasePacket{
+			EntityClass: me.SoftwareImageClassID,
+		},
+	}
+	omciPkt, err := omcilib.Serialize(omci.CommitSoftwareRequestType, omciReq, 66)
+	if err != nil {
+		t.Fatal(err.Error())
+	}
+
+	omciPkt, _ = omcilib.HexEncode(omciPkt)
+
+	return omciPkt
+}
+
 func makeOmciMessage(t *testing.T, onu *Onu, pkt []byte) bbsim.OmciMessage {
 	omciPkt, omciMsg, err := omcilib.ParseOpenOltOmciPacket(pkt)
 	if err != nil {
@@ -151,7 +221,7 @@
 }
 
 func Test_MibDataSyncIncrease(t *testing.T) {
-	onu := createMockOnu(1, 1)
+	onu := createTestOnu()
 
 	assert.Equal(t, onu.MibDataSync, uint8(0))
 
@@ -171,11 +241,24 @@
 	onu.handleOmciRequest(makeOmciMessage(t, onu, makeOmciDeleteRequest(t)), stream)
 	assert.Equal(t, onu.MibDataSync, uint8(3))
 
-	// TODO once supported MDS should increase for:
-	// - Start software download
-	// - End software download
-	// - Activate software
-	// - Commit software
+	// Start software download
+	onu.InternalState.SetState(OnuStateEnabled)
+	onu.handleOmciRequest(makeOmciMessage(t, onu, makeOmciStartSoftwareDownloadRequest(t)), stream)
+	assert.Equal(t, onu.MibDataSync, uint8(4))
+
+	// End software download
+	onu.ImageSoftwareReceivedSections = 1 // we fake that we have received the one download section we expect
+	onu.InternalState.SetState(OnuStateImageDownloadInProgress)
+	onu.handleOmciRequest(makeOmciMessage(t, onu, makeOmciEndSoftwareDownloadRequest(t)), stream)
+	assert.Equal(t, onu.MibDataSync, uint8(5))
+
+	// Activate software
+	onu.handleOmciRequest(makeOmciMessage(t, onu, makeOmciActivateSoftwareRequest(t)), stream)
+	assert.Equal(t, onu.MibDataSync, uint8(6))
+
+	// Commit software
+	onu.handleOmciRequest(makeOmciMessage(t, onu, makeOmciCommitSoftwareRequest(t)), stream)
+	assert.Equal(t, onu.MibDataSync, uint8(7))
 }
 
 func Test_MibDataSyncReset(t *testing.T) {
@@ -187,9 +270,19 @@
 		Calls: make(map[int]*openolt.Indication),
 	}
 
-	// send a MibReset and check that MDS has reset to 0
+	// create a GemPort and an AllocId for this ONU
+	onu.PonPort.storeGemPort(1024, onu.SerialNumber)
+	onu.PonPort.storeAllocId(1024, onu.SerialNumber)
+
+	// send a MibReset
 	onu.handleOmciRequest(makeOmciMessage(t, onu, makeOmciMibResetRequest(t)), stream)
+
+	// check that MDS has reset to 0
 	assert.Equal(t, onu.MibDataSync, uint8(0))
+
+	// check that GemPort and AllocId have been removed
+	assert.Equal(t, len(onu.PonPort.AllocatedGemPorts), 0)
+	assert.Equal(t, len(onu.PonPort.AllocatedAllocIds), 0)
 }
 
 func Test_MibDataSyncRotation(t *testing.T) {