diff --git a/Scripts/format.sh b/Scripts/format.sh new file mode 100755 index 00000000..2607603a --- /dev/null +++ b/Scripts/format.sh @@ -0,0 +1,9 @@ +#!/bin/zsh + +set -e + +echo "Formatting Swift sources in $(pwd)" + +# Run the format / lint commands +git ls-files -z '*.swift' | xargs -0 swift format format --parallel --in-place +git ls-files -z '*.swift' | xargs -0 swift format lint --strict --parallel \ No newline at end of file diff --git a/Sources/ComputeCxx/Closure/AGClosure.cpp b/Sources/ComputeCxx/Closure/AGClosure.cpp index 223e6fa0..6ea5b1bb 100644 --- a/Sources/ComputeCxx/Closure/AGClosure.cpp +++ b/Sources/ComputeCxx/Closure/AGClosure.cpp @@ -3,12 +3,17 @@ #include AGClosureStorage AGRetainClosure(const void *thunk, const void *_Nullable context) { - void *mutable_context = const_cast(context); - const void *retained_context = ::swift::swift_retain(reinterpret_cast<::swift::HeapObject *>(mutable_context)); + const void *retained_context = context; + if (context) { + void *mutable_context = const_cast(context); + retained_context = ::swift::swift_retain(reinterpret_cast<::swift::HeapObject *>(mutable_context)); + } return AGClosureStorage((void *)thunk, retained_context); } void AGReleaseClosure(AGClosureStorage closure) { - void *context = const_cast(closure.context); - ::swift::swift_release(reinterpret_cast<::swift::HeapObject *>(context)); + if (closure.context) { + void *mutable_context = const_cast(closure.context); + ::swift::swift_release(reinterpret_cast<::swift::HeapObject *>(mutable_context)); + } } diff --git a/Sources/ComputeCxx/Closure/ClosureFunction.h b/Sources/ComputeCxx/Closure/ClosureFunction.h index 20ea6d50..9e980d35 100644 --- a/Sources/ComputeCxx/Closure/ClosureFunction.h +++ b/Sources/ComputeCxx/Closure/ClosureFunction.h @@ -19,13 +19,17 @@ template class ClosureFunction { public: inline ClosureFunction(std::nullptr_t): _function(nullptr), _context(nullptr) {} inline ClosureFunction(Function function, Context context) noexcept : _function(function), _context(context) { - void *mutable_context = const_cast(_context); - ::swift::swift_retain(reinterpret_cast<::swift::HeapObject *>(mutable_context)); + if (_context) { + void *mutable_context = const_cast(_context); + ::swift::swift_retain(reinterpret_cast<::swift::HeapObject *>(mutable_context)); + } } inline ~ClosureFunction() { - void *context = const_cast(_context); - ::swift::swift_release(reinterpret_cast<::swift::HeapObject *>(context)); + if (_context) { + void *mutable_context = const_cast(_context); + ::swift::swift_release(reinterpret_cast<::swift::HeapObject *>(mutable_context)); + } } // Copyable @@ -39,13 +43,15 @@ template class ClosureFunction { ClosureFunction &operator=(const ClosureFunction &other) noexcept { if (this != &other) { - _function = other._function; - if (_context) { - ::swift::swift_release((::swift::HeapObject *)_context); + Context new_context = other._context; + if (new_context) { + new_context = ::swift::swift_retain((::swift::HeapObject *)new_context); } - _context = other._context; - if (_context) { - ::swift::swift_retain((::swift::HeapObject *)_context); + Context old_context = _context; + _function = other._function; + _context = new_context; + if (old_context) { + ::swift::swift_release((::swift::HeapObject *)old_context); } } return *this; @@ -58,13 +64,14 @@ template class ClosureFunction { ClosureFunction &operator=(ClosureFunction &&other) noexcept { if (this != &other) { + Context old_context = _context; _function = other._function; - other._function = nullptr; - if (_context) { - ::swift::swift_release((::swift::HeapObject *)_context); - } _context = other._context; + other._function = nullptr; other._context = nullptr; + if (old_context) { + ::swift::swift_release((::swift::HeapObject *)old_context); + } } return *this; } diff --git a/Tests/ComputeSwiftTests/Shared/MetadataTests.swift b/Tests/ComputeSwiftTests/Shared/MetadataTests.swift index 5b983245..6b5e5e3f 100644 --- a/Tests/ComputeSwiftTests/Shared/MetadataTests.swift +++ b/Tests/ComputeSwiftTests/Shared/MetadataTests.swift @@ -127,8 +127,8 @@ struct MetadataTests { signatures.append(Metadata(TestPackedGenericStruct.self).signature) signatures.append(Metadata(TestPackedGenericStruct.self).signature) - signatures.combinations(ofCount: 2).forEach { elements in - #expect(elements[0] != elements[1]) + for combination in signatures.combinations(ofCount: 2) { + #expect(combination[0] != combination[1]) } } diff --git a/Tests/UtilitiesTests/HashTableTests.swift b/Tests/UtilitiesTests/HashTableTests.swift index 964f2c78..d6c25561 100644 --- a/Tests/UtilitiesTests/HashTableTests.swift +++ b/Tests/UtilitiesTests/HashTableTests.swift @@ -119,16 +119,19 @@ struct HashTableTests { // Count iterations via for_each - if there's a cycle, this will exceed count // or hang forever. We use a manual iteration limit to detect cycles. var iterationCount = 0 - let maxIterations = 1000 // Way more than count, to detect infinite loop - - table.for_each({ _, _, context in - let countPtr = context.assumingMemoryBound(to: Int.self) - countPtr.pointee += 1 - // If we've iterated too many times, we have a cycle - if countPtr.pointee > 100 { - fatalError("Cycle detected in hash table - iteration count exceeded expected") - } - }, &iterationCount) + let maxIterations = 1000 // Way more than count, to detect infinite loop + + table.for_each( + { _, _, context in + let countPtr = context.assumingMemoryBound(to: Int.self) + countPtr.pointee += 1 + // If we've iterated too many times, we have a cycle + if countPtr.pointee > 100 { + fatalError("Cycle detected in hash table - iteration count exceeded expected") + } + }, + &iterationCount + ) #expect(iterationCount == 32, "for_each should visit exactly count() items") } @@ -201,10 +204,13 @@ struct HashTableTests { // Count via for_each should be 2 var iterationCount = 0 - table.for_each({ _, _, context in - let countPtr = context.assumingMemoryBound(to: Int.self) - countPtr.pointee += 1 - }, &iterationCount) + table.for_each( + { _, _, context in + let countPtr = context.assumingMemoryBound(to: Int.self) + countPtr.pointee += 1 + }, + &iterationCount + ) #expect(iterationCount == 2, "for_each should visit exactly 2 items after reinsertion") } @@ -242,10 +248,13 @@ struct HashTableTests { // Count via for_each should match var iterationCount = 0 - table.for_each({ _, _, context in - let countPtr = context.assumingMemoryBound(to: Int.self) - countPtr.pointee += 1 - }, &iterationCount) + table.for_each( + { _, _, context in + let countPtr = context.assumingMemoryBound(to: Int.self) + countPtr.pointee += 1 + }, + &iterationCount + ) #expect(iterationCount == itemCount, "for_each should visit all \(itemCount) items") }