diff --git a/controllers/object_controls.go b/controllers/object_controls.go index 3c068c4c83..828fd59052 100644 --- a/controllers/object_controls.go +++ b/controllers/object_controls.go @@ -1064,19 +1064,19 @@ func TransformDriver(obj *appsv1.DaemonSet, config *gpuv1.ClusterPolicySpec, n C // Set the computed digest in driver-manager initContainer driverManagerContainer := findContainerByName(obj.Spec.Template.Spec.InitContainers, "k8s-driver-manager") if driverManagerContainer != nil { - setContainerEnv(driverManagerContainer, "DRIVER_CONFIG_DIGEST", configDigest) + setContainerEnv(driverManagerContainer, driverconfig.DriverConfigDigestEnvName, configDigest) } // Set the computed digest in nvidia-driver container driverContainer := findContainerByName(obj.Spec.Template.Spec.Containers, "nvidia-driver-ctr") if driverContainer != nil { - setContainerEnv(driverContainer, "DRIVER_CONFIG_DIGEST", configDigest) + setContainerEnv(driverContainer, driverconfig.DriverConfigDigestEnvName, configDigest) } // Used by dtk-build-driver to determine if fast path should be used (skip rebuild) driverToolkitContainer := findContainerByName(obj.Spec.Template.Spec.Containers, "openshift-driver-toolkit-ctr") if driverToolkitContainer != nil { - setContainerEnv(driverToolkitContainer, "DRIVER_CONFIG_DIGEST", configDigest) + setContainerEnv(driverToolkitContainer, driverconfig.DriverConfigDigestEnvName, configDigest) } // set hostNetwork for driver if specified diff --git a/controllers/upgrade_controller.go b/controllers/upgrade_controller.go index e82e2de352..4ffe0e2640 100644 --- a/controllers/upgrade_controller.go +++ b/controllers/upgrade_controller.go @@ -47,6 +47,7 @@ import ( gpuv1 "github.com/NVIDIA/gpu-operator/api/nvidia/v1" nvidiav1alpha1 "github.com/NVIDIA/gpu-operator/api/nvidia/v1alpha1" + driverconfig "github.com/NVIDIA/gpu-operator/internal/config" ) // UpgradeReconciler reconciles Driver Daemon Sets for upgrade @@ -231,10 +232,38 @@ func (r *UpgradeReconciler) removeNodeUpgradeStateLabels(ctx context.Context) er return nil } +// driverPodRestartOnly is the upgrade controller's RestartOnlyPredicate: it allows an +// out-of-sync driver pod to be restarted in place when the running pod and the desired +// DaemonSet have the same DRIVER_CONFIG_DIGEST, i.e. the install-relevant config is +// unchanged (e.g. only a helm.sh/chart label changed). If either digest is missing, it +// returns false and the node takes the full upgrade flow. +func (r *UpgradeReconciler) driverPodRestartOnly(_ context.Context, pod *corev1.Pod, ds *appsv1.DaemonSet) (bool, error) { + if pod == nil || ds == nil { + return false, nil + } + desired := driverconfig.DriverConfigDigestFromPodSpec(&ds.Spec.Template.Spec) + running := driverconfig.DriverConfigDigestFromPodSpec(&pod.Spec) + if desired == "" || running == "" { + r.Log.V(consts.LogLevelDebug).Info("driver config digest missing; taking full upgrade flow", + "pod", pod.Name, "daemonset", ds.Name, "desiredDigest", desired, "runningDigest", running) + return false, nil + } + restartOnly := desired == running + r.Log.V(consts.LogLevelDebug).Info("evaluated driver config digest for restart-only routing", + "pod", pod.Name, "daemonset", ds.Name, + "desiredDigest", desired, "runningDigest", running, "restartOnly", restartOnly) + return restartOnly, nil +} + // SetupWithManager sets up the controller with the Manager. // //nolint:dupl func (r *UpgradeReconciler) SetupWithManager(ctx context.Context, mgr ctrl.Manager) error { + // Route digest-unchanged driver pod-template changes to a restart-only upgrade. + if r.StateManager != nil { + r.StateManager = r.StateManager.WithRestartOnlyPredicate(r.driverPodRestartOnly) + } + // Create a new controller c, err := controller.New("upgrade-controller", mgr, controller.Options{Reconciler: r, MaxConcurrentReconciles: 1, RateLimiter: workqueue.NewTypedItemExponentialFailureRateLimiter[reconcile.Request](minDelayCR, maxDelayCR)}) diff --git a/controllers/upgrade_controller_test.go b/controllers/upgrade_controller_test.go index 3b72da082f..3251790e3c 100644 --- a/controllers/upgrade_controller_test.go +++ b/controllers/upgrade_controller_test.go @@ -17,11 +17,17 @@ package controllers import ( + "context" "fmt" "testing" upgrade_v1alpha1 "github.com/NVIDIA/k8s-operator-libs/api/upgrade/v1alpha1" + "github.com/go-logr/logr" "github.com/stretchr/testify/assert" + appsv1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" + + driverconfig "github.com/NVIDIA/gpu-operator/internal/config" ) func TestSetDrainSpecPodSelector(t *testing.T) { @@ -69,3 +75,44 @@ func TestSetDrainSpecPodSelector(t *testing.T) { }) } } + +func TestDriverPodRestartOnly(t *testing.T) { + driverPod := func(digest string) *corev1.Pod { + return &corev1.Pod{Spec: corev1.PodSpec{Containers: []corev1.Container{{ + Name: "nvidia-driver-ctr", + Env: []corev1.EnvVar{{Name: driverconfig.DriverConfigDigestEnvName, Value: digest}}, + }}}} + } + driverDS := func(digest string) *appsv1.DaemonSet { + return &appsv1.DaemonSet{Spec: appsv1.DaemonSetSpec{Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{Containers: []corev1.Container{{ + Name: "nvidia-driver-ctr", + Env: []corev1.EnvVar{{Name: driverconfig.DriverConfigDigestEnvName, Value: digest}}, + }}}, + }}} + } + + r := &UpgradeReconciler{Log: logr.Discard()} + ctx := context.Background() + + tests := []struct { + name string + pod *corev1.Pod + ds *appsv1.DaemonSet + wantRestart bool + }{ + {name: "equal digests -> restart-only", pod: driverPod("same"), ds: driverDS("same"), wantRestart: true}, + {name: "differing digests -> full upgrade", pod: driverPod("old"), ds: driverDS("new"), wantRestart: false}, + {name: "missing digest on pod -> full upgrade", pod: driverPod(""), ds: driverDS("new"), wantRestart: false}, + {name: "missing digest on daemonset -> full upgrade", pod: driverPod("old"), ds: driverDS(""), wantRestart: false}, + {name: "nil pod -> full upgrade", pod: nil, ds: driverDS("x"), wantRestart: false}, + {name: "nil daemonset -> full upgrade", pod: driverPod("x"), ds: nil, wantRestart: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := r.driverPodRestartOnly(ctx, tt.pod, tt.ds) + assert.NoError(t, err) + assert.Equal(t, tt.wantRestart, got) + }) + } +} diff --git a/internal/config/driver_config_digest.go b/internal/config/driver_config_digest.go index 565b14ed45..5a5f1241a5 100644 --- a/internal/config/driver_config_digest.go +++ b/internal/config/driver_config_digest.go @@ -22,6 +22,38 @@ import ( corev1 "k8s.io/api/core/v1" ) +// DriverConfigDigestEnvName is the env var the operator sets on the driver pod +// template, carrying a hash of the install-relevant driver config (DriverInstallState). +const DriverConfigDigestEnvName = "DRIVER_CONFIG_DIGEST" + +// DriverConfigDigestFromPodSpec returns the DRIVER_CONFIG_DIGEST value from a driver +// pod spec, or "" if absent. The env is set identically on every driver container, so +// the first non-empty value (init containers first) is returned. +func DriverConfigDigestFromPodSpec(spec *corev1.PodSpec) string { + if spec == nil { + return "" + } + digestFromEnv := func(env []corev1.EnvVar) string { + for _, e := range env { + if e.Name == DriverConfigDigestEnvName { + return e.Value + } + } + return "" + } + for i := range spec.InitContainers { + if v := digestFromEnv(spec.InitContainers[i].Env); v != "" { + return v + } + } + for i := range spec.Containers { + if v := digestFromEnv(spec.Containers[i].Env); v != "" { + return v + } + } + return "" +} + // DriverInstallState lists all fields that affect driver installation. // Changes to these fields trigger a driver reinstall. // diff --git a/internal/config/driver_config_digest_test.go b/internal/config/driver_config_digest_test.go index b9adae2e41..10da9cd0a4 100644 --- a/internal/config/driver_config_digest_test.go +++ b/internal/config/driver_config_digest_test.go @@ -309,3 +309,81 @@ func TestExtractVolumes(t *testing.T) { }) } } + +// containerWithConfigDigest builds a container carrying the DRIVER_CONFIG_DIGEST env +// when digest is non-empty (matching how object_controls.go sets it). +func containerWithConfigDigest(name, digest string) corev1.Container { + c := corev1.Container{Name: name} + if digest != "" { + c.Env = []corev1.EnvVar{{Name: DriverConfigDigestEnvName, Value: digest}} + } + return c +} + +func TestDriverConfigDigestFromPodSpec(t *testing.T) { + tests := []struct { + name string + spec *corev1.PodSpec + want string + }{ + { + name: "digest on k8s-driver-manager init container", + spec: &corev1.PodSpec{ + InitContainers: []corev1.Container{containerWithConfigDigest("k8s-driver-manager", "abc123")}, + Containers: []corev1.Container{containerWithConfigDigest("nvidia-driver-ctr", "")}, + }, + want: "abc123", + }, + { + name: "digest on nvidia-driver-ctr main container", + spec: &corev1.PodSpec{ + Containers: []corev1.Container{containerWithConfigDigest("nvidia-driver-ctr", "def456")}, + }, + want: "def456", + }, + { + name: "digest on OCP openshift-driver-toolkit-ctr", + spec: &corev1.PodSpec{ + Containers: []corev1.Container{containerWithConfigDigest("openshift-driver-toolkit-ctr", "ocp789")}, + }, + want: "ocp789", + }, + { + name: "init container digest takes precedence over main container", + spec: &corev1.PodSpec{ + InitContainers: []corev1.Container{containerWithConfigDigest("k8s-driver-manager", "init-digest")}, + Containers: []corev1.Container{containerWithConfigDigest("nvidia-driver-ctr", "main-digest")}, + }, + want: "init-digest", + }, + { + name: "empty init digest is skipped; main container value used", + spec: &corev1.PodSpec{ + InitContainers: []corev1.Container{{ + Name: "k8s-driver-manager", + Env: []corev1.EnvVar{{Name: DriverConfigDigestEnvName, Value: ""}}, + }}, + Containers: []corev1.Container{containerWithConfigDigest("nvidia-driver-ctr", "main-digest")}, + }, + want: "main-digest", + }, + { + name: "no digest anywhere", + spec: &corev1.PodSpec{ + InitContainers: []corev1.Container{{Name: "k8s-driver-manager"}}, + Containers: []corev1.Container{{Name: "nvidia-driver-ctr"}}, + }, + want: "", + }, + { + name: "nil spec", + spec: nil, + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, DriverConfigDigestFromPodSpec(tt.spec)) + }) + } +}