Checking for EAPOL and DHCP if needed when removing flows

Change-Id: I3297eeffe1ca122ee0d4a16ac2eebffdb5b2b0d3
diff --git a/impl/src/main/java/org/opencord/olt/impl/OltFlowService.java b/impl/src/main/java/org/opencord/olt/impl/OltFlowService.java
index 6252144..4d452e6 100644
--- a/impl/src/main/java/org/opencord/olt/impl/OltFlowService.java
+++ b/impl/src/main/java/org/opencord/olt/impl/OltFlowService.java
@@ -616,7 +616,7 @@
             Iterator<UniTagInformation> iter = sub.subscriberAndDeviceInformation.uniTagList().iterator();
             while (iter.hasNext()) {
                 UniTagInformation entry = iter.next();
-                if (areSubscriberFlowsPendingRemoval(sub.port, entry)) {
+                if (areSubscriberFlowsPendingRemoval(sub.port, entry, enableEapol)) {
                     log.info("Subscriber {} still have flows on service {}, postpone default EAPOL installation.",
                             portWithName(sub.port), entry.getServiceName());
                     return false;
@@ -704,13 +704,15 @@
                 status.subscriberFlowsStatus == OltFlowsStatus.PENDING_ADD);
     }
 
-    public boolean areSubscriberFlowsPendingRemoval(Port port, UniTagInformation uti) {
+    public boolean areSubscriberFlowsPendingRemoval(Port port, UniTagInformation uti, boolean enableEapol) {
         OltPortStatus status = getOltPortStatus(port, uti);
         if (log.isTraceEnabled()) {
             log.trace("Status during pending_remove flow check {} for port {} and UniTagInformation {}",
                     status, portWithName(port), uti);
         }
-        return status != null && status.subscriberFlowsStatus == OltFlowsStatus.PENDING_REMOVE;
+        return status != null && (status.subscriberFlowsStatus == OltFlowsStatus.PENDING_REMOVE ||
+                (enableEapol && status.subscriberEapolStatus == OltFlowsStatus.PENDING_REMOVE) ||
+                (uti.getIsDhcpRequired() && status.dhcpStatus == OltFlowsStatus.PENDING_REMOVE));
     }
 
     @Override
@@ -801,14 +803,15 @@
 
         // create a subscriberKey for the EAPOL flow
         ServiceKey sk = new ServiceKey(new AccessDevicePort(sub.port), defaultEapolUniTag);
-
-        // NOTE we only need to keep track of the default EAPOL flow in the
-        // connectpoint status map
+        OltFlowsStatus status = action == FlowOperation.ADD ?
+                OltFlowsStatus.PENDING_ADD : OltFlowsStatus.PENDING_REMOVE;
         if (vlanId.id().equals(EAPOL_DEFAULT_VLAN)) {
-            OltFlowsStatus status = action == FlowOperation.ADD ?
-                    OltFlowsStatus.PENDING_ADD : OltFlowsStatus.PENDING_REMOVE;
-            updateConnectPointStatus(sk, status, OltFlowsStatus.NONE, OltFlowsStatus.NONE, OltFlowsStatus.NONE);
+            updateConnectPointStatus(sk, status, OltFlowsStatus.NONE, OltFlowsStatus.NONE,
+                                     OltFlowsStatus.NONE, OltFlowsStatus.NONE);
 
+        } else {
+            updateConnectPointStatus(sk, OltFlowsStatus.NONE, status, OltFlowsStatus.NONE,
+                                     OltFlowsStatus.NONE, OltFlowsStatus.NONE);
         }
 
         DefaultFilteringObjective.Builder filterBuilder = DefaultFilteringObjective.builder();
@@ -886,7 +889,7 @@
 
                         if (vlanId.id().equals(EAPOL_DEFAULT_VLAN)) {
                             updateConnectPointStatus(sk,
-                                    OltFlowsStatus.ERROR, null, null, null);
+                                                     OltFlowsStatus.ERROR, null, null, null, null);
                         }
                     }
                 });
@@ -923,7 +926,7 @@
                     u.getUpstreamOltBandwidthProfile(),
                     action, u.getPonCTag())) {
                 //
-                log.error("Failed to {} EAPOL with suscriber tags", action);
+                log.error("Failed to {} EAPOL with subscriber tags", action);
                 //TODO this sets it for all services, maybe some services succeeded.
                 success.set(false);
             }
