Merge "[VOL-3942] Correctly validation flows that have ReplicateFlow set to true"
diff --git a/internal/bbsim/devices/olt.go b/internal/bbsim/devices/olt.go
index b2480e7..ad0a8d0 100644
--- a/internal/bbsim/devices/olt.go
+++ b/internal/bbsim/devices/olt.go
@@ -1091,7 +1091,7 @@
 			return nil, err
 		}
 
-		o.storeGemPortId(flow)
+		o.storeGemPortIdByFlow(flow)
 		o.storeAllocId(flow)
 
 		msg := types.Message{
@@ -1659,7 +1659,17 @@
 	}
 }
 
-func (o *OltDevice) storeGemPortId(flow *openolt.Flow) {
+func (o *OltDevice) storeGemPortId(ponId uint32, onuId uint32, portNo uint32, gemId int32, flowId uint64) {
+	if _, ok := o.GemPortIDs[ponId][onuId][portNo]; !ok {
+		o.GemPortIDs[ponId][onuId][portNo] = make(map[int32]map[uint64]bool)
+	}
+	if _, ok := o.GemPortIDs[ponId][onuId][portNo][gemId]; !ok {
+		o.GemPortIDs[ponId][onuId][portNo][gemId] = make(map[uint64]bool)
+	}
+	o.GemPortIDs[ponId][onuId][portNo][gemId][flowId] = true
+}
+
+func (o *OltDevice) storeGemPortIdByFlow(flow *openolt.Flow) {
 	o.GemPortIDsLock.Lock()
 	defer o.GemPortIDsLock.Unlock()
 
@@ -1670,13 +1680,14 @@
 		"GemportId": flow.GemportId,
 	}).Trace("storing-gem-port-id-via-flow")
 
-	if _, ok := o.GemPortIDs[uint32(flow.AccessIntfId)][uint32(flow.OnuId)][flow.PortNo]; !ok {
-		o.GemPortIDs[uint32(flow.AccessIntfId)][uint32(flow.OnuId)][flow.PortNo] = make(map[int32]map[uint64]bool)
+	if flow.ReplicateFlow {
+		for _, gem := range flow.PbitToGemport {
+			o.storeGemPortId(uint32(flow.AccessIntfId), uint32(flow.OnuId), flow.PortNo, int32(gem), flow.FlowId)
+		}
+	} else {
+		o.storeGemPortId(uint32(flow.AccessIntfId), uint32(flow.OnuId), flow.PortNo, flow.GemportId, flow.FlowId)
 	}
-	if _, ok := o.GemPortIDs[uint32(flow.AccessIntfId)][uint32(flow.OnuId)][flow.PortNo][flow.GemportId]; !ok {
-		o.GemPortIDs[uint32(flow.AccessIntfId)][uint32(flow.OnuId)][flow.PortNo][flow.GemportId] = make(map[uint64]bool)
-	}
-	o.GemPortIDs[uint32(flow.AccessIntfId)][uint32(flow.OnuId)][flow.PortNo][flow.GemportId][flow.FlowId] = true
+
 }
 
 func (o *OltDevice) freeGemPortId(flow *openolt.Flow) {
@@ -1736,8 +1747,16 @@
 		}
 		for uniId, uni := range onu {
 			for gem := range uni {
-				if gem == flow.GemportId {
-					return fmt.Errorf("gem-%d-already-in-use-on-uni-%d-onu-%d", gem, uniId, onuId)
+				if flow.ReplicateFlow {
+					for _, flowGem := range flow.PbitToGemport {
+						if gem == int32(flowGem) {
+							return fmt.Errorf("gem-%d-already-in-use-on-uni-%d-onu-%d-replicated-flow-%d", gem, uniId, onuId, flow.FlowId)
+						}
+					}
+				} else {
+					if gem == flow.GemportId {
+						return fmt.Errorf("gem-%d-already-in-use-on-uni-%d-onu-%d-flow-%d", gem, uniId, onuId, flow.FlowId)
+					}
 				}
 			}
 		}
@@ -1753,7 +1772,7 @@
 		for uniId, uni := range onu {
 			for allocId := range uni {
 				if allocId == flow.AllocId {
-					return fmt.Errorf("allocId-%d-already-in-use-on-uni-%d-onu-%d", allocId, uniId, onuId)
+					return fmt.Errorf("allocId-%d-already-in-use-on-uni-%d-onu-%d-flow-%d", allocId, uniId, onuId, flow.FlowId)
 				}
 			}
 		}
diff --git a/internal/bbsim/devices/olt_test.go b/internal/bbsim/devices/olt_test.go
index 994902e..4de6609 100644
--- a/internal/bbsim/devices/olt_test.go
+++ b/internal/bbsim/devices/olt_test.go
@@ -18,15 +18,16 @@
 
 import (
 	"context"
+	"fmt"
 	"github.com/looplab/fsm"
 	"github.com/opencord/bbsim/internal/bbsim/types"
 	bbsim "github.com/opencord/bbsim/internal/bbsim/types"
 	"github.com/opencord/bbsim/internal/common"
+	"github.com/stretchr/testify/assert"
 	"net"
 	"testing"
 
 	"github.com/opencord/voltha-protos/v4/go/openolt"
-	"gotest.tools/assert"
 )
 
 func createMockOlt(numPon int, numOnu int, services []ServiceIf) *OltDevice {
@@ -191,7 +192,7 @@
 	mac := net.HardwareAddr{0x2e, 0x60, byte(olt.ID), byte(3), byte(6), byte(1)}
 	s, err := olt.FindServiceByMacAddress(mac)
 
-	assert.NilError(t, err)
+	assert.NoError(t, err)
 
 	service := s.(*Service)
 
@@ -278,7 +279,7 @@
 		GemportId:    gem1,
 	}
 
-	olt.storeGemPortId(flow1)
+	olt.storeGemPortIdByFlow(flow1)
 	assert.Equal(t, len(olt.GemPortIDs[pon][onu][uni]), 1)       // we have 1 gem port
 	assert.Equal(t, len(olt.GemPortIDs[pon][onu][uni][gem1]), 1) // and one flow referencing it
 
@@ -291,7 +292,7 @@
 		GemportId:    gem1,
 	}
 
