Skip to content

Commit 4911d35

Browse files
committed
Further add resilliancy in execution
1 parent e9e0e7d commit 4911d35

File tree

1 file changed

+36
-12
lines changed

1 file changed

+36
-12
lines changed

vulnerability_fix_engine.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ class ForkAlreadyExistsException(Exception):
4141
pass
4242

4343

44+
class AmbiguousObjectNameHeadException(Exception):
45+
pass
46+
47+
4448
async def subprocess_run(args: List[str], cwd: str) -> Optional[str]:
4549
proc = await asyncio.create_subprocess_exec(
4650
args[0],
@@ -70,6 +74,8 @@ async def subprocess_run(args: List[str], cwd: str) -> Optional[str]:
7074
raise PullRequestAlreadyExistsException(error_msg)
7175
if 'Error creating fork' in msg and 'already exists on github.com' in msg:
7276
raise ForkAlreadyExistsException(error_msg)
77+
if ' Ambiguous object name: \'HEAD\'' in msg:
78+
raise AmbiguousObjectNameHeadException(error_msg)
7379
raise RuntimeError(error_msg)
7480
else:
7581
if stderr:
@@ -159,6 +165,7 @@ async def do_fix_vulnerable_file(self, project_name: str, file: str, expected_fi
159165
class VulnerabilityFixReport:
160166
files_fixed: int
161167
vulnerabilities_fixed: int
168+
file_name_fixed: List[str]
162169

163170

164171
@dataclass
@@ -276,6 +283,7 @@ async def do_fix_vulnerabilities(self) -> VulnerabilityFixReport:
276283
project_vulnerabilities_fixed = 0
277284
project_files_fixed = 0
278285
submodules = self.submodule_files()
286+
files_fixed: List[str] = []
279287
for file in self.project_files.files:
280288
# Skip submodule files
281289
skip = next((True for submodule in submodules if file.startswith(submodule)), False)
@@ -284,16 +292,21 @@ async def do_fix_vulnerabilities(self) -> VulnerabilityFixReport:
284292
if file_vulnerabilities_fixed > 0:
285293
project_vulnerabilities_fixed += file_vulnerabilities_fixed
286294
project_files_fixed += 1
287-
return VulnerabilityFixReport(project_files_fixed, project_vulnerabilities_fixed)
295+
files_fixed.append(file)
296+
return VulnerabilityFixReport(
297+
project_files_fixed,
298+
project_vulnerabilities_fixed,
299+
files_fixed
300+
)
288301

289302
async def do_create_branch(self):
290303
await self.do_run_in(['git', 'checkout', '-b', self.fix_module.branch_name])
291304

292-
async def do_stage_changes(self):
305+
async def do_stage_changes(self, project_report: VulnerabilityFixReport):
293306
command = ['git', 'add']
294307
# Only run add on the files we've modified
295308
# This hopefully limits CRLF changes
296-
files_trimmed = [file_name.lstrip('/') for file_name in self.project_files.files.keys()]
309+
files_trimmed = [file_name.lstrip('/') for file_name in project_report.file_name_fixed]
297310
command.extend(files_trimmed)
298311
await self.do_run_in(command)
299312

@@ -375,7 +388,7 @@ async def execute_vulnerability_fixer_engine(engine: VulnerabilityFixerEngine, l
375388
# If the LGTM data is out-of-date, there can be cases where no vulnerabilities are fixed
376389
if project_report.vulnerabilities_fixed != 0:
377390
await engine.do_create_branch()
378-
await engine.do_stage_changes()
391+
await engine.do_stage_changes(project_report)
379392
await engine.do_commit_changes()
380393

381394
if not engine.project_files.project_name.lower().startswith('jlleitschuh'):
@@ -387,9 +400,15 @@ async def execute_vulnerability_fixer_engine(engine: VulnerabilityFixerEngine, l
387400
return project_report
388401

389402

390-
async def execute_vulnerability_fixer_engine_checked(engine: VulnerabilityFixerEngine, lock) -> VulnerabilityFixReport:
403+
async def execute_vulnerability_fixer_engine_checked(
404+
engine: VulnerabilityFixerEngine,
405+
lock
406+
) -> Optional[VulnerabilityFixReport]:
391407
try:
392408
return await execute_vulnerability_fixer_engine(engine, lock)
409+
except AmbiguousObjectNameHeadException:
410+
# They named their main branch 'HEAD'... Why?! No fix for them.
411+
return None
393412
except BaseException as e:
394413
if 'CancelledError' in e.__class__.__name__:
395414
raise e
@@ -425,6 +444,8 @@ async def _do_execute_engines(engines: List[VulnerabilityFixerEngine]):
425444
print(f'Processing {len(waiting_reports)} Projects:')
426445
all_reports = await asyncio.gather(*waiting_reports)
427446
for report in all_reports:
447+
if report is None:
448+
continue
428449
if report.vulnerabilities_fixed > 0:
429450
projects_fixed += 1
430451
files_fixed += report.files_fixed
@@ -447,12 +468,14 @@ async def _do_execute_fix_module(fix_module: VulnerabilityFixModule, starting_le
447468
for vulnerable_project in vulnerable_projects:
448469
if not vulnerable_project.project_name.startswith(starting_letter):
449470
continue
450-
if is_archived_git_hub_repository(vulnerable_project):
451-
logging.info(f'Skipping project {vulnerable_project.project_name} since it is archived')
452-
continue
471+
# Check this first, it's going to be faster
453472
if os.path.exists(fix_module.save_point_file_name(vulnerable_project)):
454473
logging.info(f'Skipping project {vulnerable_project.project_name} since save point file already exists')
455474
continue
475+
# Check this second, it's going to be slower
476+
if is_archived_git_hub_repository(vulnerable_project):
477+
logging.info(f'Skipping project {vulnerable_project.project_name} since it is archived')
478+
continue
456479
print(f'Loading Execution for: {vulnerable_project.project_name}')
457480
engine = VulnerabilityFixerEngine(
458481
fix_module=fix_module,
@@ -463,10 +486,11 @@ async def _do_execute_fix_module(fix_module: VulnerabilityFixModule, starting_le
463486
size = 100
464487
engine_lists = x = [engines[i:i + size] for i in range(0, len(engines), size)]
465488
for engine_list in engine_lists:
466-
try:
467-
await _do_execute_engines(engine_list)
468-
except EnginesExecutionException as e:
469-
logging.exception(f'Failed while processing engine group. {str(e)}')
489+
await _do_execute_engines(engine_list)
490+
# try:
491+
# await _do_execute_engines(engine_list)
492+
# except EnginesExecutionException as e:
493+
# logging.exception(f'Failed while processing engine group. {str(e)}')
470494

471495

472496
def do_execute_fix_module(fix_module: VulnerabilityFixModule):

0 commit comments

Comments
 (0)