Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 71 additions & 33 deletions cmd/nvidia-validator/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ const (
wslNvidiaSMIPath = "/usr/lib/wsl/lib/nvidia-smi"
// shell indicates what shell to use when invoking commands in a subprocess
shell = "sh"
// defaultVFWaitTimeout is the default timeout for waiting for VFs to be created
defaultVFWaitTimeout = 5 * time.Minute
// defaultVGPUReadinessTimeout is the default timeout for waiting for the vGPU stack to be ready
defaultVGPUReadinessTimeout = 5 * time.Minute
// constants for driver components
GDRCOPY = "gdrcopy"
NVIDIAFS = "nvidia-fs"
Expand Down Expand Up @@ -1747,9 +1747,9 @@ func (v *VGPUManager) validate() error {
return err
}

log.Info("Waiting for VFs to be available...")
if err := waitForVFs(ctx, defaultVFWaitTimeout); err != nil {
return fmt.Errorf("vGPU Manager VFs not ready: %w", err)
log.Info("Waiting for parent devices to be available...")
if err := waitForParentDevices(ctx, defaultVGPUReadinessTimeout); err != nil {
return fmt.Errorf("vGPU Manager parent devices not ready: %w", err)
}

statusFile := vGPUManagerStatusFile
Expand Down Expand Up @@ -1783,43 +1783,81 @@ func (v *VGPUManager) runValidation(silent bool) (hostDriver bool, err error) {
return hostDriver, runCommand(command, args, silent)
}

// waitForVFs waits for Virtual Functions to be created on all NVIDIA GPUs.
// It polls sriov_numvfs until all GPUs have their full VF count enabled.
func waitForVFs(ctx context.Context, timeout time.Duration) error {
// waitForParentDevices polls until the vGPU stack is ready — either NVIDIA
// mdev parent devices have been registered (PF on Turing, VFs on Ampere+
// SR-IOV) or all SR-IOV VFs are enabled.
func waitForParentDevices(ctx context.Context, timeout time.Duration) error {
pollInterval := time.Duration(sleepIntervalSecondsFlag) * time.Second
nvpciLib := nvpci.New()

return wait.PollUntilContextTimeout(ctx, pollInterval, timeout, true, func(ctx context.Context) (bool, error) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the below suggestion?

return wait.PollUntilContextTimeout(ctx, pollInterval, timeout, true, func(ctx context.Context) (bool, error) {
    if mdevParentDevicesExist() || vfsExist() {
      return true, nil
    }
    return false, nil
})

func mdevParentDevicesExist() bool {
    nvmdev.Lib := nvmdev.New()
    mdevParentDevices, err := nvmdevLib.GetAllParentDevices()
    if err != nil {
      log.Warnf("could not get mdev parent devices: %v", err)
      return false
    }

    if len(mdevParentDevices) == 0 {
      log.Infof("found 0 mdev parent devices")
      return false
    } 
  
    log.Infof("found %d mdev parent devices", len(mdevParentDevices))
    return true
}

func vfsExist() bool {
        nvpciLib := nvpci.New()
  		gpus, err := nvpciLib.GetGPUs()
		if err != nil {
			log.Warnf("error getting GPUs: %v", err)
			return false
		}

		var totalExpected, totalEnabled uint64
		var pfCount int
		for _, gpu := range gpus {
			sriovInfo := gpu.SriovInfo
			if sriovInfo.IsPF() {
				pfCount++
				totalExpected += sriovInfo.PhysicalFunction.TotalVFs
				totalEnabled += sriovInfo.PhysicalFunction.NumVFs
			}
		}

		if totalExpected == 0 {
			log.Info("No SR-IOV capable GPUs found")
			return false
        }

        if totalEnabled == totalExpected {
			log.Infof("All %d VF(s) enabled on %d NVIDIA GPU(s)", totalEnabled, pfCount)
			return true
		}

        log.Infof("Not all VFs have been created. %d/%d enabled across %d GPU(s)", totalEnabled, totalExpected, pfCount)
		return false
}

gpus, err := nvpciLib.GetGPUs()
if err != nil {
log.Warnf("Error getting GPUs: %v", err)
return false, nil
}

var totalExpected, totalEnabled uint64
var pfCount int
for _, gpu := range gpus {
sriovInfo := gpu.SriovInfo
if sriovInfo.IsPF() {
pfCount++
totalExpected += sriovInfo.PhysicalFunction.TotalVFs
totalEnabled += sriovInfo.PhysicalFunction.NumVFs
}
if driverUsingSRIOV() {
return allVFsReady() && mdevParentDevicesExist(), nil
}
return mdevParentDevicesExist(), nil
})
}

if totalExpected == 0 {
log.Info("No SR-IOV capable GPUs found, skipping VF wait")
return true, nil
func driverUsingSRIOV() bool {
nvpciLib := nvpci.New()
gpus, err := nvpciLib.GetGPUs()
if err != nil {
log.Warnf("error getting GPUs: %v", err)
return false
}
for _, gpu := range gpus {
if gpu.SriovInfo.IsPF() && gpu.SriovInfo.PhysicalFunction.NumVFs > 0 {
return true
}
}
return false
}

if totalEnabled == totalExpected {
log.Infof("All %d VF(s) enabled on %d NVIDIA GPU(s)", totalEnabled, pfCount)
return true, nil
func mdevParentDevicesExist() bool {
nvmdevLib := nvmdev.New()
parents, err := nvmdevLib.GetAllParentDevices()
if err != nil {
log.Warnf("could not get mdev parent devices: %v", err)
return false
}
if len(parents) == 0 {
log.Info("found 0 mdev parent devices")
return false
}
log.Infof("found %d mdev parent devices", len(parents))
return true
}

func allVFsReady() bool {
nvpciLib := nvpci.New()
gpus, err := nvpciLib.GetGPUs()
if err != nil {
log.Warnf("error getting GPUs: %v", err)
return false
}

var totalExpected, totalEnabled uint64
var pfCount int
for _, gpu := range gpus {
sriovInfo := gpu.SriovInfo
if sriovInfo.IsPF() {
pfCount++
totalExpected += sriovInfo.PhysicalFunction.TotalVFs
totalEnabled += sriovInfo.PhysicalFunction.NumVFs
}
}

log.Infof("Waiting for VFs: %d/%d enabled across %d GPU(s)", totalEnabled, totalExpected, pfCount)
return false, nil
})
if totalExpected == 0 {
log.Info("no SR-IOV capable GPUs found")
return false
}

if totalEnabled == totalExpected {
log.Infof("all %d VF(s) enabled on %d NVIDIA GPU(s)", totalEnabled, pfCount)
return true
}

log.Infof("not all VFs have been created. %d/%d enabled across %d GPU(s)", totalEnabled, totalExpected, pfCount)
return false
}

func (c *CCManager) validate() error {
Expand Down
Loading