CORD-248 Provide host management network connectivity to a VM

Change-Id: I01732c9defe65de227a147c8ad9b63bc4ec18956
diff --git a/src/main/java/org/opencord/cordvtn/impl/CordVtnPipeline.java b/src/main/java/org/opencord/cordvtn/impl/CordVtnPipeline.java
index 2aa9f1f..4e01d2b 100644
--- a/src/main/java/org/opencord/cordvtn/impl/CordVtnPipeline.java
+++ b/src/main/java/org/opencord/cordvtn/impl/CordVtnPipeline.java
@@ -27,7 +27,8 @@
 import org.onlab.packet.IpAddress;
 import org.onlab.packet.TpPort;
 import org.onlab.packet.VlanId;
-import org.onlab.util.ItemNotFoundException;
+import org.onosproject.net.AnnotationKeys;
+import org.onosproject.net.Port;
 import org.opencord.cordvtn.api.Constants;
 import org.opencord.cordvtn.api.CordVtnNode;
 import org.onosproject.core.ApplicationId;
@@ -50,8 +51,11 @@
 import org.onosproject.net.flow.instructions.ExtensionTreatment;
 import org.slf4j.Logger;
 
+import java.util.Optional;
+
 import static com.google.common.base.Preconditions.checkNotNull;
 import static org.onosproject.net.flow.instructions.ExtensionTreatmentType.ExtensionTreatmentTypes.NICIRA_SET_TUNNEL_DST;
+import static org.opencord.cordvtn.api.Constants.DEFAULT_TUNNEL;
 import static org.slf4j.LoggerFactory.getLogger;
 
 /**
@@ -91,6 +95,8 @@
     public static final int VXLAN_UDP_PORT = 4789;
     public static final VlanId VLAN_WAN = VlanId.vlanId((short) 500);
 
+    public static final String PROPERTY_TUNNEL_DST = "tunnelDst";
+
     private ApplicationId appId;
 
     @Activate
@@ -115,19 +121,84 @@
      * Installs table miss rule to a give device.
      *
      * @param node cordvtn node
-     * @param dataPort data plane port number
-     * @param tunnelPort tunnel port number
      */