@@ -1118,7 +1121,7 @@
             ServiceKey sk = new ServiceKey(new AccessDevicePort(port), uti);
             OltFlowsStatus status = action.equals(FlowOperation.ADD) ?
                     OltFlowsStatus.PENDING_ADD : OltFlowsStatus.PENDING_REMOVE;
-            updateConnectPointStatus(sk, null, status, null, null);
+            updateConnectPointStatus(sk, null, null, status, null, null);
 
             // upstream flows
             MeterId usMeterId = oltMeterService
@@ -1148,7 +1151,7 @@
 
         OltFlowsStatus status = action.equals(FlowOperation.ADD) ?
                 OltFlowsStatus.PENDING_ADD : OltFlowsStatus.PENDING_REMOVE;
-        updateConnectPointStatus(sk, null, null, status, null);
+        updateConnectPointStatus(sk, null, null, null, status, null);
 
         DefaultFilteringObjective.Builder builder = DefaultFilteringObjective.builder();
         TrafficTreatment.Builder treatmentBuilder = DefaultTrafficTreatment.builder();
@@ -1202,7 +1205,7 @@
                         portWithName(port),
                         action,
                         error);
-                updateConnectPointStatus(sk, null, null, OltFlowsStatus.ERROR, null);
+                updateConnectPointStatus(sk, null, null, null, OltFlowsStatus.ERROR, null);
             }
         });
         flowObjectiveService.filter(deviceId, dhcpUpstream);
@@ -1379,7 +1382,7 @@
             public void onError(Objective objective, ObjectiveError error) {
                 log.error("Upstream Data plane filter for {} failed {} because {}.",
                         sk, action, error);
-                updateConnectPointStatus(sk, null, OltFlowsStatus.ERROR, null, null);
+                updateConnectPointStatus(sk, null, null, OltFlowsStatus.ERROR, null, null);
             }
         };
 
@@ -1469,8 +1472,8 @@
             @Override
             public void onError(Objective objective, ObjectiveError error) {
                 log.info("Downstream Data plane filter for {} failed {} because {}.",
-                        sk, action, error);
-                updateConnectPointStatus(sk, null, OltFlowsStatus.ERROR, null, null);
+                         sk, action, error);
+                updateConnectPointStatus(sk, null, null, OltFlowsStatus.ERROR, null, null);
             }
         };
 