-	olt.storeGemPortId(flow2)
+	olt.storeGemPortIdByFlow(flow2)
 	assert.Equal(t, len(olt.GemPortIDs[pon][onu][uni]), 1)       // we have 1 gem port
 	assert.Equal(t, len(olt.GemPortIDs[pon][onu][uni][gem1]), 2) // and two flows referencing it
 
@@ -304,12 +305,46 @@
 		GemportId:    1025,
 	}
 
-	olt.storeGemPortId(flow3)
+	olt.storeGemPortIdByFlow(flow3)
 	assert.Equal(t, len(olt.GemPortIDs[pon][onu][uni]), 2)       // we have 2 gem ports
 	assert.Equal(t, len(olt.GemPortIDs[pon][onu][uni][gem1]), 2) // two flows referencing the first one
 	assert.Equal(t, len(olt.GemPortIDs[pon][onu][uni][gem2]), 1) // and one flow referencing the second one
 }
 
+func Test_Olt_storeGemPortIdReplicatedFlow(t *testing.T) {
+	const (
+		pon  = 1
+		onu  = 1
+		uni  = 16
+		gem1 = 1024
+		gem2 = 1025
+	)
+
+	numPon := 2
+	numOnu := 2
+
+	olt := createMockOlt(numPon, numOnu, []ServiceIf{})
+
+	// add a flow that needs replication
+	pbitToGemPortMap := make(map[uint32]uint32)
+	pbitToGemPortMap[0] = gem1
+	pbitToGemPortMap[1] = gem2
+	flow1 := &openolt.Flow{
+		AccessIntfId:  pon,
+		OnuId:         onu,
+		PortNo:        uni,
+		FlowId:        1,
+		GemportId:     0,
+		ReplicateFlow: true,
+		PbitToGemport: pbitToGemPortMap,
+	}
+
+	olt.storeGemPortIdByFlow(flow1)
+	assert.Equal(t, len(olt.GemPortIDs[pon][onu][uni]), 2)       // we have 2 gem ports in the flow
+	assert.Equal(t, len(olt.GemPortIDs[pon][onu][uni][gem1]), 1) // and one flow referencing them
+	assert.Equal(t, len(olt.GemPortIDs[pon][onu][uni][gem2]), 1) // and one flow referencing them
+}
+
 func Test_Olt_freeGemPortId(t *testing.T) {
 	const (
 		pon   = 1
@@ -369,6 +404,38 @@
 	assert.Equal(t, gem2exists, true)
 }
 
+func Test_Olt_freeGemPortIdReplicatedflow(t *testing.T) {
+	const (
+		pon   = 1
+		onu   = 1
+		uni   = 16
+		gem1  = 1024
+		gem2  = 1025
+		flow1 = 1
+	)
+
+	numPon := 2
+	numOnu := 2
+
+	olt := createMockOlt(numPon, numOnu, []ServiceIf{})
+
+	olt.GemPortIDs[pon][onu][uni] = make(map[int32]map[uint64]bool)
+	olt.GemPortIDs[pon][onu][uni][gem1] = make(map[uint64]bool)
+	olt.GemPortIDs[pon][onu][uni][gem1][flow1] = true
+	olt.GemPortIDs[pon][onu][uni][gem2] = make(map[uint64]bool)
+	olt.GemPortIDs[pon][onu][uni][gem2][flow1] = true
+
+	// this flow was a replicated flow, remove all the gems that are referenced by that flow
+	flowMultiGem := &openolt.Flow{
+		FlowId: flow1,
+	}
+
+	olt.freeGemPortId(flowMultiGem)
+
+	// this flow removes all the gems, so no UNI should be left
+	assert.Equal(t, len(olt.GemPortIDs[pon][onu][uni]), 0)
+}
+
 func Test_Olt_validateFlow(t *testing.T) {
 
 	const (
@@ -412,7 +479,7 @@
 	}
 
 	err := olt.validateFlow(validGemFlow)
-	assert.NilError(t, err)
+	assert.NoError(t, err)
 
 	// a GemPortID can NOT be referenced across different ONUs on the same PON
 	invalidGemFlow := &openolt.Flow{
@@ -430,7 +497,7 @@
 		GemportId:    usedGemIdPon0,
 	}
 	err = olt.validateFlow(invalidGemDifferentPonFlow)
-	assert.NilError(t, err)
+	assert.NoError(t, err)
 
 	// an allocId can be referenced across multiple flows on the same ONU
 	validAllocFlow := &openolt.Flow{
@@ -439,7 +506,7 @@
 		AllocId:      usedAllocIdPon0,
 	}
 	err = olt.validateFlow(validAllocFlow)
-	assert.NilError(t, err)
+	assert.NoError(t, err)
 
 	// an allocId can NOT be referenced across different ONUs on the same PON
 	invalidAllocFlow := &openolt.Flow{
@@ -457,7 +524,65 @@
 		AllocId:      usedAllocIdPon0,
 	}
 	err = olt.validateFlow(invalidAllocDifferentPonFlow)
-	assert.NilError(t, err)
+	assert.NoError(t, err)
+}
+
+func Test_Olt_validateReplicatedFlow(t *testing.T) {
+
+	const (
+		pon0            = 0
+		onu0            = 0
+		onu1            = 1
+		uniPort         = 0
+		usedGemId1      = 1024
+		usedGemId2      = 1025
+		usedAllocIdPon0 = 1
+		flowId          = 1
+	)
+
+	numPon := 1
+	numOnu := 1
+
+	olt := createMockOlt(numPon, numOnu, []ServiceIf{})
+
+	// both the gemports referenced in this flow are already allocated
+	olt.GemPortIDs[pon0][onu0][uniPort] = make(map[int32]map[uint64]bool)
+	olt.GemPortIDs[pon0][onu0][uniPort][usedGemId1] = make(map[uint64]bool)
+	olt.GemPortIDs[pon0][onu0][uniPort][usedGemId1][flowId] = true
+	olt.GemPortIDs[pon0][onu0][uniPort][usedGemId2] = make(map[uint64]bool)
+	olt.GemPortIDs[pon0][onu0][uniPort][usedGemId2][flowId] = true
+
+	olt.AllocIDs[pon0][onu0][uniPort] = make(map[int32]map[uint64]bool)
+	olt.AllocIDs[pon0][onu0][uniPort][usedAllocIdPon0] = make(map[uint64]bool)
+	olt.AllocIDs[pon0][onu0][uniPort][usedAllocIdPon0][flowId] = true
+
+	pbitToGemPortMap := make(map[uint32]uint32)
+	pbitToGemPortMap[0] = usedGemId1
+	pbitToGemPortMap[1] = usedGemId2
+
+	// this flow should fail vlidation as the gems are already allocated to Onu0
+	invalidGemFlow := &openolt.Flow{
+		AccessIntfId:  pon0,
+		OnuId:         onu1,
+		PortNo:        uniPort,
+		GemportId:     0,
+		ReplicateFlow: true,
+		PbitToGemport: pbitToGemPortMap,
+	}
+
+	err := olt.validateFlow(invalidGemFlow)
+	assert.NotNil(t, err)
+
+	// PbitToGemport is a map, so any of the two gemPorts can fail first and determine the error message
+	foundError := false
+	switch err.Error() {
+	case fmt.Sprintf("gem-%d-already-in-use-on-uni-%d-onu-%d-replicated-flow-%d", usedGemId2, uniPort, onu0, invalidGemFlow.FlowId):
+		foundError = true
+	case fmt.Sprintf("gem-%d-already-in-use-on-uni-%d-onu-%d-replicated-flow-%d", usedGemId1, uniPort, onu0, invalidGemFlow.FlowId):
+		foundError = true
+
+	}
+	assert.True(t, foundError)
 }
 
 func Test_Olt_OmciMsgOut(t *testing.T) {
@@ -500,13 +625,13 @@
 		Pkt:    makeOmciSetRequest(t),
 	}
 	_, err = olt.OmciMsgOut(ctx, msg)
-	assert.NilError(t, err)
+	assert.NoError(t, err)
 	assert.Equal(t, len(onu.Channel), 0) // check that no messages have been sent
 
 	// test that the ONU receives a valid packet
 	onu.InternalState.SetState(OnuStateEnabled)
 	_, err = olt.OmciMsgOut(ctx, msg)
-	assert.NilError(t, err)
+	assert.NoError(t, err)
 	assert.Equal(t, len(onu.Channel), 1) // check that one message have been sent
 
 }