VOL-4697: Fixes for rolling update case

Change-Id: I4c529ed8ec90013be0dd953ba4b2bf5708872e63
diff --git a/VERSION b/VERSION
index ff365e0..0aec50e 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-3.1.3
+3.1.4
diff --git a/rw_core/core/adapter/agent.go b/rw_core/core/adapter/agent.go
index e7c3d2b..0a2b632 100644
--- a/rw_core/core/adapter/agent.go
+++ b/rw_core/core/adapter/agent.go
@@ -95,6 +95,7 @@
 	if err != nil {
 		return nil, err
 	}
+
 	c, ok := client.(adapter_service.AdapterServiceClient)
 	if ok {
 		return c, nil
diff --git a/rw_core/core/adapter/manager.go b/rw_core/core/adapter/manager.go
index f5a6cac..9a1ff1c 100644
--- a/rw_core/core/adapter/manager.go
+++ b/rw_core/core/adapter/manager.go
@@ -53,6 +53,10 @@
 	lockAdapterEndPointsMap sync.RWMutex
 	liveProbeInterval       time.Duration
 	coreEndpoint            string
+	rollingUpdateMap        map[string]bool
+	rollingUpdateLock       sync.RWMutex
+	rxStreamCloseChMap      map[string]chan bool
+	rxStreamCloseChLock     sync.RWMutex
 }
 
 // SetAdapterRestartedCallback is used to set the callback that needs to be invoked on an adapter restart
@@ -68,14 +72,16 @@
 	liveProbeInterval time.Duration,
 ) *Manager {
 	return &Manager{
-		adapterDbProxy:    dbPath.Proxy("adapters"),
-		deviceTypeDbProxy: dbPath.Proxy("device_types"),
-		deviceTypes:       make(map[string]*voltha.DeviceType),
-		adapterAgents:     make(map[string]*agent),
-		adapterEndpoints:  make(map[Endpoint]*agent),
-		endpointMgr:       NewEndpointManager(backend),
-		liveProbeInterval: liveProbeInterval,
-		coreEndpoint:      coreEndpoint,
+		adapterDbProxy:     dbPath.Proxy("adapters"),
+		deviceTypeDbProxy:  dbPath.Proxy("device_types"),
+		deviceTypes:        make(map[string]*voltha.DeviceType),
+		adapterAgents:      make(map[string]*agent),
+		adapterEndpoints:   make(map[Endpoint]*agent),
+		endpointMgr:        NewEndpointManager(backend),
+		liveProbeInterval:  liveProbeInterval,
+		coreEndpoint:       coreEndpoint,
+		rollingUpdateMap:   make(map[string]bool),
+		rxStreamCloseChMap: make(map[string]chan bool),
 	}
 }
 
@@ -196,6 +202,38 @@
 	return nil
 }
 
