diff --git a/cumulusci/tasks/bulkdata/load.py b/cumulusci/tasks/bulkdata/load.py index a6cbdee9ac..2bb148869d 100644 --- a/cumulusci/tasks/bulkdata/load.py +++ b/cumulusci/tasks/bulkdata/load.py @@ -127,6 +127,11 @@ def _init_options(self, kwargs): self.options["enable_rollback"] = process_bool_arg( self.options.get("enable_rollback", False) ) + if self.options["enable_rollback"] and self.options["ignore_row_errors"]: + self.logger.warning( + "enable_rollback=True has no effect on row-level errors when " + "ignore_row_errors=True, because row errors are suppressed before rollback can trigger." + ) self._id_generators = {} self._old_format = False self.ID_TABLE_NAME = ID_TABLE_NAME diff --git a/cumulusci/tasks/bulkdata/snowfakery.py b/cumulusci/tasks/bulkdata/snowfakery.py index 0f960bd29e..54645ac290 100644 --- a/cumulusci/tasks/bulkdata/snowfakery.py +++ b/cumulusci/tasks/bulkdata/snowfakery.py @@ -147,6 +147,9 @@ class Snowfakery(BaseSalesforceApiTask): "strict_mode": { "description": "Boolean: If True, validates the Snowfakery recipe and generated mapping against the org schema (strict mode) and then proceeds with the run", }, + "enable_rollback": { + "description": "Boolean: When True, performs a rollback of all loaded records in case of an error. Defaults to False." + }, } def _validate_options(self): @@ -169,6 +172,21 @@ def _validate_options(self): ) self.validate_only = process_bool_arg(self.options.get("validate_only", False)) self.strict_mode = process_bool_arg(self.options.get("strict_mode", False)) + self.enable_rollback = process_bool_arg( + self.options.get("enable_rollback", False) + ) + if self.enable_rollback and any( + self.options.get(k) + for k in ( + "run_until_records_in_org", + "run_until_records_loaded", + "run_until_recipe_repeated", + ) + ): + raise TaskOptionsError( + "enable_rollback=True cannot be combined with run_until_* options " + "because each batch commits independently; only the failing batch would be rolled back." + ) loading_rules = process_list_arg(self.options.get("loading_rules")) or [] self.loading_rules = [Path(path) for path in loading_rules if path] @@ -290,6 +308,7 @@ def _setup_channels_and_queues(self, working_directory): additional_load_options = { "ignore_row_errors": self.ignore_row_errors, "drop_missing_schema": self.drop_missing_schema, + "enable_rollback": self.enable_rollback, } subtask_configurator = SubtaskConfigurator( self.recipe, self.run_until, self.bulk_mode, additional_load_options @@ -619,6 +638,7 @@ def _run_generate_and_load_subtask( "drop_missing_schema": self.drop_missing_schema, "validate_only": validate_only, "strict_mode": self.strict_mode, + "enable_rollback": self.enable_rollback, } subtask_config = TaskConfig({"options": options}) subtask = GenerateAndLoadDataFromYaml( diff --git a/cumulusci/tasks/bulkdata/tests/test_load.py b/cumulusci/tasks/bulkdata/tests/test_load.py index f3657779c3..f413cf4ed7 100644 --- a/cumulusci/tasks/bulkdata/tests/test_load.py +++ b/cumulusci/tasks/bulkdata/tests/test_load.py @@ -131,6 +131,25 @@ def test_run(self, dml_mock): hh_ids = next(c.execute("SELECT * from cumulusci_id_table")) assert hh_ids == ("households-1", "001000000000000") + def test_enable_rollback_warns_when_ignore_row_errors_also_set(self): + task = _make_task( + LoadData, + { + "options": { + "mapping": "mapping.yml", + "database_url": "sqlite://", + "enable_rollback": True, + "ignore_row_errors": True, + } + }, + ) + with mock.patch.object(task, "logger") as mock_logger: + task._init_options({}) + warning_messages = [ + str(call) for call in mock_logger.warning.call_args_list + ] + assert any("enable_rollback" in msg for msg in warning_messages) + @responses.activate @mock.patch("cumulusci.tasks.bulkdata.load.get_dml_operation") def test__insert_rollback(self, dml_mock): diff --git a/cumulusci/tasks/bulkdata/tests/test_snowfakery.py b/cumulusci/tasks/bulkdata/tests/test_snowfakery.py index 7b30cbc6fc..22789aac9b 100644 --- a/cumulusci/tasks/bulkdata/tests/test_snowfakery.py +++ b/cumulusci/tasks/bulkdata/tests/test_snowfakery.py @@ -237,6 +237,81 @@ def _run_snowfakery_and_inspect_mapping(**options): return _run_snowfakery_and_inspect_mapping +@mock.patch("cumulusci.tasks.bulkdata.snowfakery.GenerateAndLoadDataFromYaml") +def test_enable_rollback_passes_flag_to_subtask(mock_subtask_cls, snowfakery): + mock_subtask = mock.Mock() + mock_subtask.__call__ = mock.Mock(return_value=None) + mock_subtask.return_values = { + "load_results": [ + { + "step_results": { + "Insert Account": { + "sobject": "Account", + "record_type": None, + "status": "Success", + "records_processed": 1, + "total_row_errors": 0, + } + } + } + ] + } + mock_subtask_cls.return_value = mock_subtask + + task = snowfakery( + recipe=str(simple_salesforce_yaml), + enable_rollback=True, + ) + + with TemporaryDirectory() as tmpdir: + task._run_generate_and_load_subtask( + Path(tmpdir), + DummyOrgConfig({}, "test"), + options={}, + ) + + call_kwargs = mock_subtask_cls.call_args.kwargs + task_config = call_kwargs["task_config"] + assert task_config.options["enable_rollback"] is True + + +@mock.patch("cumulusci.tasks.bulkdata.snowfakery.GenerateAndLoadDataFromYaml") +def test_enable_rollback_defaults_to_false(mock_subtask_cls, snowfakery): + mock_subtask = mock.Mock() + mock_subtask.__call__ = mock.Mock(return_value=None) + mock_subtask.return_values = { + "load_results": [ + { + "step_results": { + "Insert Account": { + "sobject": "Account", + "record_type": None, + "status": "Success", + "records_processed": 1, + "total_row_errors": 0, + } + } + } + ] + } + mock_subtask_cls.return_value = mock_subtask + + task = snowfakery( + recipe=str(simple_salesforce_yaml), + ) + + with TemporaryDirectory() as tmpdir: + task._run_generate_and_load_subtask( + Path(tmpdir), + DummyOrgConfig({}, "test"), + options={}, + ) + + call_kwargs = mock_subtask_cls.call_args.kwargs + task_config = call_kwargs["task_config"] + assert task_config.options["enable_rollback"] is False + + @mock.patch("cumulusci.tasks.bulkdata.snowfakery.GenerateAndLoadDataFromYaml") def test_snowfakery_validate_only_passes_flags(mock_subtask_cls, snowfakery): mock_subtask = mock.Mock() @@ -405,6 +480,24 @@ def test_small( for call in mock_load_data.mock_calls: assert call.task_config.config["options"]["drop_missing_schema"] is True + @pytest.mark.parametrize( + "run_until_option,run_until_value", + [ + ("run_until_recipe_repeated", "7"), + ("run_until_records_loaded", "Account:10"), + ("run_until_records_in_org", "Account:10"), + ], + ) + def test_enable_rollback_rejected_with_run_until( + self, run_until_option, run_until_value, snowfakery + ): + with pytest.raises(exc.TaskOptionsError, match="enable_rollback"): + snowfakery( + recipe=str(simple_salesforce_yaml), + enable_rollback=True, + **{run_until_option: run_until_value}, + ) + @mock.patch("cumulusci.tasks.bulkdata.snowfakery.MIN_PORTION_SIZE", 3) def test_multi_part( self, threads_instead_of_processes, mock_load_data, create_task_fixture