Refactor EAPOL responder

Change-Id: I423585fcf689f7b10ac1143ba9f2ff7c2eb51f26
diff --git a/core/eapol.go b/core/eapol.go
index e40da47..8dc7e7b 100644
--- a/core/eapol.go
+++ b/core/eapol.go
@@ -18,78 +18,79 @@
 
 import (
 	"context"
-	"gerrit.opencord.org/voltha-bbsim/common/logger"
-	"net"
-	"errors"
+	"crypto/md5"
 	"encoding/hex"
+	"errors"
+	"fmt"
+	"gerrit.opencord.org/voltha-bbsim/common/logger"
 	"github.com/google/gopacket"
 	"github.com/google/gopacket/layers"
-	"crypto/md5"
+	"net"
 	"sync"
 )
 
-type eapState int
+type clientState int
+
 const (
-	START eapState = iota	//TODO: This state definition should support 802.1X
-	RESPID
-	RESPCHA
-	SUCCESS
+	EAP_START clientState = iota + 1 //TODO: This state definition should support 802.1X
+	EAP_RESPID
+	EAP_RESPCHA
+	EAP_SUCCESS
 )
 
-
-
-type responder struct {
-	peers   map [key] *peerInstance
-	eapolIn chan *EAPByte
+type eapResponder struct {
+	clients map[clientKey]*eapClientInstance
+	eapolIn chan *byteMsg
 }
 
-type peerInstance struct{
-	key      key
+type clientInstance interface {
+	transitState(cur clientState, recvbytes []byte) (next clientState, sendbytes []byte, err error)
+	getState() clientState
+	getKey() clientKey
+}
+
+type eapClientInstance struct {
+	key      clientKey
 	srcaddr  *net.HardwareAddr
 	version  uint8
 	curId    uint8
-	curState eapState
+	curState clientState
 }
 
-type key struct {
+type clientKey struct {
 	intfid uint32
 	onuid  uint32
 }
 
-var resp *responder
+var resp *eapResponder
 var once sync.Once
-func getResponder () *responder {
-	once.Do(func(){
-		resp = &responder{peers: make(map[key] *peerInstance), eapolIn: nil}
+
+func getEAPResponder() *eapResponder {
+	once.Do(func() {
+		resp = &eapResponder{clients: make(map[clientKey]*eapClientInstance), eapolIn: nil}
 	})
 	return resp
 }
 
-func RunEapolResponder(ctx context.Context, eapolOut chan *EAPPkt, eapolIn chan *EAPByte, errch chan error) {
-	responder := getResponder()
+func RunEapolResponder(ctx context.Context, eapolOut chan *byteMsg, eapolIn chan *byteMsg, errch chan error) {
+	responder := getEAPResponder()
 	responder.eapolIn = eapolIn
-	peers := responder.peers
-	
+
 	go func() {
 		logger.Debug("EAPOL response process starts")
 		defer logger.Debug("EAPOL response process was done")
 		for {
 			select {
 			case msg := <- eapolOut:
-				logger.Debug("Received eapol from eapolOut")
-				intfid := msg.IntfId
-				onuid := msg.OnuId
-
-				if peer, ok := peers[key{intfid: intfid, onuid:onuid}]; ok {
-					logger.Debug("Key hit intfid:%d onuid: %d", intfid, onuid)
-					curstate := peer.curState
-					nextstate, err := peer.transitState(curstate, eapolIn, msg.Pkt)
-					if err != nil {
-						logger.Error("Failed to transitState: %s", err)
-					}
-					peer.curState = nextstate
+				logger.Debug("Received eapol from eapolOut intfid:%d onuid:%d", msg.IntfId, msg.OnuId)
+				responder := getEAPResponder()
+				clients := responder.clients
+				if c, ok := clients[clientKey{intfid: msg.IntfId, onuid: msg.OnuId}]; ok {
+					logger.Debug("Got client intfid:%d onuid: %d", c.key.intfid, c.key.onuid)
+					nextstate := respondMessage("EAPOL", *c, msg, eapolIn)
+					c.updateState(nextstate)
 				} else {
-					logger.Error("Failed to find eapol peer instance intfid:%d onuid:%d", intfid, onuid)
+					logger.Error("Failed to find eapol client instance intfid:%d onuid:%d", msg.IntfId, msg.OnuId)
 				}
 			case <-ctx.Done():
 				return
@@ -98,95 +99,132 @@
 	}()
 }
 
-func startPeer (intfid uint32, onuid uint32) error {
-	peer := peerInstance{key: key{intfid: intfid, onuid: onuid},
-						srcaddr: &net.HardwareAddr{0x2e, 0x60, 0x70, 0x13, 0x07, byte(onuid)},
-						version: 1,
-						curId: 0,
-						curState: START}
+func respondMessage(msgtype string, client clientInstance, recvmsg *byteMsg, msgInCh chan *byteMsg) clientState {
+	curstate := client.getState()
+	nextstate, sendbytes, err := client.transitState(curstate, recvmsg.Byte)
 
-	eap := peer.createEAPStart()
-	bytes := peer.createEAPOL(eap)
-	resp := getResponder()
+	if err != nil {
+		msg := fmt.Sprintf("Failed to transitState in %s: %s", msgtype, err)
+		logger.Error(msg, err)
+	}
+
+	if sendbytes != nil {
+		key := client.getKey()
+		if err := sendBytes(key, sendbytes, msgInCh); err != nil {
+			msg := fmt.Sprintf("Failed to sendBytes in %s: %s", msgtype, err)
+			logger.Error(msg)
+		}
+	} else {
+		logger.Debug("sendbytes is nil")
+	}
+	return nextstate
+}
+
+func startEAPClient(intfid uint32, onuid uint32) error {
+	client := eapClientInstance{key: clientKey{intfid: intfid, onuid: onuid},
+		srcaddr:  &net.HardwareAddr{0x2e, 0x60, 0x70, 0x13, 0x07, byte(onuid)},
+		version:  1,
+		curId:    0,
+		curState: EAP_START}
+
+	eap := client.createEAPStart()
+	bytes := client.createEAPOL(eap)
+	resp := getEAPResponder()
 	eapolIn := resp.eapolIn
-	if err := peer.sendPkt(bytes, eapolIn); err != nil {
+	if err := sendBytes(clientKey{intfid, onuid}, bytes, eapolIn); err != nil {
 		return errors.New("Failed to send EAPStart")
 	}
 	logger.Debug("Sending EAPStart")
-	logger.Debug(hex.Dump(bytes))
-	//peers[key{intfid: intfid, onuid: onuid}] = &peer
-	resp.peers[key{intfid: intfid, onuid: onuid}] = &peer
+	//clients[key{intfid: intfid, onuid: onuid}] = &client
+	resp.clients[clientKey{intfid: intfid, onuid: onuid}] = &client
 	return nil
 }
 
-func (p *peerInstance)transitState(cur eapState, omciIn chan *EAPByte, recvpkt gopacket.Packet) (next eapState, err error) {
-	logger.Debug("currentState:%d", cur)
+func (c eapClientInstance) transitState(cur clientState, recvbytes []byte) (next clientState, respbytes []byte, err error) {
+	recvpkt := gopacket.NewPacket(recvbytes, layers.LayerTypeEthernet, gopacket.Default)
 	eap, err := extractEAP(recvpkt)
 	if err != nil {
-		return cur, nil
+		return cur, nil, nil
 	}
 	if eap.Code == layers.EAPCodeRequest && eap.Type == layers.EAPTypeIdentity {
 		logger.Debug("Received EAP-Request/Identity")
 		logger.Debug(recvpkt.Dump())
-		p.curId = eap.Id
-		if cur == START {
-			reseap := p.createEAPResID()
-			pkt := p.createEAPOL(reseap)
-			logger.Debug("Sending EAP-Response/Identity")
-			if err != p.sendPkt(pkt, omciIn) {
-				return cur, err
-			}
-			return RESPID, nil
+		c.curId = eap.Id
+		if cur == EAP_START {
+			reseap := c.createEAPResID()
+			pkt := c.createEAPOL(reseap)
+			return EAP_RESPID, pkt, nil
 		}
 	} else if eap.Code == layers.EAPCodeRequest && eap.Type == layers.EAPTypeOTP {
 		logger.Debug("Received EAP-Request/Challenge")
 		logger.Debug(recvpkt.Dump())
-		if cur == RESPID {
-			p.curId = eap.Id
-			resdata := getMD5Res (p.curId, eap)
-			resdata = append([]byte{0x10}, resdata ...)
-			reseap := p.createEAPResCha(resdata)
-			pkt := p.createEAPOL(reseap)
-			logger.Debug("Sending EAP-Response/Challenge")
-			if err != p.sendPkt(pkt, omciIn) {
-				return cur, err
-			}
-			return RESPCHA, nil
+		if cur == EAP_RESPID {
+			c.curId = eap.Id
+			senddata := getMD5Data(c.curId, eap)
+			senddata = append([]byte{0x10}, senddata...)
+			sendeap := c.createEAPResCha(senddata)
+			pkt := c.createEAPOL(sendeap)
+			return EAP_RESPCHA, pkt, nil
 		}
 	} else if eap.Code == layers.EAPCodeSuccess && eap.Type == layers.EAPTypeNone {
 		logger.Debug("Received EAP-Success")
 		logger.Debug(recvpkt.Dump())
-		if cur == RESPCHA {
-			return SUCCESS, nil
+		if cur == EAP_RESPCHA {
+			return EAP_SUCCESS, nil, nil
 		}
 	} else {
 		logger.Debug("Received unsupported EAP")
-		return cur, nil
+		return cur, nil, nil
 	}
 	logger.Debug("State transition does not support..current state:%d", cur)
 	logger.Debug(recvpkt.Dump())
-	return cur, nil
+	return cur, nil, nil
 }
 
-func (p *peerInstance) createEAPOL (eap *layers.EAP) []byte {
+func (c eapClientInstance) getState() clientState {
+	return c.curState
+}
+
+func (c *eapClientInstance) updateState(state clientState) {
+	msg := fmt.Sprintf("EAP update state intfid:%d onuid:%d state:%d", c.key.intfid, c.key.onuid, state)
+	logger.Debug(msg)
+	c.curState = state
+}
+
+func (c eapClientInstance) getKey() clientKey {
+	return c.key
+}
+
+func sendBytes(key clientKey, pkt []byte, chIn chan *byteMsg) error {
+	// Send our packet
+	msg := byteMsg{IntfId: key.intfid,
+		OnuId: key.onuid,
+		Byte:  pkt}
+	chIn <- &msg
+	logger.Debug("sendBytes intfid:%d onuid:%d", key.intfid, key.onuid)
+	logger.Debug(hex.Dump(msg.Byte))
+	return nil
+}
+
+func (c *eapClientInstance) createEAPOL(eap *layers.EAP) []byte {
 	buffer := gopacket.NewSerializeBuffer()
 	options := gopacket.SerializeOptions{}
 
 	ethernetLayer := &layers.Ethernet{
-		SrcMAC: *p.srcaddr,
-		DstMAC: net.HardwareAddr{0x01, 0x80, 0xC2, 0x00, 0x00, 0x03},
+		SrcMAC:       *c.srcaddr,
+		DstMAC:       net.HardwareAddr{0x01, 0x80, 0xC2, 0x00, 0x00, 0x03},
 		EthernetType: layers.EthernetTypeEAPOL,
 	}
 
-	if eap == nil {	// EAP Start
+	if eap == nil { // EAP Start
 		gopacket.SerializeLayers(buffer, options,
 			ethernetLayer,
-			&layers.EAPOL{Version: p.version, Type:1, Length: 0},
+			&layers.EAPOL{Version: c.version, Type: 1, Length: 0},
 		)
 	} else {
 		gopacket.SerializeLayers(buffer, options,
 			ethernetLayer,
-			&layers.EAPOL{Version: p.version, Type:0, Length: eap.Length},
+			&layers.EAPOL{Version: c.version, Type: 0, Length: eap.Length},
 			eap,
 		)
 	}
@@ -194,51 +232,40 @@
 	return bytes
 }
 
-func (p *peerInstance) createEAPStart () *layers.EAP {
+func (c *eapClientInstance) createEAPStart() *layers.EAP {
 	return nil
 }
 
-func (p *peerInstance) createEAPResID () *layers.EAP {
+func (c *eapClientInstance) createEAPResID() *layers.EAP {
 	eap := layers.EAP{Code: layers.EAPCodeResponse,
-		Id: p.curId,
-		Length: 9,
-		Type: layers.EAPTypeIdentity,
-		TypeData: []byte{0x75, 0x73, 0x65, 0x72 }}
+		Id:       c.curId,
+		Length:   9,
+		Type:     layers.EAPTypeIdentity,
+		TypeData: []byte{0x75, 0x73, 0x65, 0x72}}
 	return &eap
 }
 
-func (p *peerInstance) createEAPResCha (payload []byte) *layers.EAP {
+func (c *eapClientInstance) createEAPResCha(payload []byte) *layers.EAP {
 	eap := layers.EAP{Code: layers.EAPCodeResponse,
-		Id: p.curId, Length: 22,
-		Type: layers.EAPTypeOTP,
+		Id: c.curId, Length: 22,
+		Type:     layers.EAPTypeOTP,
 		TypeData: payload}
 	return &eap
 }
 
-func (p *peerInstance) sendPkt (pkt []byte, omciIn chan *EAPByte) error {
-	// Send our packet
-	msg := EAPByte{IntfId: p.key.intfid,
-					OnuId: p.key.onuid,
-					Byte: pkt}
-	omciIn <- &msg
-	logger.Debug("sendPkt intfid:%d onuid:%d", p.key.intfid, p.key.onuid)
-	logger.Debug(hex.Dump(msg.Byte))
-	return nil
-}
-
-func getMD5Res (id uint8, eap *layers.EAP) []byte {
+func getMD5Data (id uint8, eap *layers.EAP) []byte {
 	i := byte(id)
 	C := []byte(eap.BaseLayer.Contents)[6:]
 	P := []byte{i, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64} //"password"
-	data := md5.Sum(append(P, C ...))
+	data := md5.Sum(append(P, C...))
 	ret := make([]byte, 16)
-	for j := 0; j < 16; j ++ {
+	for j := 0; j < 16; j++ {
 		ret[j] = data[j]
 	}
 	return ret
 }
 
-func extractEAPOL (pkt gopacket.Packet) (*layers.EAPOL, error) {
+func extractEAPOL(pkt gopacket.Packet) (*layers.EAPOL, error) {
 	layerEAPOL := pkt.Layer(layers.LayerTypeEAPOL)
 	eapol, _ := layerEAPOL.(*layers.EAPOL)
 	if eapol == nil {
@@ -247,11 +274,11 @@
 	return eapol, nil
 }
 
-func extractEAP (pkt gopacket.Packet) (*layers.EAP, error) {
+func extractEAP(pkt gopacket.Packet) (*layers.EAP, error) {
 	layerEAP := pkt.Layer(layers.LayerTypeEAP)
 	eap, _ := layerEAP.(*layers.EAP)
 	if eap == nil {
 		return nil, errors.New("Cannot extract EAP")
 	}
 	return eap, nil
-}
\ No newline at end of file
+}