+func (aMgr *Manager) updateAdapter(ctx context.Context, adapter *voltha.Adapter, saveToDb bool) error {
+	aMgr.lockAdapterAgentsMap.Lock()
+	defer aMgr.lockAdapterAgentsMap.Unlock()
+	logger.Debugw(ctx, "updating-adapter", log.Fields{"adapterId": adapter.Id, "vendor": adapter.Vendor,
+		"currentReplica": adapter.CurrentReplica, "totalReplicas": adapter.TotalReplicas, "endpoint": adapter.Endpoint,
+		"version": adapter.Version})
+	if _, exist := aMgr.adapterAgents[adapter.Id]; !exist {
+		logger.Errorw(ctx, "adapter-does-not-exist", log.Fields{"adapterName": adapter.Id})
+		return fmt.Errorf("does-not-exist")
+	}
+	if saveToDb {
+		// Update the adapter to the KV store
+		if err := aMgr.adapterDbProxy.Set(log.WithSpanFromContext(context.Background(), ctx), adapter.Id, adapter); err != nil {
+			logger.Errorw(ctx, "failed-to-update-adapter", log.Fields{"adapterId": adapter.Id, "vendor": adapter.Vendor,
+				"currentReplica": adapter.CurrentReplica, "totalReplicas": adapter.TotalReplicas,
+				"endpoint": adapter.Endpoint, "replica": adapter.CurrentReplica, "total": adapter.TotalReplicas,
+				"version": adapter.Version})
+			return err
+		}
+		logger.Debugw(ctx, "adapter-updated-to-KV-Store", log.Fields{"adapterId": adapter.Id, "vendor": adapter.Vendor,
+			"currentReplica": adapter.CurrentReplica, "totalReplicas": adapter.TotalReplicas, "endpoint": adapter.Endpoint,
+			"replica": adapter.CurrentReplica, "total": adapter.TotalReplicas, "version": adapter.Version})
+	}
+	clonedAdapter := (proto.Clone(adapter)).(*voltha.Adapter)
+	// Use a muted adapter restart handler which is invoked by the corresponding gRPC client on an adapter restart.
+	// This handler just log the restart event.  The actual action taken following an adapter restart
+	// will be done when an adapter re-registers itself.
+	aMgr.adapterAgents[adapter.Id] = newAdapterAgent(aMgr.coreEndpoint, clonedAdapter, aMgr.mutedAdapterRestartedHandler, aMgr.liveProbeInterval)
+	aMgr.adapterEndpoints[Endpoint(adapter.Endpoint)] = aMgr.adapterAgents[adapter.Id]
+	return nil
+}
+
 func (aMgr *Manager) addDeviceTypes(ctx context.Context, deviceTypes *voltha.DeviceTypes, saveToDb bool) error {
 	if deviceTypes == nil {
 		return fmt.Errorf("no-device-type")
@@ -304,16 +342,27 @@
 	}
 
 	if adpt, _ := aMgr.getAdapter(ctx, adapter.Id); adpt != nil {
-		//	Already registered - Adapter may have restarted.  Trigger the reconcile process for that adapter
-		logger.Warnw(ctx, "adapter-restarted", log.Fields{"adapter": adpt.Id, "endpoint": adpt.Endpoint})
-
-		// First reset the adapter connection
 		agt, err := aMgr.getAgent(ctx, adpt.Id)
 		if err != nil {
 			logger.Errorw(ctx, "no-adapter-agent", log.Fields{"error": err})
 			return nil, err
 		}
-		agt.resetConnection(ctx)
+		if adapter.Version != adpt.Version {
+			// Rolling update scenario - could be downgrade or upgrade
+			logger.Infow(ctx, "rolling-update",
+				log.Fields{"adapter": adpt.Id, "endpoint": adpt.Endpoint, "old-version": adpt.Version, "new-version": adapter.Version})
+			// Stop the gRPC connection to the old adapter
+			agt.stop(ctx)
+			if err = aMgr.updateAdapter(ctx, adapter, true); err != nil {
+				return nil, err
+			}
+			aMgr.SetRollingUpdate(ctx, adapter.Endpoint, true)
+		} else {
+			//	Adapter registered and version is the same. The adapter may have restarted.
+			//	Trigger the reconcile process for that adapter
+			logger.Warnw(ctx, "adapter-restarted", log.Fields{"adapter": adpt.Id, "endpoint": adpt.Endpoint})
+			agt.resetConnection(ctx)
+		}
 
 		go func() {
 			err := aMgr.onAdapterRestart(log.WithSpanFromContext(context.Background(), ctx), adpt.Endpoint)
@@ -355,6 +404,23 @@
 	return &empty.Empty{}, nil
 }
 
+func (aMgr *Manager) StartAdapterWithEndPoint(ctx context.Context, endpoint string) error {
+	aMgr.lockAdapterAgentsMap.RLock()
+	defer aMgr.lockAdapterAgentsMap.RUnlock()
+	subCtx := log.WithSpanFromContext(context.Background(), ctx)
+	for _, adapterAgent := range aMgr.adapterAgents {
+		if adapterAgent.adapter.Endpoint == endpoint {
+			if err := adapterAgent.start(subCtx); err != nil {
+				logger.Errorw(subCtx, "failed-to-start-adapter", log.Fields{"error": err})
+				return err
+			}
+			return nil
+		}
+	}
+	logger.Errorw(ctx, "adapter-agent-not-found-for-endpoint", log.Fields{"endpoint": endpoint})
+	return fmt.Errorf("adapter-agent-not-found-for-endpoint-%s", endpoint)
+}
+
 func (aMgr *Manager) GetAdapterTypeByVendorID(vendorID string) (string, error) {
 	aMgr.lockDeviceTypesMap.RLock()
 	defer aMgr.lockDeviceTypesMap.RUnlock()
@@ -421,6 +487,77 @@
 	return result, nil
 }
 
+func (aMgr *Manager) GetRollingUpdate(ctx context.Context, endpoint string) (bool, bool) {
+	aMgr.rollingUpdateLock.RLock()
+	defer aMgr.rollingUpdateLock.RUnlock()
+	val, ok := aMgr.rollingUpdateMap[endpoint]
+	return val, ok
+}
+
+func (aMgr *Manager) SetRollingUpdate(ctx context.Context, endpoint string, status bool) {
+	aMgr.rollingUpdateLock.Lock()
+	defer aMgr.rollingUpdateLock.Unlock()
+	if res, ok := aMgr.rollingUpdateMap[endpoint]; ok {
+		logger.Warnw(ctx, "possible duplicate rolling update - overwriting", log.Fields{"old-status": res, "endpoint": endpoint})
+	}
+	aMgr.rollingUpdateMap[endpoint] = status
+}
+
+func (aMgr *Manager) DeleteRollingUpdate(ctx context.Context, endpoint string) {
+	aMgr.rollingUpdateLock.Lock()
+	defer aMgr.rollingUpdateLock.Unlock()
+	delete(aMgr.rollingUpdateMap, endpoint)
+}
+
+func (aMgr *Manager) RegisterOnRxStreamCloseChMap(ctx context.Context, endpoint string) {
+	aMgr.rxStreamCloseChLock.Lock()
+	defer aMgr.rxStreamCloseChLock.Unlock()
+	if _, ok := aMgr.rxStreamCloseChMap[endpoint]; ok {
+		logger.Warnw(ctx, "duplicate entry on rxStreamCloseChMap - overwriting", log.Fields{"endpoint": endpoint})
+		// First close the old channel
+		close(aMgr.rxStreamCloseChMap[endpoint])
+	}
+	aMgr.rxStreamCloseChMap[endpoint] = make(chan bool, 1)
+}
+
+func (aMgr *Manager) SignalOnRxStreamCloseCh(ctx context.Context, endpoint string) {
+	var closeCh chan bool
+	ok := false
+	aMgr.rxStreamCloseChLock.RLock()
+	if closeCh, ok = aMgr.rxStreamCloseChMap[endpoint]; !ok {
+		logger.Infow(ctx, "no entry on rxStreamCloseChMap", log.Fields{"endpoint": endpoint})
+		aMgr.rxStreamCloseChLock.RUnlock()
+		return
+	}
+	aMgr.rxStreamCloseChLock.RUnlock()
+
+	// close the rx channel
+	closeCh <- true
+
+	aMgr.rxStreamCloseChLock.Lock()
+	defer aMgr.rxStreamCloseChLock.Unlock()
+	delete(aMgr.rxStreamCloseChMap, endpoint)
+}
+
+func (aMgr *Manager) WaitOnRxStreamCloseCh(ctx context.Context, endpoint string) {
+	var closeCh chan bool
+	ok := false
+	aMgr.rxStreamCloseChLock.RLock()
+	if closeCh, ok = aMgr.rxStreamCloseChMap[endpoint]; !ok {
+		logger.Warnw(ctx, "no entry on rxStreamCloseChMap", log.Fields{"endpoint": endpoint})
+		aMgr.rxStreamCloseChLock.RUnlock()
+		return
+	}
+	aMgr.rxStreamCloseChLock.RUnlock()
+
+	select {
+	case <-closeCh:
+		logger.Infow(ctx, "rx stream closed for endpoint", log.Fields{"endpoint": endpoint})
+	case <-time.After(60 * time.Second):
+		logger.Warnw(ctx, "timeout waiting for rx stream close", log.Fields{"endpoint": endpoint})
+	}
+}
+
 func (aMgr *Manager) getAgent(ctx context.Context, adapterID string) (*agent, error) {
 	aMgr.lockAdapterAgentsMap.RLock()
 	defer aMgr.lockAdapterAgentsMap.RUnlock()
diff --git a/rw_core/core/device/manager.go b/rw_core/core/device/manager.go
index c1148b5..b48e603 100755
--- a/rw_core/core/device/manager.go
+++ b/rw_core/core/device/manager.go
@@ -383,7 +383,8 @@
 // adapterRestarted is invoked whenever an adapter is restarted
 func (dMgr *Manager) adapterRestarted(ctx context.Context, adapter *voltha.Adapter) error {
 	logger.Debugw(ctx, "adapter-restarted", log.Fields{"adapter-id": adapter.Id, "vendor": adapter.Vendor,
-		"current-replica": adapter.CurrentReplica, "total-replicas": adapter.TotalReplicas, "restarted-endpoint": adapter.Endpoint})
+		"current-replica": adapter.CurrentReplica, "total-replicas": adapter.TotalReplicas,
+		"restarted-endpoint": adapter.Endpoint, "current-version": adapter.Version})
 
 	numberOfDevicesToReconcile := 0
 	dMgr.deviceAgents.Range(func(key, value interface{}) bool {
@@ -856,6 +857,17 @@
 func (dMgr *Manager) adapterRestartedHandler(ctx context.Context, endpoint string) error {
 	// Get the adapter corresponding to that endpoint
 	if a, _ := dMgr.adapterMgr.GetAdapterWithEndpoint(ctx, endpoint); a != nil {
+		if rollingUpdate, _ := dMgr.adapterMgr.GetRollingUpdate(ctx, endpoint); rollingUpdate {
+			dMgr.adapterMgr.RegisterOnRxStreamCloseChMap(ctx, endpoint)
+			// Blocking call. wait for the old adapters rx stream to close.
+			// That is a signal that the old adapter is completely down
+			dMgr.adapterMgr.WaitOnRxStreamCloseCh(ctx, endpoint)
+			dMgr.adapterMgr.DeleteRollingUpdate(ctx, endpoint)
+			// In case of rolling update we need to start the connection towards the new adapter instance now
+			if err := dMgr.adapterMgr.StartAdapterWithEndPoint(ctx, endpoint); err != nil {
+				return err
+			}
+		}
 		return dMgr.adapterRestarted(ctx, a)
 	}
 	logger.Errorw(ctx, "restarted-adapter-not-found", log.Fields{"endpoint": endpoint})
diff --git a/rw_core/core/device/manager_sbi.go b/rw_core/core/device/manager_sbi.go
index 7a17d6d..c0e1b6b 100644
--- a/rw_core/core/device/manager_sbi.go
+++ b/rw_core/core/device/manager_sbi.go
@@ -548,12 +548,14 @@
 		tempClient, err = stream.Recv()
 		if err != nil {
 			logger.Warnw(ctx, "received-stream-error", log.Fields{"remote-client": remoteClient, "error": err})
+			dMgr.adapterMgr.SignalOnRxStreamCloseCh(ctx, remoteClient.Endpoint)
 			break loop
 		}
 		// Send a response back
 		err = stream.Send(&health.HealthStatus{State: health.HealthStatus_HEALTHY})
 		if err != nil {
 			logger.Warnw(ctx, "sending-stream-error", log.Fields{"remote-client": remoteClient, "error": err})
+			dMgr.adapterMgr.SignalOnRxStreamCloseCh(ctx, remoteClient.Endpoint)
 			break loop
 		}
 
diff --git a/rw_core/main.go b/rw_core/main.go
index 8a9675d..715285d 100644
--- a/rw_core/main.go
+++ b/rw_core/main.go
@@ -140,5 +140,6 @@
 	core.Stop(shutdownCtx)
 
 	elapsed := time.Since(start)
+
 	logger.Infow(ctx, "rw-core-run-time", log.Fields{"core": instanceID, "time": elapsed / time.Second})
 }