-    public void initPipeline(CordVtnNode node, PortNumber dataPort, PortNumber tunnelPort) {
+    public void initPipeline(CordVtnNode node) {
         checkNotNull(node);
 
-        processTableZero(node.integrationBridgeId(), dataPort, node.dataIp().ip());
-        processInPortTable(node.integrationBridgeId(), tunnelPort, dataPort);
-        processAccessTypeTable(node.integrationBridgeId(), dataPort);
-        processVlanTable(node.integrationBridgeId(), dataPort);
+        Optional<PortNumber> dataPort = getPortNumber(node.integrationBridgeId(), node.dataIface());
+        Optional<PortNumber> tunnelPort = getPortNumber(node.integrationBridgeId(), DEFAULT_TUNNEL);
+        if (!dataPort.isPresent() || !tunnelPort.isPresent()) {
+            log.warn("Node is not in COMPLETE state");
+            return;
+        }
+
+        Optional<PortNumber> hostMgmtPort = Optional.empty();
+        if (node.hostMgmtIface().isPresent()) {
+            hostMgmtPort = getPortNumber(node.integrationBridgeId(), node.hostMgmtIface().get());
+        }
+
+        processTableZero(node.integrationBridgeId(),
+                         dataPort.get(),
+                         node.dataIp().ip(),
+                         node.localMgmtIp().ip());
+
+        processInPortTable(node.integrationBridgeId(),
+                           tunnelPort.get(),
+                           dataPort.get(),
+                           hostMgmtPort);
+
+        processAccessTypeTable(node.integrationBridgeId(), dataPort.get());
+        processVlanTable(node.integrationBridgeId(), dataPort.get());
     }
 
-    private void processTableZero(DeviceId deviceId, PortNumber dataPort, IpAddress dataIp) {
+    private void processTableZero(DeviceId deviceId, PortNumber dataPort, IpAddress dataIp,
+                                  IpAddress localMgmtIp) {
+        vxlanShuttleRule(deviceId, dataPort, dataIp);
+        localManagementBaseRule(deviceId, localMgmtIp.getIp4Address());
+
+        // take all vlan tagged packet to the VLAN table
+        TrafficSelector selector = DefaultTrafficSelector.builder()
+                .matchVlanId(VlanId.ANY)
+                .build();
+
+        TrafficTreatment treatment = DefaultTrafficTreatment.builder()
+                .transition(TABLE_VLAN)
+                .build();
+
+        FlowRule flowRule = DefaultFlowRule.builder()
+                .fromApp(appId)
+                .withSelector(selector)
+                .withTreatment(treatment)
+                .withPriority(PRIORITY_MANAGEMENT)
+                .forDevice(deviceId)
+                .forTable(TABLE_ZERO)
+                .makePermanent()
+                .build();
+
+        processFlowRule(true, flowRule);
+
+        // take all other packets to the next table
+        selector = DefaultTrafficSelector.builder()
+                .build();
+
+        treatment = DefaultTrafficTreatment.builder()
+                .transition(TABLE_IN_PORT)
+                .build();
+
+        flowRule = DefaultFlowRule.builder()
+                .fromApp(appId)
+                .withSelector(selector)
+                .withTreatment(treatment)
+                .withPriority(PRIORITY_ZERO)
+                .forDevice(deviceId)
+                .forTable(TABLE_ZERO)
+                .makePermanent()
+                .build();
+
+        processFlowRule(true, flowRule);
+    }
+
+    private void vxlanShuttleRule(DeviceId deviceId, PortNumber dataPort, IpAddress dataIp) {
         // take vxlan packet out onto the physical port
         TrafficSelector selector = DefaultTrafficSelector.builder()
                 .matchInPort(PortNumber.LOCAL)
@@ -218,52 +289,98 @@
                 .build();
 
         processFlowRule(true, flowRule);
+    }
 
-        // take all else to the next table
-        selector = DefaultTrafficSelector.builder()
+    private void localManagementBaseRule(DeviceId deviceId, Ip4Address localMgmtIp) {
+        TrafficSelector selector = DefaultTrafficSelector.builder()
+                .matchEthType(Ethernet.TYPE_ARP)
+                .matchArpTpa(localMgmtIp)
                 .build();
 
-        treatment = DefaultTrafficTreatment.builder()
-                .transition(TABLE_IN_PORT)
+        TrafficTreatment treatment = DefaultTrafficTreatment.builder()
+                .setOutput(PortNumber.LOCAL)
                 .build();
 
-        flowRule = DefaultFlowRule.builder()
+        FlowRule flowRule = DefaultFlowRule.builder()
                 .fromApp(appId)
                 .withSelector(selector)
                 .withTreatment(treatment)
-                .withPriority(PRIORITY_ZERO)
+                .withPriority(CordVtnPipeline.PRIORITY_MANAGEMENT)
                 .forDevice(deviceId)
-                .forTable(TABLE_ZERO)
+                .forTable(CordVtnPipeline.TABLE_ZERO)
                 .makePermanent()
                 .build();
 
         processFlowRule(true, flowRule);
 
-        // take all vlan tagged packet to the VLAN table
         selector = DefaultTrafficSelector.builder()
-                .matchVlanId(VlanId.ANY)
+                .matchInPort(PortNumber.LOCAL)
+                .matchEthType(Ethernet.TYPE_IPV4)
+                .matchIPSrc(localMgmtIp.toIpPrefix())
                 .build();
 
         treatment = DefaultTrafficTreatment.builder()
-                .transition(TABLE_VLAN)
+                .transition(CordVtnPipeline.TABLE_DST_IP)
                 .build();
 
         flowRule = DefaultFlowRule.builder()
                 .fromApp(appId)
                 .withSelector(selector)
                 .withTreatment(treatment)
-                .withPriority(PRIORITY_MANAGEMENT)
+                .withPriority(CordVtnPipeline.PRIORITY_MANAGEMENT)
                 .forDevice(deviceId)
-                .forTable(TABLE_ZERO)
+                .forTable(CordVtnPipeline.TABLE_ZERO)
+                .makePermanent()
+                .build();
+
+        processFlowRule(true, flowRule);
+
+        selector = DefaultTrafficSelector.builder()
+                .matchEthType(Ethernet.TYPE_IPV4)
+                .matchIPDst(localMgmtIp.toIpPrefix())
+                .build();
+
+        treatment = DefaultTrafficTreatment.builder()
+                .setOutput(PortNumber.LOCAL)
+                .build();
+
+        flowRule = DefaultFlowRule.builder()
+                .fromApp(appId)
+                .withSelector(selector)
+                .withTreatment(treatment)
+                .withPriority(CordVtnPipeline.PRIORITY_MANAGEMENT)
+                .forDevice(deviceId)
+                .forTable(CordVtnPipeline.TABLE_ZERO)
+                .makePermanent()
+                .build();
+
+        processFlowRule(true, flowRule);
+
+        selector = DefaultTrafficSelector.builder()
+                .matchInPort(PortNumber.LOCAL)
+                .matchEthType(Ethernet.TYPE_ARP)
+                .matchArpSpa(localMgmtIp)
+                .build();
+
+        treatment = DefaultTrafficTreatment.builder()
+                .setOutput(PortNumber.CONTROLLER)
+                .build();
+
+        flowRule = DefaultFlowRule.builder()
+                .fromApp(appId)
+                .withSelector(selector)
+                .withTreatment(treatment)
+                .withPriority(CordVtnPipeline.PRIORITY_MANAGEMENT)
+                .forDevice(deviceId)
+                .forTable(CordVtnPipeline.TABLE_ZERO)
                 .makePermanent()
                 .build();
 
         processFlowRule(true, flowRule);
     }
 
-    private void processInPortTable(DeviceId deviceId, PortNumber tunnelPort, PortNumber dataPort) {
-        checkNotNull(tunnelPort);
-
+    private void processInPortTable(DeviceId deviceId, PortNumber tunnelPort, PortNumber dataPort,
+                                    Optional<PortNumber> hostMgmtPort) {
         TrafficSelector selector = DefaultTrafficSelector.builder()
                 .matchInPort(tunnelPort)
                 .build();
@@ -303,6 +420,28 @@
                 .build();
 
         processFlowRule(true, flowRule);
+
+        if (hostMgmtPort.isPresent()) {
+            selector = DefaultTrafficSelector.builder()
+                    .matchInPort(hostMgmtPort.get())
+                    .build();
+
+            treatment = DefaultTrafficTreatment.builder()
+                    .transition(TABLE_DST_IP)
+                    .build();
+
+            flowRule = DefaultFlowRule.builder()
+                    .fromApp(appId)
+                    .withSelector(selector)
+                    .withTreatment(treatment)
+                    .withPriority(PRIORITY_DEFAULT)
+                    .forDevice(deviceId)
+                    .forTable(TABLE_IN_PORT)
+                    .makePermanent()
+                    .build();
+
+            processFlowRule(true, flowRule);
+        }
     }
 
     private void processAccessTypeTable(DeviceId deviceId, PortNumber dataPort) {
@@ -384,23 +523,30 @@
     }
 
     public ExtensionTreatment tunnelDstTreatment(DeviceId deviceId, Ip4Address remoteIp) {
-        try {
-            Device device = deviceService.getDevice(deviceId);
-            if (!device.is(ExtensionTreatmentResolver.class)) {
-                log.error("The extension treatment is not supported");
-                return null;
-
-            }
-
-            ExtensionTreatmentResolver resolver = device.as(ExtensionTreatmentResolver.class);
-            ExtensionTreatment treatment =
-                    resolver.getExtensionInstruction(NICIRA_SET_TUNNEL_DST.type());
-            treatment.setPropertyValue("tunnelDst", remoteIp);
-            return treatment;
-        } catch (ItemNotFoundException | UnsupportedOperationException |
-                ExtensionPropertyException e) {
-            log.error("Failed to get extension instruction {}", deviceId);
+        Device device = deviceService.getDevice(deviceId);
+        if (device != null && !device.is(ExtensionTreatmentResolver.class)) {
+            log.error("The extension treatment is not supported");
             return null;
         }
+
+        ExtensionTreatmentResolver resolver = device.as(ExtensionTreatmentResolver.class);
+        ExtensionTreatment treatment = resolver.getExtensionInstruction(NICIRA_SET_TUNNEL_DST.type());
+        try {
+            treatment.setPropertyValue(PROPERTY_TUNNEL_DST, remoteIp);
+            return treatment;
+        } catch (ExtensionPropertyException e) {
+            log.warn("Failed to get tunnelDst extension treatment for {}", deviceId);
+            return null;
+        }
+    }
+
+    private Optional<PortNumber> getPortNumber(DeviceId deviceId, String portName) {
+        PortNumber port = deviceService.getPorts(deviceId).stream()
+                .filter(p -> p.annotations().value(AnnotationKeys.PORT_NAME).equals(portName) &&
+                        p.isEnabled())
+                .map(Port::number)
+                .findAny()
+                .orElse(null);
+        return Optional.ofNullable(port);
     }
 }