diff --git a/bundle/python/warning.go b/bundle/python/warning.go index 443b8fd27c..01b639ef11 100644 --- a/bundle/python/warning.go +++ b/bundle/python/warning.go @@ -20,12 +20,20 @@ func WrapperWarning() bundle.Mutator { } func (m *wrapperWarning) Apply(ctx context.Context, b *bundle.Bundle) error { + if isPythonWheelWrapperOn(b) { + return nil + } + if hasIncompatibleWheelTasks(ctx, b) { return fmt.Errorf("python wheel tasks with local libraries require compute with DBR 13.1+. Please change your cluster configuration or set experimental 'python_wheel_wrapper' setting to 'true'") } return nil } +func isPythonWheelWrapperOn(b *bundle.Bundle) bool { + return b.Config.Experimental != nil && b.Config.Experimental.PythonWheelWrapper +} + func hasIncompatibleWheelTasks(ctx context.Context, b *bundle.Bundle) bool { tasks := libraries.FindAllWheelTasksWithLocalLibraries(b) for _, task := range tasks { diff --git a/bundle/python/warning_test.go b/bundle/python/warning_test.go index 83bc142f1f..f822f113c8 100644 --- a/bundle/python/warning_test.go +++ b/bundle/python/warning_test.go @@ -209,6 +209,9 @@ func TestIncompatibleWheelTasksWithJobClusterKey(t *testing.T) { } require.True(t, hasIncompatibleWheelTasks(context.Background(), b)) + + err := bundle.Apply(context.Background(), b, WrapperWarning()) + require.ErrorContains(t, err, "python wheel tasks with local libraries require compute with DBR 13.1+.") } func TestIncompatibleWheelTasksWithExistingClusterId(t *testing.T) { @@ -337,6 +340,49 @@ func TestNoIncompatibleWheelTasks(t *testing.T) { require.False(t, hasIncompatibleWheelTasks(context.Background(), b)) } +func TestNoWarningWhenPythonWheelWrapperIsOn(t *testing.T) { + b := &bundle.Bundle{ + Config: config.Root{ + Experimental: &config.Experimental{ + PythonWheelWrapper: true, + }, + Resources: config.Resources{ + Jobs: map[string]*resources.Job{ + "job1": { + JobSettings: &jobs.JobSettings{ + Tasks: []jobs.Task{ + { + TaskKey: "key1", + PythonWheelTask: &jobs.PythonWheelTask{}, + NewCluster: &compute.ClusterSpec{ + SparkVersion: "12.2.x-scala2.12", + }, + Libraries: []compute.Library{ + {Whl: "./dist/test.whl"}, + }, + }, + { + TaskKey: "key2", + PythonWheelTask: &jobs.PythonWheelTask{}, + NewCluster: &compute.ClusterSpec{ + SparkVersion: "13.1.x-scala2.12", + }, + Libraries: []compute.Library{ + {Whl: "./dist/test.whl"}, + }, + }, + }, + }, + }, + }, + }, + }, + } + + err := bundle.Apply(context.Background(), b, WrapperWarning()) + require.NoError(t, err) +} + func TestSparkVersionLowerThanExpected(t *testing.T) { testCases := map[string]bool{ "13.1.x-scala2.12": false,