[VOL-3762] Updates to allow image download and update for the ONU

Change-Id: I0869307e3ef534c1d506b961d61a1ec6f5e13c2e
diff --git a/rw_core/core/device/agent.go b/rw_core/core/device/agent.go
index b6f2b2f..c122574 100755
--- a/rw_core/core/device/agent.go
+++ b/rw_core/core/device/agent.go
@@ -674,6 +674,7 @@
 	cloned.MacAddress = device.MacAddress
 	cloned.Vlan = device.Vlan
 	cloned.Reason = device.Reason
+	cloned.ImageDownloads = device.ImageDownloads
 	return agent.updateDeviceAndReleaseLock(ctx, cloned)
 }
 
diff --git a/rw_core/core/device/agent_image.go b/rw_core/core/device/agent_image.go
index f63faa2..acebaca 100644
--- a/rw_core/core/device/agent_image.go
+++ b/rw_core/core/device/agent_image.go
@@ -18,6 +18,7 @@
 
 import (
 	"context"
+	"github.com/opencord/voltha-protos/v4/go/common"
 
 	"github.com/gogo/protobuf/proto"
 	"github.com/golang/protobuf/ptypes"
@@ -27,21 +28,25 @@
 	"google.golang.org/grpc/status"
 )
 
-func (agent *Agent) downloadImage(ctx context.Context, img *voltha.ImageDownload) (*voltha.OperationResp, error) {
+func (agent *Agent) downloadImage(ctx context.Context, img *voltha.ImageDownload) (*common.OperationResp, error) {
 	if err := agent.requestQueue.WaitForGreenLight(ctx); err != nil {
 		return nil, err
 	}
 	logger.Debugw(ctx, "downloadImage", log.Fields{"device-id": agent.deviceID})
 
-	device := agent.cloneDeviceWithoutLock()
-	if device.AdminState != voltha.AdminState_ENABLED {
-		agent.requestQueue.RequestComplete()
-		return nil, status.Errorf(codes.FailedPrecondition, "device-id:%s, expected-admin-state:%s", agent.deviceID, voltha.AdminState_ENABLED)
+	if agent.device.Root {
+		return nil, status.Errorf(codes.FailedPrecondition, "device-id:%s, is an OLT. Image update "+
+			"not supported by VOLTHA. Use Device Manager or other means", agent.deviceID)
 	}
-	if device.AdminState != voltha.AdminState_ENABLED {
-		logger.Debugw(ctx, "device-not-enabled", log.Fields{"device-id": agent.deviceID})
-		agent.requestQueue.RequestComplete()
-		return nil, status.Errorf(codes.FailedPrecondition, "deviceId:%s, expected-admin-state:%s", agent.deviceID, voltha.AdminState_ENABLED)
+
+	device := agent.cloneDeviceWithoutLock()
+	if device.ImageDownloads != nil {
+		for _, image := range device.ImageDownloads {
+			if image.DownloadState == voltha.ImageDownload_DOWNLOAD_REQUESTED {
+				return nil, status.Errorf(codes.FailedPrecondition, "device-id:%s, already downloading image:%s",
+					agent.deviceID, image.Name)
+			}
+		}
 	}
 
 	// Save the image
@@ -57,47 +62,43 @@
 
 	// Send the request to the adapter
 	subCtx, cancel := context.WithTimeout(log.WithSpanFromContext(context.Background(), ctx), agent.defaultTimeout)
-	ch, err := agent.adapterProxy.DownloadImage(ctx, cloned, clonedImg)
+	ch, err := agent.adapterProxy.DownloadImage(subCtx, cloned, clonedImg)
 	if err != nil {
 		cancel()
 		return nil, err
 	}
-	go agent.waitForAdapterResponse(subCtx, cancel, "downloadImage", ch, agent.onSuccess, agent.onFailure)
+	go agent.waitForAdapterResponse(subCtx, cancel, "downloadImage", ch, agent.onImageSuccess, agent.onImageFailure)
 
-	return &voltha.OperationResp{Code: voltha.OperationResp_OPERATION_SUCCESS}, nil
+	return &common.OperationResp{Code: voltha.OperationResp_OPERATION_SUCCESS}, nil
 }
 
-// isImageRegistered is a helper method to figure out if an image is already registered
-func isImageRegistered(img *voltha.ImageDownload, device *voltha.Device) bool {
-	for _, image := range device.ImageDownloads {
+// getImage is a helper method to figure out if an image is already registered
+func getImage(img *voltha.ImageDownload, device *voltha.Device) (*voltha.ImageDownload, int, error) {
+	for pos, image := range device.ImageDownloads {
 		if image.Id == img.Id && image.Name == img.Name {
-			return true
+			return image, pos, nil
 		}
 	}
-	return false
+	return nil, -1, status.Errorf(codes.FailedPrecondition, "device-id:%s, image-not-registered:%s",
+		device.Id, img.Name)
 }
 
-func (agent *Agent) cancelImageDownload(ctx context.Context, img *voltha.ImageDownload) (*voltha.OperationResp, error) {
+func (agent *Agent) cancelImageDownload(ctx context.Context, img *voltha.ImageDownload) (*common.OperationResp, error) {
 	if err := agent.requestQueue.WaitForGreenLight(ctx); err != nil {
 		return nil, err
 	}
 	logger.Debugw(ctx, "cancelImageDownload", log.Fields{"device-id": agent.deviceID})
 
-	// Verify whether the Image is in the list of image being downloaded
-	device := agent.getDeviceReadOnlyWithoutLock()
-	if !isImageRegistered(img, device) {
-		agent.requestQueue.RequestComplete()
-		return nil, status.Errorf(codes.FailedPrecondition, "device-id:%s, image-not-registered:%s", agent.deviceID, img.Name)
-	}
-
 	// Update image download state
 	cloned := agent.cloneDeviceWithoutLock()
-	for _, image := range cloned.ImageDownloads {
-		if image.Id == img.Id && image.Name == img.Name {
-			image.DownloadState = voltha.ImageDownload_DOWNLOAD_CANCELLED
-		}
+	_, index, err := getImage(img, cloned)
+	if err != nil {
+		agent.requestQueue.RequestComplete()
+		return nil, err
 	}
 
+	cloned.ImageDownloads[index].DownloadState = voltha.ImageDownload_DOWNLOAD_CANCELLED
+
 	if cloned.AdminState != voltha.AdminState_DOWNLOADING_IMAGE {
 		agent.requestQueue.RequestComplete()
 	} else {
@@ -112,7 +113,8 @@
 			cancel()
 			return nil, err
 		}
-		go agent.waitForAdapterResponse(subCtx, cancel, "cancelImageDownload", ch, agent.onSuccess, agent.onFailure)
+		go agent.waitForAdapterResponse(subCtx, cancel, "cancelImageDownload", ch, agent.onImageSuccess,
+			agent.onImageFailure)
 	}
 	return &voltha.OperationResp{Code: voltha.OperationResp_OPERATION_SUCCESS}, nil
 }
@@ -123,25 +125,33 @@
 	}
 	logger.Debugw(ctx, "activateImage", log.Fields{"device-id": agent.deviceID})
 
-	// Verify whether the Image is in the list of image being downloaded
-	device := agent.getDeviceReadOnlyWithoutLock()
-	if !isImageRegistered(img, device) {
+	// Update image download state
+	cloned := agent.cloneDeviceWithoutLock()
+	image, index, err := getImage(img, cloned)
+	if err != nil {
+		agent.requestQueue.RequestComplete()
+		return nil, err
+	}
+
+	if err != nil {
 		agent.requestQueue.RequestComplete()
 		return nil, status.Errorf(codes.FailedPrecondition, "device-id:%s, image-not-registered:%s", agent.deviceID, img.Name)
 	}
-	if device.AdminState == voltha.AdminState_DOWNLOADING_IMAGE {
+
+	if image.DownloadState != voltha.ImageDownload_DOWNLOAD_SUCCEEDED {
+		agent.requestQueue.RequestComplete()
+		return nil, status.Errorf(codes.FailedPrecondition, "device-id:%s, device-has-not-downloaded-image:%s", agent.deviceID, img.Name)
+	}
+
+	//TODO does this need to be removed ?
+	if cloned.AdminState == voltha.AdminState_DOWNLOADING_IMAGE {
 		agent.requestQueue.RequestComplete()
 		return nil, status.Errorf(codes.FailedPrecondition, "device-id:%s, device-in-downloading-state:%s", agent.deviceID, img.Name)
 	}
 
-	// Update image download state
-	cloned := agent.cloneDeviceWithoutLock()
-	for _, image := range cloned.ImageDownloads {
-		if image.Id == img.Id && image.Name == img.Name {
-			image.ImageState = voltha.ImageDownload_IMAGE_ACTIVATING
-		}
-	}
-	// Set the device to downloading_image
+	// Save the image
+	cloned.ImageDownloads[index].ImageState = voltha.ImageDownload_IMAGE_ACTIVATING
+
 	cloned.AdminState = voltha.AdminState_DOWNLOADING_IMAGE
 	if err := agent.updateDeviceAndReleaseLock(ctx, cloned); err != nil {
 		return nil, err
@@ -153,7 +163,7 @@
 		cancel()
 		return nil, err
 	}
-	go agent.waitForAdapterResponse(subCtx, cancel, "activateImageUpdate", ch, agent.onSuccess, agent.onFailure)
+	go agent.waitForAdapterResponse(subCtx, cancel, "activateImageUpdate", ch, agent.onImageSuccess, agent.onFailure)
 
 	// The status of the AdminState will be changed following the update_download_status response from the adapter
 	// The image name will also be removed from the device list
@@ -166,23 +176,19 @@
 	}
 	logger.Debugw(ctx, "revertImage", log.Fields{"device-id": agent.deviceID})
 
-	// Verify whether the Image is in the list of image being downloaded
-	device := agent.getDeviceReadOnlyWithoutLock()
-	if !isImageRegistered(img, device) {
+	// Update image download state
+	cloned := agent.cloneDeviceWithoutLock()
+	_, index, err := getImage(img, cloned)
+	if err != nil {
 		agent.requestQueue.RequestComplete()
 		return nil, status.Errorf(codes.FailedPrecondition, "deviceId:%s, image-not-registered:%s", agent.deviceID, img.Name)
 	}
-	if device.AdminState != voltha.AdminState_ENABLED {
+	if cloned.AdminState != voltha.AdminState_ENABLED {
 		agent.requestQueue.RequestComplete()
 		return nil, status.Errorf(codes.FailedPrecondition, "deviceId:%s, device-not-enabled-state:%s", agent.deviceID, img.Name)
 	}
-	// Update image download state
-	cloned := agent.cloneDeviceWithoutLock()
-	for _, image := range cloned.ImageDownloads {
-		if image.Id == img.Id && image.Name == img.Name {
-			image.ImageState = voltha.ImageDownload_IMAGE_REVERTING
-		}
-	}
+
+	cloned.ImageDownloads[index].ImageState = voltha.ImageDownload_IMAGE_REVERTING
 
 	if err := agent.updateDeviceAndReleaseLock(ctx, cloned); err != nil {
 		return nil, err
@@ -278,3 +284,107 @@
 	}
 	return &voltha.ImageDownloads{Items: device.ImageDownloads}, nil
 }
+
+// onImageFailure brings back the device to Enabled state and sets the image to image download_failed.
+func (agent *Agent) onImageFailure(ctx context.Context, rpc string, response interface{}, reqArgs ...interface{}) {
+	if err := agent.requestQueue.WaitForGreenLight(ctx); err != nil {
+		logger.Errorw(ctx, "can't obtain lock", log.Fields{"rpc": rpc, "device-id": agent.deviceID, "error": err, "args": reqArgs})
+		return
+	}
+	if res, ok := response.(error); ok {
+		logger.Errorw(ctx, "rpc-failed", log.Fields{"rpc": rpc, "device-id": agent.deviceID, "error": res, "args": reqArgs})
+		device := agent.cloneDeviceWithoutLock()
+		//TODO base this on IMAGE ID when created
+		var imageFailed *voltha.ImageDownload
+		var index int
+		if device.ImageDownloads != nil {
+			for pos, image := range device.ImageDownloads {
+				if image.DownloadState == voltha.ImageDownload_DOWNLOAD_REQUESTED ||
+					image.ImageState == voltha.ImageDownload_IMAGE_ACTIVATING {
+					imageFailed = image
+					index = pos
+				}
+			}
+		}
+
+		if imageFailed == nil {
+			logger.Errorw(ctx, "can't find image", log.Fields{"rpc": rpc, "device-id": agent.deviceID, "args": reqArgs})
+			return
+		}
+
+		updatedImages := removeImage(device.ImageDownloads, index)
+
+		// Save the image
+		clonedImg := proto.Clone(imageFailed).(*voltha.ImageDownload)
+		if imageFailed.DownloadState == voltha.ImageDownload_DOWNLOAD_REQUESTED {
+			clonedImg.DownloadState = voltha.ImageDownload_DOWNLOAD_FAILED
+		} else if imageFailed.ImageState == voltha.ImageDownload_IMAGE_ACTIVATING {
+			clonedImg.ImageState = voltha.ImageDownload_IMAGE_INACTIVE
+		}
+		cloned := agent.cloneDeviceWithoutLock()
+		cloned.ImageDownloads = append(updatedImages, clonedImg)
+		//Enabled is the only state we can go back to.
+		cloned.AdminState = voltha.AdminState_ENABLED
+		if err := agent.updateDeviceAndReleaseLock(ctx, cloned); err != nil {
+			logger.Errorw(ctx, "failed-enable-device-after-image-failure",
+				log.Fields{"rpc": rpc, "device-id": agent.deviceID, "error": res, "args": reqArgs})
+		}
+	} else {
+		logger.Errorw(ctx, "rpc-failed-invalid-error", log.Fields{"rpc": rpc, "device-id": agent.deviceID, "args": reqArgs})
+	}
+	// TODO: Post failure message onto kafka
+}
+
+// onImageSuccess brings back the device to Enabled state and sets the image to image download_failed.
+func (agent *Agent) onImageSuccess(ctx context.Context, rpc string, response interface{}, reqArgs ...interface{}) {
+	if err := agent.requestQueue.WaitForGreenLight(ctx); err != nil {
+		logger.Errorw(ctx, "can't obtain lock", log.Fields{"rpc": rpc, "device-id": agent.deviceID, "error": err, "args": reqArgs})
+		return
+	}
+	logger.Errorw(ctx, "rpc-successful", log.Fields{"rpc": rpc, "device-id": agent.deviceID, "response": response, "args": reqArgs})
+	device := agent.cloneDeviceWithoutLock()
+	//TODO base this on IMAGE ID when created
+	var imageSucceeded *voltha.ImageDownload
+	var index int
+	if device.ImageDownloads != nil {
+		for pos, image := range device.ImageDownloads {
+			if image.DownloadState == voltha.ImageDownload_DOWNLOAD_REQUESTED ||
+				image.ImageState == voltha.ImageDownload_IMAGE_ACTIVATING {
+				imageSucceeded = image
+				index = pos
+			}
+		}
+	}
+
+	if imageSucceeded == nil {
+		logger.Errorw(ctx, "can't find image", log.Fields{"rpc": rpc, "device-id": agent.deviceID, "args": reqArgs})
+		return
+	}
+
+	updatedImages := removeImage(device.ImageDownloads, index)
+
+	// Save the image
+	clonedImg := proto.Clone(imageSucceeded).(*voltha.ImageDownload)
+	if imageSucceeded.DownloadState == voltha.ImageDownload_DOWNLOAD_REQUESTED {
+		clonedImg.DownloadState = voltha.ImageDownload_DOWNLOAD_SUCCEEDED
+	} else if imageSucceeded.ImageState == voltha.ImageDownload_IMAGE_ACTIVATING {
+		clonedImg.ImageState = voltha.ImageDownload_IMAGE_ACTIVE
+
+	}
+	cloned := agent.cloneDeviceWithoutLock()
+	cloned.ImageDownloads = append(updatedImages, clonedImg)
+
+	//Enabled is the only state we can go back to.
+	cloned.AdminState = voltha.AdminState_ENABLED
+	if err := agent.updateDeviceAndReleaseLock(ctx, cloned); err != nil {
+		logger.Errorw(ctx, "failed-enable-device-after-image-download-success",
+			log.Fields{"rpc": rpc, "device-id": agent.deviceID, "response": response, "args": reqArgs})
+	}
+
+}
+
+func removeImage(s []*voltha.ImageDownload, i int) []*voltha.ImageDownload {
+	s[i] = s[len(s)-1]
+	// We do not need to put s[i] at the end, as it will be discarded anyway
+	return s[:len(s)-1]
+}
diff --git a/rw_core/core/device/state/transitions.go b/rw_core/core/device/state/transitions.go
index e65c396..eb8f033 100644
--- a/rw_core/core/device/state/transitions.go
+++ b/rw_core/core/device/state/transitions.go
@@ -232,12 +232,6 @@
 		transition{
 			deviceType:    any,
 			previousState: deviceState{Admin: voltha.AdminState_ENABLED, Connection: voltha.ConnectStatus_UNKNOWN, Operational: voltha.OperStatus_UNKNOWN, Transient: voltha.DeviceTransientState_NONE},
-			currentState:  deviceState{Admin: voltha.AdminState_DOWNLOADING_IMAGE, Connection: voltha.ConnectStatus_UNKNOWN, Operational: voltha.OperStatus_UNKNOWN, Transient: voltha.DeviceTransientState_NONE},
-			handlers:      []transitionHandler{dMgr.NotifyInvalidTransition}})
-	transitionMap.transitions = append(transitionMap.transitions,
-		transition{
-			deviceType:    any,
-			previousState: deviceState{Admin: voltha.AdminState_ENABLED, Connection: voltha.ConnectStatus_UNKNOWN, Operational: voltha.OperStatus_UNKNOWN, Transient: voltha.DeviceTransientState_NONE},
 			currentState:  deviceState{Admin: voltha.AdminState_UNKNOWN, Connection: voltha.ConnectStatus_UNKNOWN, Operational: voltha.OperStatus_UNKNOWN, Transient: voltha.DeviceTransientState_NONE},
 			handlers:      []transitionHandler{dMgr.NotifyInvalidTransition}})
 	transitionMap.transitions = append(transitionMap.transitions,
@@ -270,12 +264,6 @@
 			previousState: deviceState{Admin: voltha.AdminState_DISABLED, Connection: voltha.ConnectStatus_UNKNOWN, Operational: voltha.OperStatus_UNKNOWN, Transient: voltha.DeviceTransientState_NONE},
 			currentState:  deviceState{Admin: voltha.AdminState_PREPROVISIONED, Connection: voltha.ConnectStatus_UNKNOWN, Operational: voltha.OperStatus_UNKNOWN, Transient: voltha.DeviceTransientState_NONE},
 			handlers:      []transitionHandler{dMgr.NotifyInvalidTransition}})
-	transitionMap.transitions = append(transitionMap.transitions,
-		transition{
-			deviceType:    any,
-			previousState: deviceState{Admin: voltha.AdminState_DOWNLOADING_IMAGE, Connection: voltha.ConnectStatus_UNKNOWN, Operational: voltha.OperStatus_UNKNOWN, Transient: voltha.DeviceTransientState_NONE},
-			currentState:  deviceState{Admin: voltha.AdminState_DISABLED, Connection: voltha.ConnectStatus_UNKNOWN, Operational: voltha.OperStatus_UNKNOWN, Transient: voltha.DeviceTransientState_NONE},
-			handlers:      []transitionHandler{dMgr.NotifyInvalidTransition}})
 
 	return &transitionMap
 }
diff --git a/rw_core/core/device/state/transitions_test.go b/rw_core/core/device/state/transitions_test.go
index a085f97..4ad3f49 100644
--- a/rw_core/core/device/state/transitions_test.go
+++ b/rw_core/core/device/state/transitions_test.go
@@ -315,20 +315,12 @@
 	assertInvalidTransition(t, device, previousDevice)
 
 	previousDevice = getDevice(true, voltha.AdminState_ENABLED, voltha.ConnectStatus_UNKNOWN, voltha.OperStatus_UNKNOWN)
-	device = getDevice(true, voltha.AdminState_DOWNLOADING_IMAGE, voltha.ConnectStatus_UNKNOWN, voltha.OperStatus_UNKNOWN)
-	assertInvalidTransition(t, device, previousDevice)
-
-	previousDevice = getDevice(true, voltha.AdminState_ENABLED, voltha.ConnectStatus_UNKNOWN, voltha.OperStatus_UNKNOWN)
 	device = getDevice(true, voltha.AdminState_UNKNOWN, voltha.ConnectStatus_UNKNOWN, voltha.OperStatus_UNKNOWN)
 	assertInvalidTransition(t, device, previousDevice)
 
 	previousDevice = getDevice(true, voltha.AdminState_DISABLED, voltha.ConnectStatus_UNKNOWN, voltha.OperStatus_UNKNOWN)
 	device = getDevice(true, voltha.AdminState_PREPROVISIONED, voltha.ConnectStatus_UNKNOWN, voltha.OperStatus_UNKNOWN)
 	assertInvalidTransition(t, device, previousDevice)
-
-	previousDevice = getDevice(true, voltha.AdminState_DOWNLOADING_IMAGE, voltha.ConnectStatus_UNKNOWN, voltha.OperStatus_UNKNOWN)
-	device = getDevice(true, voltha.AdminState_DISABLED, voltha.ConnectStatus_UNKNOWN, voltha.OperStatus_UNKNOWN)
-	assertInvalidTransition(t, device, previousDevice)
 }
 
 func TestNoOpTransitions(t *testing.T) {