Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 29 additions & 16 deletions bundle/scripts/scripts.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package scripts

import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
Expand Down Expand Up @@ -43,13 +44,37 @@ func (m *script) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
return diag.FromErr(err)
}

cmd, out, err := executeHook(ctx, executor, command)
cmd, err := executeHook(ctx, executor, command)
if err != nil {
return diag.FromErr(fmt.Errorf("failed to execute script: %w", err))
}

cmdio.LogString(ctx, fmt.Sprintf("Executing '%s' script", m.scriptHook))

// Reading the pipes sequentially deadlocks once the script fills the ~64KiB
// stderr pipe buffer while stdout is still open, so drain stderr concurrently.
// Spooling it to memory preserves the stdout-then-stderr output order.
var stderr bytes.Buffer
stderrDone := make(chan struct{})
go func() {
defer close(stderrDone)
_, _ = io.Copy(&stderr, cmd.Stderr())
}()

logOutput(ctx, cmd.Stdout())
<-stderrDone
logOutput(ctx, &stderr)

err = cmd.Wait()
if err != nil {
return diag.FromErr(fmt.Errorf("failed to execute script: %w", err))
}

return nil
}

// logOutput logs output line by line, including a final line without a trailing newline.
func logOutput(ctx context.Context, out io.Reader) {
reader := bufio.NewReader(out)
for {
line, err := reader.ReadString('\n')
Expand All @@ -60,27 +85,15 @@ func (m *script) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
break
}
}

err = cmd.Wait()
if err != nil {
return diag.FromErr(fmt.Errorf("failed to execute script: %w", err))
}

return nil
}

func executeHook(ctx context.Context, executor *exec.Executor, command config.Command) (exec.Command, io.Reader, error) {
func executeHook(ctx context.Context, executor *exec.Executor, command config.Command) (exec.Command, error) {
// Don't run any arbitrary code when restricted execution is enabled.
if _, ok := env.RestrictedExecution(ctx); ok {
return nil, nil, errors.New("running scripts is not allowed when DATABRICKS_BUNDLE_RESTRICTED_CODE_EXECUTION is set")
}

cmd, err := executor.StartCommand(ctx, string(command))
if err != nil {
return nil, nil, err
return nil, errors.New("running scripts is not allowed when DATABRICKS_BUNDLE_RESTRICTED_CODE_EXECUTION is set")
}

return cmd, io.MultiReader(cmd.Stdout(), cmd.Stderr()), nil
return executor.StartCommand(ctx, string(command))
}

func getCommand(b *bundle.Bundle, hook config.ScriptHook) config.Command {
Expand Down
37 changes: 37 additions & 0 deletions bundle/scripts/scripts_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package scripts_test

import (
"context"
"runtime"
"strings"
"testing"
"time"

"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
Expand Down Expand Up @@ -38,3 +41,37 @@ func TestExecuteOutputWithoutTrailingNewline(t *testing.T) {
assert.Contains(t, output, "line2")
assert.Contains(t, output, "last line without newline")
}

func TestExecuteLargeStderrOutputDoesNotDeadlock(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("skipping on windows")
}

dir := t.TempDir()
b := &bundle.Bundle{
BundleRootPath: dir,
Config: config.Root{
Experimental: &config.Experimental{
Scripts: map[config.ScriptHook]config.Command{
// Overflows the ~64KiB stderr pipe buffer before stdout
// closes, which deadlocked the old sequential draining.
config.ScriptPreInit: "seq 100000 | tr -d '\\n' >&2; echo stdout-after-stderr",
},
},
},
}

// A reintroduced deadlock fails fast: the context timeout kills the script.
ctx, cancel := context.WithTimeout(t.Context(), time.Minute)
defer cancel()

ctx, stderr := cmdio.NewTestContextWithStderr(ctx)
diags := bundle.Apply(ctx, b, scripts.Execute(config.ScriptPreInit))
require.NoError(t, diags.Error())

output := stderr.String()
assert.Contains(t, output, "99999100000")
assert.Contains(t, output, "stdout-after-stderr")
// The script writes stderr first, but spooling preserves the stdout-then-stderr order.
assert.Less(t, strings.Index(output, "stdout-after-stderr"), strings.Index(output, "99999100000"))
}
Loading