Merge "[VOL-4596] Updated dmi version"
diff --git a/internal/bbsim/packetHandlers/packet_tags.go b/internal/bbsim/packetHandlers/packet_tags.go
index 90186a8..fdcfeac 100644
--- a/internal/bbsim/packetHandlers/packet_tags.go
+++ b/internal/bbsim/packetHandlers/packet_tags.go
@@ -145,20 +145,26 @@
 	return dot1q.Priority, nil
 }
 
-// godet inner and outer tag from a packet
-// TODO unit test
+// get inner and outer tag from a packet
 func GetTagsFromPacket(pkt gopacket.Packet) (uint16, uint16, error) {
-	sTag, err := GetVlanTag(pkt)
+	oTag, err := GetVlanTag(pkt)
 	if err != nil {
 		return 0, 0, err
 	}
-	singleTagPkt, err := PopSingleTag(pkt)
+
+	poppedTagPkt, err := PopSingleTag(pkt)
 	if err != nil {
 		return 0, 0, err
 	}
-	cTag, err := GetVlanTag(singleTagPkt)
+
+	if dot1qLayer := poppedTagPkt.Layer(layers.LayerTypeDot1Q); dot1qLayer == nil {
+		//No other tag, the packet is single tagged
+		return oTag, 0, nil
+	}
+
+	iTag, err := GetVlanTag(poppedTagPkt)
 	if err != nil {
 		return 0, 0, err
 	}
-	return sTag, cTag, nil
+	return oTag, iTag, nil
 }
diff --git a/internal/bbsim/packetHandlers/packet_tags_test.go b/internal/bbsim/packetHandlers/packet_tags_test.go
index 231cb71..4cd7830 100644
--- a/internal/bbsim/packetHandlers/packet_tags_test.go
+++ b/internal/bbsim/packetHandlers/packet_tags_test.go
@@ -183,3 +183,43 @@
 	assert.Equal(t, vlan, uint16(0))
 	assert.Equal(t, err.Error(), "no-dot1q-layer-in-packet")
 }
+
+func TestGetTagsFromPacket(t *testing.T) {
+	rawBytes := []byte{10, 20, 30}
+	srcMac := net.HardwareAddr{0xff, 0xff, 0xff, 0xff, byte(1), byte(1)}
+	dstMac := net.HardwareAddr{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
+
+	ethernetLayer := &layers.Ethernet{
+		SrcMAC:       srcMac,
+		DstMAC:       dstMac,
+		EthernetType: 0x800,
+	}
+
+	buffer := gopacket.NewSerializeBuffer()
+	_ = gopacket.SerializeLayers(
+		buffer,
+		gopacket.SerializeOptions{
+			FixLengths: false,
+		},
+		ethernetLayer,
+		gopacket.Payload(rawBytes),
+	)
+	untaggedPkt := gopacket.NewPacket(buffer.Bytes(), layers.LayerTypeEthernet, gopacket.Default)
+	singleTaggedPkt, err := packetHandlers.PushSingleTag(111, untaggedPkt, 0)
+	assert.NilError(t, err)
+	doubleTaggedPkt, err := packetHandlers.PushDoubleTag(111, 222, untaggedPkt, 0)
+	assert.NilError(t, err)
+
+	_, _, err = packetHandlers.GetTagsFromPacket(untaggedPkt)
+	assert.Equal(t, err.Error(), "no-dot1q-layer-in-packet")
+
+	oTag, iTag, err := packetHandlers.GetTagsFromPacket(singleTaggedPkt)
+	assert.NilError(t, err)
+	assert.Equal(t, uint16(111), oTag)
+	assert.Equal(t, uint16(0), iTag)
+
+	oTag, iTag, err = packetHandlers.GetTagsFromPacket(doubleTaggedPkt)
+	assert.NilError(t, err)
+	assert.Equal(t, uint16(111), oTag)
+	assert.Equal(t, uint16(222), iTag)
+}
diff --git a/internal/bbsim/responders/dhcp/dhcp_server.go b/internal/bbsim/responders/dhcp/dhcp_server.go
index 7e7f361..5ba3fa9 100644
--- a/internal/bbsim/responders/dhcp/dhcp_server.go
+++ b/internal/bbsim/responders/dhcp/dhcp_server.go
@@ -20,11 +20,12 @@
 	"encoding/hex"
 	"errors"
 	"fmt"
+	"net"
+
 	"github.com/google/gopacket"
 	"github.com/google/gopacket/layers"
 	"github.com/opencord/bbsim/internal/bbsim/packetHandlers"
 	log "github.com/sirupsen/logrus"
-	"net"
 )
 
 type DHCPServerIf interface {
@@ -167,7 +168,7 @@
 // get a Discover packet an return a valid Offer
 func (s *DHCPServer) handleDiscover(pkt gopacket.Packet) (gopacket.Packet, error) {
 
-	sTag, cTag, err := packetHandlers.GetTagsFromPacket(pkt)
+	oTag, iTag, err := packetHandlers.GetTagsFromPacket(pkt)
 	if err != nil {
 		return nil, err
 	}
@@ -193,8 +194,8 @@
 	}
 
 	dhcpLogger.WithFields(log.Fields{
-		"sTag":      sTag,
-		"cTag":      cTag,
+		"oTag":      oTag,
+		"iTag":      iTag,
 		"clientMac": clientMac,
 		"txId":      txId,
 		"hostname":  string(hostname),
@@ -224,15 +225,22 @@
 		return nil, err
 	}
 
-	taggedResponsePkt, err := packetHandlers.PushDoubleTag(int(sTag), int(cTag), responsePkt, 0)
+	var taggedResponsePkt gopacket.Packet
+	if iTag != 0 { //Double tagged
+		taggedResponsePkt, err = packetHandlers.PushDoubleTag(int(oTag), int(iTag), responsePkt, 0)
+	} else { //Single tagged
+		taggedResponsePkt, err = packetHandlers.PushSingleTag(int(oTag), responsePkt, 0)
+	}
+
 	if err != nil {
 		return nil, err
 	}
+
 	return taggedResponsePkt, nil
 }
 
 func (s *DHCPServer) handleRequest(pkt gopacket.Packet) (gopacket.Packet, error) {
-	sTag, cTag, err := packetHandlers.GetTagsFromPacket(pkt)
+	oTag, iTag, err := packetHandlers.GetTagsFromPacket(pkt)
 	if err != nil {
 		return nil, err
 	}
@@ -258,8 +266,8 @@
 	}
 
 	dhcpLogger.WithFields(log.Fields{
-		"sTag":      sTag,
-		"cTag":      cTag,
+		"oTag":      oTag,
+		"iTag":      iTag,
 		"clientMac": clientMac,
 		"txId":      txId,
 		"hostname":  string(hostname),
@@ -290,10 +298,17 @@
 		return nil, err
 	}
 
-	taggedResponsePkt, err := packetHandlers.PushDoubleTag(int(sTag), int(cTag), responsePkt, 0)
+	var taggedResponsePkt gopacket.Packet
+	if iTag != 0 { //Double tagged
+		taggedResponsePkt, err = packetHandlers.PushDoubleTag(int(oTag), int(iTag), responsePkt, 0)
+	} else { //Single tagged
+		taggedResponsePkt, err = packetHandlers.PushSingleTag(int(oTag), responsePkt, 0)
+	}
+
 	if err != nil {
 		return nil, err
 	}
+
 	return taggedResponsePkt, nil
 }