diff --git a/bundle/scripts/scripts.go b/bundle/scripts/scripts.go index 29cef46041f..4efcd22a4fd 100644 --- a/bundle/scripts/scripts.go +++ b/bundle/scripts/scripts.go @@ -2,6 +2,7 @@ package scripts import ( "bufio" + "bytes" "context" "errors" "fmt" @@ -43,13 +44,37 @@ func (m *script) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics { return diag.FromErr(err) } - cmd, out, err := executeHook(ctx, executor, command) + cmd, err := executeHook(ctx, executor, command) if err != nil { return diag.FromErr(fmt.Errorf("failed to execute script: %w", err)) } cmdio.LogString(ctx, fmt.Sprintf("Executing '%s' script", m.scriptHook)) + // Reading the pipes sequentially deadlocks once the script fills the ~64KiB + // stderr pipe buffer while stdout is still open, so drain stderr concurrently. + // Spooling it to memory preserves the stdout-then-stderr output order. + var stderr bytes.Buffer + stderrDone := make(chan struct{}) + go func() { + defer close(stderrDone) + _, _ = io.Copy(&stderr, cmd.Stderr()) + }() + + logOutput(ctx, cmd.Stdout()) + <-stderrDone + logOutput(ctx, &stderr) + + err = cmd.Wait() + if err != nil { + return diag.FromErr(fmt.Errorf("failed to execute script: %w", err)) + } + + return nil +} + +// logOutput logs output line by line, including a final line without a trailing newline. +func logOutput(ctx context.Context, out io.Reader) { reader := bufio.NewReader(out) for { line, err := reader.ReadString('\n') @@ -60,27 +85,15 @@ func (m *script) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics { break } } - - err = cmd.Wait() - if err != nil { - return diag.FromErr(fmt.Errorf("failed to execute script: %w", err)) - } - - return nil } -func executeHook(ctx context.Context, executor *exec.Executor, command config.Command) (exec.Command, io.Reader, error) { +func executeHook(ctx context.Context, executor *exec.Executor, command config.Command) (exec.Command, error) { // Don't run any arbitrary code when restricted execution is enabled. if _, ok := env.RestrictedExecution(ctx); ok { - return nil, nil, errors.New("running scripts is not allowed when DATABRICKS_BUNDLE_RESTRICTED_CODE_EXECUTION is set") - } - - cmd, err := executor.StartCommand(ctx, string(command)) - if err != nil { - return nil, nil, err + return nil, errors.New("running scripts is not allowed when DATABRICKS_BUNDLE_RESTRICTED_CODE_EXECUTION is set") } - return cmd, io.MultiReader(cmd.Stdout(), cmd.Stderr()), nil + return executor.StartCommand(ctx, string(command)) } func getCommand(b *bundle.Bundle, hook config.ScriptHook) config.Command { diff --git a/bundle/scripts/scripts_test.go b/bundle/scripts/scripts_test.go index 04aeafe5760..39042fce021 100644 --- a/bundle/scripts/scripts_test.go +++ b/bundle/scripts/scripts_test.go @@ -1,8 +1,11 @@ package scripts_test import ( + "context" "runtime" + "strings" "testing" + "time" "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config" @@ -38,3 +41,37 @@ func TestExecuteOutputWithoutTrailingNewline(t *testing.T) { assert.Contains(t, output, "line2") assert.Contains(t, output, "last line without newline") } + +func TestExecuteLargeStderrOutputDoesNotDeadlock(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipping on windows") + } + + dir := t.TempDir() + b := &bundle.Bundle{ + BundleRootPath: dir, + Config: config.Root{ + Experimental: &config.Experimental{ + Scripts: map[config.ScriptHook]config.Command{ + // Overflows the ~64KiB stderr pipe buffer before stdout + // closes, which deadlocked the old sequential draining. + config.ScriptPreInit: "seq 100000 | tr -d '\\n' >&2; echo stdout-after-stderr", + }, + }, + }, + } + + // A reintroduced deadlock fails fast: the context timeout kills the script. + ctx, cancel := context.WithTimeout(t.Context(), time.Minute) + defer cancel() + + ctx, stderr := cmdio.NewTestContextWithStderr(ctx) + diags := bundle.Apply(ctx, b, scripts.Execute(config.ScriptPreInit)) + require.NoError(t, diags.Error()) + + output := stderr.String() + assert.Contains(t, output, "99999100000") + assert.Contains(t, output, "stdout-after-stderr") + // The script writes stderr first, but spooling preserves the stdout-then-stderr order. + assert.Less(t, strings.Index(output, "stdout-after-stderr"), strings.Index(output, "99999100000")) +}