Skip to content

Commit 32849b8

Browse files
authored
Fix mode for second order (#205)
* Fix mode for second order * Fix benchmark * Add test * mode not imported
1 parent 023a8ea commit 32849b8

3 files changed

Lines changed: 17 additions & 2 deletions

File tree

DifferentiationInterface/src/second_order/second_order.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,10 @@ outer(backend::SecondOrder) = backend.outer
2626
function Base.show(io::IO, backend::SecondOrder)
2727
return print(io, "SecondOrder($(outer(backend)) / $(inner(backend)))")
2828
end
29+
30+
"""
31+
mode(backend::SecondOrder)
32+
33+
Return the _outer_ mode of the second-order backend.
34+
"""
35+
ADTypes.mode(backend::SecondOrder) = mode(outer(backend))

DifferentiationInterface/test/second_order.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ for backend in vcat(
2929
dense_second_order_backends, sparse_second_order_backends, mixed_second_order_backends
3030
)
3131
@test check_hessian(backend)
32+
@test ADTypes.mode(backend) isa ADTypes.AbstractMode
3233
end
3334

3435
test_differentiation(

DifferentiationInterfaceTest/test/zero_backends.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,18 @@ test_differentiation(
5151

5252
## Benchmark
5353

54-
data = benchmark_differentiation(
54+
data1 = benchmark_differentiation(
5555
[AutoZeroForward(), AutoZeroReverse()]; logging=get(ENV, "CI", "false") == "false"
5656
);
5757

58-
df = DataFrames.DataFrame(data)
58+
data2 = benchmark_differentiation(
59+
[SecondOrder(AutoZeroForward(), AutoZeroReverse())];
60+
first_order=false,
61+
logging=get(ENV, "CI", "false") == "false",
62+
);
63+
64+
df1 = DataFrames.DataFrame(data1)
65+
df2 = DataFrames.DataFrame(data2)
5966

6067
## Weird arrays
6168

0 commit comments

Comments
 (0)