[VOL-4022] RW-Core Changes For ONU SW Upgrade
New Download/Activate/Retrieve APIs

Change-Id: I2d8a0ec7d8967fd76a261a108f743e75f84c98e9
diff --git a/rw_core/core/device/manager.go b/rw_core/core/device/manager.go
index 9d9f933..43bd2be 100755
--- a/rw_core/core/device/manager.go
+++ b/rw_core/core/device/manager.go
@@ -1643,3 +1643,321 @@
 	}
 	return agent.getTransientState(), nil
 }
+
+func (dMgr *Manager) DownloadImageToDevice(ctx context.Context, request *voltha.DeviceImageDownloadRequest) (*voltha.DeviceImageResponse, error) {
+	if err := dMgr.validateImageDownloadRequest(request); err != nil {
+		return nil, err
+	}
+
+	ctx = utils.WithRPCMetadataContext(ctx, "DownloadImageToDevice")
+	respCh := make(chan []*voltha.DeviceImageState, len(request.GetDeviceId()))
+
+	downloadReq := &voltha.DeviceImageDownloadRequest{
+		Image:             request.Image,
+		ActivateOnSuccess: request.ActivateOnSuccess,
+		CommitOnSuccess:   request.CommitOnSuccess,
+	}
+
+	for index, deviceID := range request.DeviceId {
+		//slice-out only single deviceID from the request
+		downloadReq.DeviceId = request.DeviceId[index : index+1]
+
+		go func(deviceID string, req *voltha.DeviceImageDownloadRequest, ch chan []*voltha.DeviceImageState) {
+			agent := dMgr.getDeviceAgent(ctx, deviceID)
+			if agent == nil {
+				logger.Errorw(ctx, "Not-found", log.Fields{"device-id": deviceID})
+				ch <- nil
+				return
+			}
+
+			resp, err := agent.downloadImageToDevice(ctx, req)
+			if err != nil {
+				logger.Errorw(ctx, "download-image-to-device-failed", log.Fields{"device-id": deviceID, "error": err})
+				ch <- nil
+				return
+			}
+
+			err = dMgr.validateDeviceImageResponse(resp)
+			if err != nil {
+				logger.Errorw(ctx, "download-image-to-device-failed", log.Fields{"device-id": deviceID, "error": err})
+				ch <- nil
+				return
+			}
+			ch <- resp.GetDeviceImageStates()
+		}(deviceID.GetId(), downloadReq, respCh)
+
+	}
+
+	return dMgr.waitForAllResponses(ctx, "download-image-to-device", respCh, len(request.GetDeviceId()))
+}
+
+func (dMgr *Manager) GetImageStatus(ctx context.Context, request *voltha.DeviceImageRequest) (*voltha.DeviceImageResponse, error) {
+	if err := dMgr.validateImageRequest(request); err != nil {
+		return nil, err
+	}
+
+	ctx = utils.WithRPCMetadataContext(ctx, "GetImageStatus")
+
+	imageStatusReq := &voltha.DeviceImageRequest{
+		Version:         request.Version,
+		CommitOnSuccess: request.CommitOnSuccess,
+	}
+
+	respCh := make(chan []*voltha.DeviceImageState, len(request.GetDeviceId()))
+	for index, deviceID := range request.DeviceId {
+		//slice-out only single deviceID from the request
+		imageStatusReq.DeviceId = request.DeviceId[index : index+1]
+
+		go func(deviceID string, req *voltha.DeviceImageRequest, ch chan []*voltha.DeviceImageState) {
+			agent := dMgr.getDeviceAgent(ctx, deviceID)
+			if agent == nil {
+				logger.Errorw(ctx, "Not-found", log.Fields{"device-id": deviceID})
+				ch <- nil
+				return
+			}
+
+			resp, err := agent.getImageStatus(ctx, req)
+			if err != nil {
+				logger.Errorw(ctx, "get-image-status-failed", log.Fields{"device-id": deviceID, "error": err})
+				ch <- nil
+				return
+			}
+
+			err = dMgr.validateDeviceImageResponse(resp)
+			if err != nil {
+				logger.Errorw(ctx, "get-image-status-failed", log.Fields{"device-id": deviceID, "error": err})
+				ch <- nil
+				return
+			}
+			ch <- resp.GetDeviceImageStates()
+		}(deviceID.GetId(), imageStatusReq, respCh)
+
+	}
+
+	return dMgr.waitForAllResponses(ctx, "get-image-status", respCh, len(request.GetDeviceId()))
+}
+
+func (dMgr *Manager) AbortImageUpgradeToDevice(ctx context.Context, request *voltha.DeviceImageRequest) (*voltha.DeviceImageResponse, error) {
+	if err := dMgr.validateImageRequest(request); err != nil {
+		return nil, err
+	}
+
+	ctx = utils.WithRPCMetadataContext(ctx, "AbortImageUpgradeToDevice")
+	respCh := make(chan []*voltha.DeviceImageState, len(request.GetDeviceId()))
+
+	abortImageReq := &voltha.DeviceImageRequest{
+		Version:         request.Version,
+		CommitOnSuccess: request.CommitOnSuccess,
+	}
+
+	for index, deviceID := range request.DeviceId {
+		//slice-out only single deviceID from the request
+		abortImageReq.DeviceId = request.DeviceId[index : index+1]
+
+		go func(deviceID string, req *voltha.DeviceImageRequest, ch chan []*voltha.DeviceImageState) {
+			agent := dMgr.getDeviceAgent(ctx, deviceID)
+			if agent == nil {
+				logger.Errorw(ctx, "Not-found", log.Fields{"device-id": deviceID})
+				ch <- nil
+				return
+			}
+
+			resp, err := agent.abortImageUpgradeToDevice(ctx, req)
+			if err != nil {
+				logger.Errorw(ctx, "abort-image-upgrade-to-device-failed", log.Fields{"device-id": deviceID, "error": err})
+				ch <- nil
+				return
+			}
+
+			err = dMgr.validateDeviceImageResponse(resp)
+			if err != nil {
+				logger.Errorw(ctx, "abort-image-upgrade-to-device-failed", log.Fields{"device-id": deviceID, "error": err})
+				ch <- nil
+				return
+			}
+			ch <- resp.GetDeviceImageStates()
+		}(deviceID.GetId(), abortImageReq, respCh)
+
+	}
+
+	return dMgr.waitForAllResponses(ctx, "abort-image-upgrade-to-device", respCh, len(request.GetDeviceId()))
+}
+
+func (dMgr *Manager) GetOnuImages(ctx context.Context, id *common.ID) (*voltha.OnuImages, error) {
+	if id == nil || id.Id == "" {
+		return nil, status.Errorf(codes.InvalidArgument, "empty device id")
+	}
+
+	ctx = utils.WithRPCMetadataContext(ctx, "GetOnuImages")
+	log.EnrichSpan(ctx, log.Fields{"device-id": id.Id})
+	agent := dMgr.getDeviceAgent(ctx, id.Id)
+	if agent == nil {
+		return nil, status.Errorf(codes.NotFound, "%s", id.Id)
+	}
+
+	resp, err := agent.getOnuImages(ctx, id)
+	if err != nil {
+		return nil, err
+	}
+
+	logger.Debugw(ctx, "get-onu-images-result", log.Fields{"onu-image": resp.Items})
+
+	return resp, nil
+}
+
+func (dMgr *Manager) ActivateImage(ctx context.Context, request *voltha.DeviceImageRequest) (*voltha.DeviceImageResponse, error) {
+	if err := dMgr.validateImageRequest(request); err != nil {
+		return nil, err
+	}
+
+	ctx = utils.WithRPCMetadataContext(ctx, "ActivateImage")
+	respCh := make(chan []*voltha.DeviceImageState, len(request.GetDeviceId()))
+
+	activateImageReq := &voltha.DeviceImageRequest{
+		Version:         request.Version,
+		CommitOnSuccess: request.CommitOnSuccess,
+	}
+
+	for index, deviceID := range request.DeviceId {
+		//slice-out only single deviceID from the request
+		activateImageReq.DeviceId = request.DeviceId[index : index+1]
+
+		go func(deviceID string, req *voltha.DeviceImageRequest, ch chan []*voltha.DeviceImageState) {
+			agent := dMgr.getDeviceAgent(ctx, deviceID)
+			if agent == nil {
+				logger.Errorw(ctx, "Not-found", log.Fields{"device-id": deviceID})
+				ch <- nil
+				return
+			}
+
+			resp, err := agent.activateImageOnDevice(ctx, req)
+			if err != nil {
+				logger.Errorw(ctx, "activate-image-failed", log.Fields{"device-id": deviceID, "error": err})
+				ch <- nil
+				return
+			}
+
+			err = dMgr.validateDeviceImageResponse(resp)
+			if err != nil {
+				logger.Errorw(ctx, "activate-image-failed", log.Fields{"device-id": deviceID, "error": err})
+				ch <- nil
+				return
+			}
+
+			ch <- resp.GetDeviceImageStates()
+		}(deviceID.GetId(), activateImageReq, respCh)
+
+	}
+
+	return dMgr.waitForAllResponses(ctx, "activate-image", respCh, len(request.GetDeviceId()))
+}
+
+func (dMgr *Manager) CommitImage(ctx context.Context, request *voltha.DeviceImageRequest) (*voltha.DeviceImageResponse, error) {
+	if err := dMgr.validateImageRequest(request); err != nil {
+		return nil, err
+	}
+
+	ctx = utils.WithRPCMetadataContext(ctx, "CommitImage")
+	respCh := make(chan []*voltha.DeviceImageState, len(request.GetDeviceId()))
+
+	commitImageReq := &voltha.DeviceImageRequest{
+		Version:         request.Version,
+		CommitOnSuccess: request.CommitOnSuccess,
+	}
+
+	for index, deviceID := range request.DeviceId {
+		//slice-out only single deviceID from the request
+		commitImageReq.DeviceId = request.DeviceId[index : index+1]
+
+		go func(deviceID string, req *voltha.DeviceImageRequest, ch chan []*voltha.DeviceImageState) {
+			agent := dMgr.getDeviceAgent(ctx, deviceID)
+			if agent == nil {
+				logger.Errorw(ctx, "Not-found", log.Fields{"device-id": deviceID})
+				ch <- nil
+				return
+			}
+
+			resp, err := agent.commitImage(ctx, req)
+			if err != nil {
+				logger.Errorw(ctx, "commit-image-failed", log.Fields{"device-id": deviceID, "error": err})
+				ch <- nil
+				return
+			}
+
+			err = dMgr.validateDeviceImageResponse(resp)
+			if err != nil {
+				logger.Errorf(ctx, "commit-image-failed", log.Fields{"device-id": deviceID, "error": err})
+				ch <- nil
+				return
+			}
+			ch <- resp.GetDeviceImageStates()
+		}(deviceID.GetId(), commitImageReq, respCh)
+
+	}
+
+	return dMgr.waitForAllResponses(ctx, "commit-image", respCh, len(request.GetDeviceId()))
+}
+
+func (dMgr *Manager) validateImageDownloadRequest(request *voltha.DeviceImageDownloadRequest) error {
+	if request == nil || request.Image == nil || len(request.DeviceId) == 0 {
+		return status.Errorf(codes.InvalidArgument, "invalid argument")
+	}
+
+	for _, deviceID := range request.DeviceId {
+		if deviceID == nil {
+			return status.Errorf(codes.InvalidArgument, "id is nil")
+		}
+	}
+	return nil
+}
+
+func (dMgr *Manager) validateImageRequest(request *voltha.DeviceImageRequest) error {
+	if request == nil || len(request.DeviceId) == 0 || request.DeviceId[0] == nil {
+		return status.Errorf(codes.InvalidArgument, "invalid argument")
+	}
+
+	for _, deviceID := range request.DeviceId {
+		if deviceID == nil {
+			return status.Errorf(codes.InvalidArgument, "id is nil")
+		}
+	}
+
+	return nil
+}
+
+func (dMgr *Manager) validateDeviceImageResponse(response *voltha.DeviceImageResponse) error {
+	if response == nil || len(response.GetDeviceImageStates()) == 0 || response.GetDeviceImageStates()[0] == nil {
+		return status.Errorf(codes.Internal, "invalid-response-from-adapter")
+	}
+
+	return nil
+}
+
+func (dMgr *Manager) waitForAllResponses(ctx context.Context, opName string, respCh chan []*voltha.DeviceImageState, expectedResps int) (*voltha.DeviceImageResponse, error) {
+	response := &voltha.DeviceImageResponse{}
+	respCount := 0
+	for {
+		select {
+		case resp, ok := <-respCh:
+			if !ok {
+				logger.Errorw(ctx, opName+"-failed", log.Fields{"error": "channel-closed"})
+				return response, status.Errorf(codes.Aborted, "channel-closed")
+			}
+
+			if resp != nil {
+				logger.Debugw(ctx, opName+"-result", log.Fields{"image-state": resp[0].GetImageState(), "device-id": resp[0].GetDeviceId()})
+				response.DeviceImageStates = append(response.DeviceImageStates, resp...)
+			}
+
+			respCount++
+
+			//check whether all responses received, if so, sent back the collated response
+			if respCount == expectedResps {
+				return response, nil
+			}
+			continue
+		case <-ctx.Done():
+			return nil, status.Errorf(codes.Aborted, opName+"-failed-%s", ctx.Err())
+		}
+	}
+}