Skip to content

Commit d970daf

Browse files
committed
Fixes
1 parent 401096d commit d970daf

6 files changed

Lines changed: 21 additions & 15 deletions

File tree

DifferentiationInterface/test/Back/DifferentiateWith/test.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ function differentiatewith_scenarios()
1616
DIT.function_place(scen) == :out
1717
end
1818
good_scens = map(bad_scens) do scen
19-
DIT.change_function(
20-
scen, DifferentiateWith(scen.f, AutoFiniteDiff()); keep_smaller=false
21-
)
19+
DIT.change_function(scen, DifferentiateWith(scen.f, AutoFiniteDiff()))
2220
end
2321
return good_scens
2422
end

DifferentiationInterfaceTest/src/scenarios/modify.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ end
3636
3737
Return a new `Scenario` identical to `scen` except for the function `f` which is changed to `new_f`.
3838
"""
39-
function change_function(
40-
scen::Scenario{op,pl_op,pl_fun}, new_f; keep_smaller
41-
) where {op,pl_op,pl_fun}
39+
function change_function(scen::Scenario{op,pl_op,pl_fun}, new_f) where {op,pl_op,pl_fun}
4240
return Scenario{op,pl_op,pl_fun}(;
4341
f=new_f,
4442
x=scen.x,
@@ -52,6 +50,8 @@ function change_function(
5250
)
5351
end
5452

53+
same_function(scen) = change_function(scen, scen.f)
54+
5555
"""
5656
batchify(scen::Scenario)
5757

DifferentiationInterfaceTest/src/scenarios/scenario.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,15 @@ function Base.show(
199199
end
200200

201201
function adapt_batchsize(backend::AbstractADType, scen::Scenario)
202-
(; x, y) = scen
202+
(; x, y, prep_args) = scen
203+
xprep = prep_args.x
204+
yprep = hasproperty(prep_args, :y) ? prep_args.y : y
203205
Bmax = if x isa AbstractArray && y isa AbstractArray
204-
min(length(x), length(y))
206+
min(length(x), length(y), length(xprep), length(yprep))
205207
elseif x isa AbstractArray
206-
length(x)
208+
min(length(x), length(xprep))
207209
elseif y isa AbstractArray
208-
length(y)
210+
min(length(y), length(yprep))
209211
else
210212
typemax(Int)
211213
end

DifferentiationInterfaceTest/src/test_differentiation.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ Each setting tests/benchmarks a different subset of calls:
4848
- `rtol=1e-3`: relative precision for correctness testing (when comparing to the reference outputs)
4949
- `scenario_intact=true`: whether to check that the scenario remains unchanged after the operators are applied
5050
- `sparsity=false`: whether to check sparsity patterns for Jacobians / Hessians
51+
- `reprepare::Bool=true`: whether to modify preparation before testing when the preparation arguments have the wrong size
5152
5253
**Type stability options:**
5354

DifferentiationInterfaceTest/test/standard.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ test_differentiation(
2020
test_differentiation(
2121
[AutoForwardDiff(), AutoFiniteDiff(; relstep=1e-5)],
2222
default_scenarios(;
23-
include_batchified=false, include_normal=false, include_constantorcachified=true
23+
include_batchified=false,
24+
include_normal=false,
25+
include_cachified=true,
26+
include_constantorcachified=true,
2427
);
2528
logging=LOGGING,
2629
)
@@ -35,7 +38,7 @@ sparse_backend = AutoSparse(
3538

3639
test_differentiation(
3740
sparse_backend,
38-
sparse_scenarios(; include_cachified=true, use_tuples=true);
41+
sparse_scenarios(; include_cachified=true, use_tuples=false);
3942
sparsity=true,
4043
logging=LOGGING,
4144
)

DifferentiationInterfaceTest/test/zero_backends.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,17 @@ LOGGING = get(ENV, "CI", "false") == "false"
1111

1212
test_differentiation(
1313
AutoZeroForward(),
14-
default_scenarios(; include_batchified=false);
15-
correctness=false,
14+
map(zero, default_scenarios(; include_batchified=false));
1615
type_stability=:full,
1716
logging=LOGGING,
1817
)
1918

2019
test_differentiation(
2120
AutoZeroReverse(),
22-
default_scenarios(; include_batchified=false);
21+
map(
22+
DifferentiationInterfaceTest.same_function,
23+
default_scenarios(; include_batchified=false),
24+
);
2325
correctness=false,
2426
type_stability=:prepared,
2527
logging=LOGGING,

0 commit comments

Comments
 (0)