diff --git a/bundle/generate/downloader.go b/bundle/generate/downloader.go index acfef1ff3d..bf12b37586 100644 --- a/bundle/generate/downloader.go +++ b/bundle/generate/downloader.go @@ -321,9 +321,8 @@ func (n *Downloader) FlushToDisk(ctx context.Context, force bool) error { if err != nil { return err } - defer file.Close() - _, err = io.Copy(file, reader) + err = writeAndClose(file, reader) if err != nil { return err } @@ -336,6 +335,18 @@ func (n *Downloader) FlushToDisk(ctx context.Context, force bool) error { return errs.Wait() } +// writeAndClose copies src into dst and closes dst. A copy error takes +// precedence; otherwise the Close error is returned because a failed Close +// can mean buffered writes were lost and the file is truncated. +func writeAndClose(dst io.WriteCloser, src io.Reader) error { + _, err := io.Copy(dst, src) + cerr := dst.Close() + if err == nil { + err = cerr + } + return err +} + func NewDownloader(w *databricks.WorkspaceClient, sourceDir, configDir string) *Downloader { return &Downloader{ files: make(map[string]exportFile), diff --git a/bundle/generate/downloader_test.go b/bundle/generate/downloader_test.go index 9363ce98d3..5ce4889c5a 100644 --- a/bundle/generate/downloader_test.go +++ b/bundle/generate/downloader_test.go @@ -1,13 +1,19 @@ package generate import ( + "bytes" "encoding/json" + "errors" + "io" "net/http" "net/http/httptest" "os" "path/filepath" + "strings" "testing" + "testing/iotest" + "github.com/databricks/cli/libs/cmdio" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/experimental/mocks" "github.com/databricks/databricks-sdk-go/service/jobs" @@ -261,6 +267,101 @@ func TestDownloader_MarkTasksForDownload_NoNotebooks(t *testing.T) { assert.Empty(t, downloader.files) } +func TestDownloader_FlushToDisk(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + + contents := map[string]string{ + "/Users/user/project/notebook": "# Databricks notebook source\nprint(1)", + "/Users/user/project/utils.py": "def helper(): pass", + } + w := newTestWorkspaceClient(t, func(rw http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/2.0/workspace/export" { + t.Fatalf("unexpected request path: %s", r.URL.Path) + } + content, ok := contents[r.URL.Query().Get("path")] + if !ok { + t.Fatalf("unexpected export path: %s", r.URL.Query().Get("path")) + } + _, err := rw.Write([]byte(content)) + if err != nil { + t.Fatal(err) + } + }) + + sourceDir := t.TempDir() + downloader := NewDownloader(w, sourceDir, "config") + downloader.files[filepath.Join(sourceDir, "notebook.py")] = exportFile{ + path: "/Users/user/project/notebook", + format: workspace.ExportFormatSource, + } + downloader.files[filepath.Join(sourceDir, "utils.py")] = exportFile{ + path: "/Users/user/project/utils.py", + format: workspace.ExportFormatSource, + } + + err := downloader.FlushToDisk(ctx, false) + require.NoError(t, err) + + data, err := os.ReadFile(filepath.Join(sourceDir, "notebook.py")) + require.NoError(t, err) + assert.Equal(t, contents["/Users/user/project/notebook"], string(data)) + + data, err = os.ReadFile(filepath.Join(sourceDir, "utils.py")) + require.NoError(t, err) + assert.Equal(t, contents["/Users/user/project/utils.py"], string(data)) +} + +type fakeWriteCloser struct { + bytes.Buffer + closeErr error +} + +func (f *fakeWriteCloser) Close() error { return f.closeErr } + +func TestWriteAndClose(t *testing.T) { + closeErr := errors.New("close failed") + readErr := errors.New("read failed") + + tests := []struct { + name string + src io.Reader + closeErr error + wantErr error + wantData string + }{ + { + name: "success", + src: strings.NewReader("data"), + wantData: "data", + }, + { + name: "close error is returned", + src: strings.NewReader("data"), + closeErr: closeErr, + wantErr: closeErr, + wantData: "data", + }, + { + name: "copy error takes precedence over close error", + src: iotest.ErrReader(readErr), + closeErr: closeErr, + wantErr: readErr, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dst := &fakeWriteCloser{closeErr: tt.closeErr} + err := writeAndClose(dst, tt.src) + if tt.wantErr == nil { + assert.NoError(t, err) + } else { + assert.ErrorIs(t, err, tt.wantErr) + } + assert.Equal(t, tt.wantData, dst.String()) + }) + } +} + func TestDownloader_CleanupOldFiles(t *testing.T) { ctx := t.Context() sourceDir := t.TempDir()