@@ -1578,11 +1581,13 @@
     }
 
     protected void updateConnectPointStatus(ServiceKey key, OltFlowsStatus eapolStatus,
+                                            OltFlowsStatus subscriberEapolStatus,
                                             OltFlowsStatus subscriberFlowsStatus, OltFlowsStatus dhcpStatus,
                                             OltFlowsStatus pppoeStatus) {
         if (log.isTraceEnabled()) {
-            log.trace("Updating cpStatus {} with values: eapolFlow={}, subscriberFlows={}, dhcpFlow={}",
-                    key, eapolStatus, subscriberFlowsStatus, dhcpStatus);
+            log.trace("Updating cpStatus {} with values: eapolFlow={}, " +
+                              "subscriberEapolStatus={}, subscriberFlows={}, dhcpFlow={}",
+                      key, eapolStatus, subscriberEapolStatus, subscriberFlowsStatus, dhcpStatus);
         }
         try {
             cpStatusWriteLock.lock();
@@ -1610,6 +1615,7 @@
 
                 status = new OltPortStatus(
                         eapolStatus != null ? eapolStatus : OltFlowsStatus.NONE,
+                        subscriberEapolStatus != null ? subscriberEapolStatus : OltFlowsStatus.NONE,
                         subscriberFlowsStatus != null ? subscriberFlowsStatus : OltFlowsStatus.NONE,
                         dhcpStatus != null ? dhcpStatus : OltFlowsStatus.NONE,
                         pppoeStatus != null ? pppoeStatus : OltFlowsStatus.NONE
@@ -1698,7 +1704,16 @@
                 if (log.isTraceEnabled()) {
                     log.trace("update defaultEapolStatus {} on {}", status, sk);
                 }
-                updateConnectPointStatus(sk, status, null, null, null);
+                updateConnectPointStatus(sk, status, null, null, null, null);
+            } else if (isSubscriberEapolFlow(flowRule)) {
+                ServiceKey sk = getSubscriberKeyFromFlowRule(flowRule, port);
+                if (sk == null) {
+                    return;
+                }
+                if (log.isTraceEnabled()) {
+                    log.trace("update subscriberEapolStatus {} on {}", status, sk);
+                }
+                updateConnectPointStatus(sk, null, status, null, status, null);
             } else if (isDhcpFlow(flowRule)) {
                 ServiceKey sk = getSubscriberKeyFromFlowRule(flowRule, port);
                 if (sk == null) {
@@ -1707,7 +1722,7 @@
                 if (log.isTraceEnabled()) {
                     log.trace("update dhcpStatus {} on {}", status, sk);
                 }
-                updateConnectPointStatus(sk, null, null, status, null);
+                updateConnectPointStatus(sk, null, null, null, status, null);
             } else if (isPppoeFlow(flowRule)) {
                 ServiceKey sk = getSubscriberKeyFromFlowRule(flowRule, port);
                 if (sk == null) {
@@ -1716,7 +1731,7 @@
                 if (log.isTraceEnabled()) {
                     log.trace("update pppoeStatus {} on {}", status, sk);
                 }
-                updateConnectPointStatus(sk, null, null, null, status);
+                updateConnectPointStatus(sk, null, null, null, null, status);
             } else if (isDataFlow(flowRule)) {
                 PortNumber number = getPortNumberFromFlowRule(flowRule);
                 if (number == null) {
@@ -1735,7 +1750,7 @@
                 if (log.isTraceEnabled()) {
                     log.trace("update dataplaneStatus {} on {}", status, sk);
                 }
-                updateConnectPointStatus(sk, null, status, null, null);
+                updateConnectPointStatus(sk, null, null, status, null, null);
             }
         }
 
@@ -1809,6 +1824,31 @@
             return flowRule.selector().getCriterion(Criterion.Type.VLAN_VID) != null;
         }
 
+        private boolean isSubscriberEapolFlow(FlowRule flowRule) {
+            EthTypeCriterion c = (EthTypeCriterion) flowRule.selector().getCriterion(Criterion.Type.ETH_TYPE);
+            if (c == null) {
+                return false;
+            }
+            if (c.ethType().equals(EthType.EtherType.EAPOL.ethType())) {
+                AtomicBoolean isSubscriber = new AtomicBoolean(false);
+                flowRule.treatment().allInstructions().forEach(instruction -> {
+                    if (instruction.type() == L2MODIFICATION) {
+                        L2ModificationInstruction modificationInstruction = (L2ModificationInstruction) instruction;
+                        if (modificationInstruction.subtype() == L2ModificationInstruction.L2SubType.VLAN_ID) {
+                            L2ModificationInstruction.ModVlanIdInstruction vlanInstruction =
+                                    (L2ModificationInstruction.ModVlanIdInstruction) modificationInstruction;
+                            if (!vlanInstruction.vlanId().id().equals(EAPOL_DEFAULT_VLAN)) {
+                                isSubscriber.set(true);
+                                return;
+                            }
+                        }
+                    }
+                });
+                return isSubscriber.get();
+            }
+            return false;
+        }
+
         private Port getCpFromFlowRule(FlowRule flowRule) {
             DeviceId deviceId = flowRule.deviceId();
             PortNumber inPort = getPortNumberFromFlowRule(flowRule);
@@ -1846,6 +1886,11 @@
                 L2ModificationInstruction.ModVlanIdInstruction instruction =
                         (L2ModificationInstruction.ModVlanIdInstruction) flowRule.treatment().immediate().get(1);
                 flowVlan = instruction.vlanId();
+            } else if (isSubscriberEapolFlow(flowRule)) {
+                // we need to make a special case for EAPOL as in the ATT workflow EAPOL flows don't match on tags
+                L2ModificationInstruction.ModVlanIdInstruction instruction =
+                        (L2ModificationInstruction.ModVlanIdInstruction) flowRule.treatment().immediate().get(2);
+                flowVlan = instruction.vlanId();
             } else {
                 // for now we assume that if it's not DHCP it's dataplane (or at least tagged)
                 VlanIdCriterion vlanIdCriterion =
diff --git a/impl/src/main/java/org/opencord/olt/impl/OltPortStatus.java b/impl/src/main/java/org/opencord/olt/impl/OltPortStatus.java
index 32550b6..3ef348a 100644
--- a/impl/src/main/java/org/opencord/olt/impl/OltPortStatus.java
+++ b/impl/src/main/java/org/opencord/olt/impl/OltPortStatus.java
@@ -24,6 +24,7 @@
 public class OltPortStatus {
     // TODO consider adding a lastUpdated field, it may help with debugging
     public OltFlowService.OltFlowsStatus defaultEapolStatus;
+    public OltFlowService.OltFlowsStatus subscriberEapolStatus;
     public OltFlowService.OltFlowsStatus subscriberFlowsStatus;
     // NOTE we need to keep track of the DHCP status as that is installed before the other flows
     // if macLearning is enabled (DHCP is needed to learn the MacAddress from the host)
@@ -31,10 +32,12 @@
     public OltFlowService.OltFlowsStatus pppoeStatus;
 
     public OltPortStatus(OltFlowService.OltFlowsStatus defaultEapolStatus,
+                         OltFlowService.OltFlowsStatus subscriberEapolStatus,
                          OltFlowService.OltFlowsStatus subscriberFlowsStatus,
                          OltFlowService.OltFlowsStatus dhcpStatus,
                          OltFlowService.OltFlowsStatus pppoeStatus) {
         this.defaultEapolStatus = defaultEapolStatus;
+        this.subscriberEapolStatus = subscriberEapolStatus;
         this.subscriberFlowsStatus = subscriberFlowsStatus;
         this.dhcpStatus = dhcpStatus;
         this.pppoeStatus = pppoeStatus;
@@ -50,19 +53,22 @@
         }
         OltPortStatus that = (OltPortStatus) o;
         return defaultEapolStatus == that.defaultEapolStatus
+                && subscriberEapolStatus == that.subscriberEapolStatus
                 && subscriberFlowsStatus == that.subscriberFlowsStatus
                 && dhcpStatus == that.dhcpStatus;
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(defaultEapolStatus, subscriberFlowsStatus, dhcpStatus);
+        return Objects.hash(defaultEapolStatus, subscriberEapolStatus,
+                            subscriberFlowsStatus, dhcpStatus);
     }
 
     @Override
     public String toString() {
         final StringBuilder sb = new StringBuilder("OltPortStatus{");
         sb.append("defaultEapolStatus=").append(defaultEapolStatus);
+        sb.append(", subscriberEapolStatus=").append(subscriberEapolStatus);
         sb.append(", subscriberFlowsStatus=").append(subscriberFlowsStatus);
         sb.append(", dhcpStatus=").append(dhcpStatus);
         sb.append('}');
diff --git a/impl/src/test/java/org/opencord/olt/impl/OltFlowServiceTest.java b/impl/src/test/java/org/opencord/olt/impl/OltFlowServiceTest.java
index ca90867..baa880d 100644
--- a/impl/src/test/java/org/opencord/olt/impl/OltFlowServiceTest.java
+++ b/impl/src/test/java/org/opencord/olt/impl/OltFlowServiceTest.java
@@ -164,22 +164,22 @@
         // cpStatus map for the test
         component.cpStatus = component.storageService.
                 <ServiceKey, OltPortStatus>consistentMapBuilder().build().asJavaMap();
-        OltPortStatus cp1Status = new OltPortStatus(PENDING_ADD, NONE, NONE, NONE);
+        OltPortStatus cp1Status = new OltPortStatus(PENDING_ADD, NONE, NONE, NONE, NONE);
         component.cpStatus.put(sk1, cp1Status);
 
         //check that we only update the provided value
-        component.updateConnectPointStatus(sk1, ADDED, null, null, null);
+        component.updateConnectPointStatus(sk1, ADDED, null, null, null, null);
         OltPortStatus updated = component.cpStatus.get(sk1);
         Assert.assertEquals(ADDED, updated.defaultEapolStatus);
         Assert.assertEquals(NONE, updated.subscriberFlowsStatus);
         Assert.assertEquals(NONE, updated.dhcpStatus);
 
         // check that it creates an entry if it does not exist
-        component.updateConnectPointStatus(sk2, PENDING_ADD, NONE, NONE, NONE);
+        component.updateConnectPointStatus(sk2, PENDING_ADD, NONE, NONE, NONE, NONE);
         Assert.assertNotNull(component.cpStatus.get(sk2));
 
         // check that if we create a new entry with null values they're converted to NONE
-        component.updateConnectPointStatus(sk3, null, null, null, null);
+        component.updateConnectPointStatus(sk3, null, null, null, null, null);
         updated = component.cpStatus.get(sk3);
         Assert.assertEquals(NONE, updated.defaultEapolStatus);
         Assert.assertEquals(NONE, updated.subscriberFlowsStatus);
@@ -208,12 +208,12 @@
                 <ServiceKey, OltPortStatus>consistentMapBuilder().build().asJavaMap();
 
         // check that an entry is not created if the only status is pending remove
-        component.updateConnectPointStatus(sk1, null, null, PENDING_REMOVE, null);
+        component.updateConnectPointStatus(sk1, null, null, null, PENDING_REMOVE, null);
         OltPortStatus entry = component.cpStatus.get(sk1);
         Assert.assertNull(entry);
 
         // check that an entry is not created if the only status is ERROR
-        component.updateConnectPointStatus(sk1, null, null, ERROR, null);
+        component.updateConnectPointStatus(sk1, null, null, null, ERROR, null);
         entry = component.cpStatus.get(sk1);
         Assert.assertNull(entry);
     }
@@ -238,6 +238,7 @@
                 OltFlowService.OltFlowsStatus.ADDED,
                 NONE,
                 null,
+                null,
                 null
         );
 
@@ -245,6 +246,7 @@
                 REMOVED,
                 NONE,
                 null,
+                null,
                 null
         );
 
@@ -276,12 +278,14 @@
                 ADDED,
                 NONE,
                 NONE,
+                NONE,
                 NONE
         );
 
         OltPortStatus withDhcp = new OltPortStatus(
                 REMOVED,
                 NONE,
+                NONE,
                 ADDED,
                 NONE
         );
@@ -290,6 +294,7 @@
                 REMOVED,
                 ADDED,
                 ADDED,
+                ADDED,
                 NONE
         );
 
