[VOL-4022] RW-Core Changes For ONU SW Upgrade
New Download/Activate/Retrieve APIs
Change-Id: I2d8a0ec7d8967fd76a261a108f743e75f84c98e9
diff --git a/rw_core/core/device/manager.go b/rw_core/core/device/manager.go
index 9d9f933..43bd2be 100755
--- a/rw_core/core/device/manager.go
+++ b/rw_core/core/device/manager.go
@@ -1643,3 +1643,321 @@
}
return agent.getTransientState(), nil
}
+
+func (dMgr *Manager) DownloadImageToDevice(ctx context.Context, request *voltha.DeviceImageDownloadRequest) (*voltha.DeviceImageResponse, error) {
+ if err := dMgr.validateImageDownloadRequest(request); err != nil {
+ return nil, err
+ }
+
+ ctx = utils.WithRPCMetadataContext(ctx, "DownloadImageToDevice")
+ respCh := make(chan []*voltha.DeviceImageState, len(request.GetDeviceId()))
+
+ downloadReq := &voltha.DeviceImageDownloadRequest{
+ Image: request.Image,
+ ActivateOnSuccess: request.ActivateOnSuccess,
+ CommitOnSuccess: request.CommitOnSuccess,
+ }
+
+ for index, deviceID := range request.DeviceId {
+ //slice-out only single deviceID from the request
+ downloadReq.DeviceId = request.DeviceId[index : index+1]
+
+ go func(deviceID string, req *voltha.DeviceImageDownloadRequest, ch chan []*voltha.DeviceImageState) {
+ agent := dMgr.getDeviceAgent(ctx, deviceID)
+ if agent == nil {
+ logger.Errorw(ctx, "Not-found", log.Fields{"device-id": deviceID})
+ ch <- nil
+ return
+ }
+
+ resp, err := agent.downloadImageToDevice(ctx, req)
+ if err != nil {
+ logger.Errorw(ctx, "download-image-to-device-failed", log.Fields{"device-id": deviceID, "error": err})
+ ch <- nil
+ return
+ }
+
+ err = dMgr.validateDeviceImageResponse(resp)
+ if err != nil {
+ logger.Errorw(ctx, "download-image-to-device-failed", log.Fields{"device-id": deviceID, "error": err})
+ ch <- nil
+ return
+ }
+ ch <- resp.GetDeviceImageStates()
+ }(deviceID.GetId(), downloadReq, respCh)
+
+ }
+
+ return dMgr.waitForAllResponses(ctx, "download-image-to-device", respCh, len(request.GetDeviceId()))
+}
+
+func (dMgr *Manager) GetImageStatus(ctx context.Context, request *voltha.DeviceImageRequest) (*voltha.DeviceImageResponse, error) {
+ if err := dMgr.validateImageRequest(request); err != nil {
+ return nil, err
+ }
+
+ ctx = utils.WithRPCMetadataContext(ctx, "GetImageStatus")
+
+ imageStatusReq := &voltha.DeviceImageRequest{
+ Version: request.Version,
+ CommitOnSuccess: request.CommitOnSuccess,
+ }
+
+ respCh := make(chan []*voltha.DeviceImageState, len(request.GetDeviceId()))
+ for index, deviceID := range request.DeviceId {
+ //slice-out only single deviceID from the request
+ imageStatusReq.DeviceId = request.DeviceId[index : index+1]
+
+ go func(deviceID string, req *voltha.DeviceImageRequest, ch chan []*voltha.DeviceImageState) {
+ agent := dMgr.getDeviceAgent(ctx, deviceID)
+ if agent == nil {
+ logger.Errorw(ctx, "Not-found", log.Fields{"device-id": deviceID})
+ ch <- nil
+ return
+ }
+
+ resp, err := agent.getImageStatus(ctx, req)
+ if err != nil {
+ logger.Errorw(ctx, "get-image-status-failed", log.Fields{"device-id": deviceID, "error": err})
+ ch <- nil
+ return
+ }
+
+ err = dMgr.validateDeviceImageResponse(resp)
+ if err != nil {
+ logger.Errorw(ctx, "get-image-status-failed", log.Fields{"device-id": deviceID, "error": err})
+ ch <- nil
+ return
+ }
+ ch <- resp.GetDeviceImageStates()
+ }(deviceID.GetId(), imageStatusReq, respCh)
+
+ }
+
+ return dMgr.waitForAllResponses(ctx, "get-image-status", respCh, len(request.GetDeviceId()))
+}
+
+func (dMgr *Manager) AbortImageUpgradeToDevice(ctx context.Context, request *voltha.DeviceImageRequest) (*voltha.DeviceImageResponse, error) {
+ if err := dMgr.validateImageRequest(request); err != nil {
+ return nil, err
+ }
+
+ ctx = utils.WithRPCMetadataContext(ctx, "AbortImageUpgradeToDevice")
+ respCh := make(chan []*voltha.DeviceImageState, len(request.GetDeviceId()))
+
+ abortImageReq := &voltha.DeviceImageRequest{
+ Version: request.Version,
+ CommitOnSuccess: request.CommitOnSuccess,
+ }
+
+ for index, deviceID := range request.DeviceId {
+ //slice-out only single deviceID from the request
+ abortImageReq.DeviceId = request.DeviceId[index : index+1]
+
+ go func(deviceID string, req *voltha.DeviceImageRequest, ch chan []*voltha.DeviceImageState) {
+ agent := dMgr.getDeviceAgent(ctx, deviceID)
+ if agent == nil {
+ logger.Errorw(ctx, "Not-found", log.Fields{"device-id": deviceID})
+ ch <- nil
+ return
+ }
+
+ resp, err := agent.abortImageUpgradeToDevice(ctx, req)
+ if err != nil {
+ logger.Errorw(ctx, "abort-image-upgrade-to-device-failed", log.Fields{"device-id": deviceID, "error": err})
+ ch <- nil
+ return
+ }
+
+ err = dMgr.validateDeviceImageResponse(resp)
+ if err != nil {
+ logger.Errorw(ctx, "abort-image-upgrade-to-device-failed", log.Fields{"device-id": deviceID, "error": err})
+ ch <- nil
+ return
+ }
+ ch <- resp.GetDeviceImageStates()
+ }(deviceID.GetId(), abortImageReq, respCh)
+
+ }
+
+ return dMgr.waitForAllResponses(ctx, "abort-image-upgrade-to-device", respCh, len(request.GetDeviceId()))
+}
+
+func (dMgr *Manager) GetOnuImages(ctx context.Context, id *common.ID) (*voltha.OnuImages, error) {
+ if id == nil || id.Id == "" {
+ return nil, status.Errorf(codes.InvalidArgument, "empty device id")
+ }
+
+ ctx = utils.WithRPCMetadataContext(ctx, "GetOnuImages")
+ log.EnrichSpan(ctx, log.Fields{"device-id": id.Id})
+ agent := dMgr.getDeviceAgent(ctx, id.Id)
+ if agent == nil {
+ return nil, status.Errorf(codes.NotFound, "%s", id.Id)
+ }
+
+ resp, err := agent.getOnuImages(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+
+ logger.Debugw(ctx, "get-onu-images-result", log.Fields{"onu-image": resp.Items})
+
+ return resp, nil
+}
+
+func (dMgr *Manager) ActivateImage(ctx context.Context, request *voltha.DeviceImageRequest) (*voltha.DeviceImageResponse, error) {
+ if err := dMgr.validateImageRequest(request); err != nil {
+ return nil, err
+ }
+
+ ctx = utils.WithRPCMetadataContext(ctx, "ActivateImage")
+ respCh := make(chan []*voltha.DeviceImageState, len(request.GetDeviceId()))
+
+ activateImageReq := &voltha.DeviceImageRequest{
+ Version: request.Version,
+ CommitOnSuccess: request.CommitOnSuccess,
+ }
+
+ for index, deviceID := range request.DeviceId {
+ //slice-out only single deviceID from the request
+ activateImageReq.DeviceId = request.DeviceId[index : index+1]
+
+ go func(deviceID string, req *voltha.DeviceImageRequest, ch chan []*voltha.DeviceImageState) {
+ agent := dMgr.getDeviceAgent(ctx, deviceID)
+ if agent == nil {
+ logger.Errorw(ctx, "Not-found", log.Fields{"device-id": deviceID})
+ ch <- nil
+ return
+ }
+
+ resp, err := agent.activateImageOnDevice(ctx, req)
+ if err != nil {
+ logger.Errorw(ctx, "activate-image-failed", log.Fields{"device-id": deviceID, "error": err})
+ ch <- nil
+ return
+ }
+
+ err = dMgr.validateDeviceImageResponse(resp)
+ if err != nil {
+ logger.Errorw(ctx, "activate-image-failed", log.Fields{"device-id": deviceID, "error": err})
+ ch <- nil
+ return
+ }
+
+ ch <- resp.GetDeviceImageStates()
+ }(deviceID.GetId(), activateImageReq, respCh)
+
+ }
+
+ return dMgr.waitForAllResponses(ctx, "activate-image", respCh, len(request.GetDeviceId()))
+}
+
+func (dMgr *Manager) CommitImage(ctx context.Context, request *voltha.DeviceImageRequest) (*voltha.DeviceImageResponse, error) {
+ if err := dMgr.validateImageRequest(request); err != nil {
+ return nil, err
+ }
+
+ ctx = utils.WithRPCMetadataContext(ctx, "CommitImage")
+ respCh := make(chan []*voltha.DeviceImageState, len(request.GetDeviceId()))
+
+ commitImageReq := &voltha.DeviceImageRequest{
+ Version: request.Version,
+ CommitOnSuccess: request.CommitOnSuccess,
+ }
+
+ for index, deviceID := range request.DeviceId {
+ //slice-out only single deviceID from the request
+ commitImageReq.DeviceId = request.DeviceId[index : index+1]
+
+ go func(deviceID string, req *voltha.DeviceImageRequest, ch chan []*voltha.DeviceImageState) {
+ agent := dMgr.getDeviceAgent(ctx, deviceID)
+ if agent == nil {
+ logger.Errorw(ctx, "Not-found", log.Fields{"device-id": deviceID})
+ ch <- nil
+ return
+ }
+
+ resp, err := agent.commitImage(ctx, req)
+ if err != nil {
+ logger.Errorw(ctx, "commit-image-failed", log.Fields{"device-id": deviceID, "error": err})
+ ch <- nil
+ return
+ }
+
+ err = dMgr.validateDeviceImageResponse(resp)
+ if err != nil {
+ logger.Errorf(ctx, "commit-image-failed", log.Fields{"device-id": deviceID, "error": err})
+ ch <- nil
+ return
+ }
+ ch <- resp.GetDeviceImageStates()
+ }(deviceID.GetId(), commitImageReq, respCh)
+
+ }
+
+ return dMgr.waitForAllResponses(ctx, "commit-image", respCh, len(request.GetDeviceId()))
+}
+
+func (dMgr *Manager) validateImageDownloadRequest(request *voltha.DeviceImageDownloadRequest) error {
+ if request == nil || request.Image == nil || len(request.DeviceId) == 0 {
+ return status.Errorf(codes.InvalidArgument, "invalid argument")
+ }
+
+ for _, deviceID := range request.DeviceId {
+ if deviceID == nil {
+ return status.Errorf(codes.InvalidArgument, "id is nil")
+ }
+ }
+ return nil
+}
+
+func (dMgr *Manager) validateImageRequest(request *voltha.DeviceImageRequest) error {
+ if request == nil || len(request.DeviceId) == 0 || request.DeviceId[0] == nil {
+ return status.Errorf(codes.InvalidArgument, "invalid argument")
+ }
+
+ for _, deviceID := range request.DeviceId {
+ if deviceID == nil {
+ return status.Errorf(codes.InvalidArgument, "id is nil")
+ }
+ }
+
+ return nil
+}
+
+func (dMgr *Manager) validateDeviceImageResponse(response *voltha.DeviceImageResponse) error {
+ if response == nil || len(response.GetDeviceImageStates()) == 0 || response.GetDeviceImageStates()[0] == nil {
+ return status.Errorf(codes.Internal, "invalid-response-from-adapter")
+ }
+
+ return nil
+}
+
+func (dMgr *Manager) waitForAllResponses(ctx context.Context, opName string, respCh chan []*voltha.DeviceImageState, expectedResps int) (*voltha.DeviceImageResponse, error) {
+ response := &voltha.DeviceImageResponse{}
+ respCount := 0
+ for {
+ select {
+ case resp, ok := <-respCh:
+ if !ok {
+ logger.Errorw(ctx, opName+"-failed", log.Fields{"error": "channel-closed"})
+ return response, status.Errorf(codes.Aborted, "channel-closed")
+ }
+
+ if resp != nil {
+ logger.Debugw(ctx, opName+"-result", log.Fields{"image-state": resp[0].GetImageState(), "device-id": resp[0].GetDeviceId()})
+ response.DeviceImageStates = append(response.DeviceImageStates, resp...)
+ }
+
+ respCount++
+
+ //check whether all responses received, if so, sent back the collated response
+ if respCount == expectedResps {
+ return response, nil
+ }
+ continue
+ case <-ctx.Done():
+ return nil, status.Errorf(codes.Aborted, opName+"-failed-%s", ctx.Err())
+ }
+ }
+}