@@ -845,7 +850,7 @@
 
         // first test that when we remove the EAPOL flow we return false so that the
         // subscriber is not removed from the queue
-        doReturn(true).when(oltFlowService).areSubscriberFlowsPendingRemoval(any(), any());
+        doReturn(true).when(oltFlowService).areSubscriberFlowsPendingRemoval(any(), any(), eq(true));
         boolean res = oltFlowService.removeSubscriberFlows(sub, DEFAULT_BP_ID_DEFAULT, DEFAULT_MCAST_SERVICE_NAME);
         verify(oltFlowService, times(1))
                 .handleSubscriberDhcpFlows(deviceId, port, OltFlowService.FlowOperation.REMOVE, si);
@@ -863,7 +868,7 @@
 
         // then test that if the tagged EAPOL is not there we install the default EAPOL
         // and return true so we remove the subscriber from the queue
-        doReturn(false).when(oltFlowService).areSubscriberFlowsPendingRemoval(any(), any());
+        doReturn(false).when(oltFlowService).areSubscriberFlowsPendingRemoval(any(), any(), eq(true));
         doReturn(port).when(oltFlowService.deviceService).getPort(deviceId, port.number());
         res = oltFlowService.removeSubscriberFlows(sub, DEFAULT_BP_ID_DEFAULT, DEFAULT_MCAST_SERVICE_NAME);
         verify(oltFlowService, times(1))
@@ -916,7 +921,7 @@
         // cpStatus map for the test
         component.cpStatus = component.storageService.
                 <ServiceKey, OltPortStatus>consistentMapBuilder().build().asJavaMap();
-        OltPortStatus cp1Status = new OltPortStatus(NONE, PENDING_REMOVE, NONE, NONE);
+        OltPortStatus cp1Status = new OltPortStatus(NONE, NONE, PENDING_REMOVE, NONE, NONE);
         component.cpStatus.put(sk1, cp1Status);
 
         FlowRuleEvent event = new FlowRuleEvent(FlowRuleEvent.Type.RULE_REMOVED, flowRule);