Skip to content

Commit 4feb596

Browse files
ErikQQYgdalle
andauthored
feat: access overloaded inputs from preparation result (#672)
* Add overloaded_inputs for preparations * Format * Change to dualized array and increase test coverage * Proper import * Add tests for one arg and two arg prep * Add overloaded_inputs for pushforwardprep * Add overloaded_inputs for derivative and gradient * Revamp API * No fail fast * Simplify API * Fail fast * Codecov --------- Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
1 parent 240e7e8 commit 4feb596

10 files changed

Lines changed: 90 additions & 1 deletion

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.29"
4+
version = "0.6.30"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,6 @@ include("onearg.jl")
3232
include("twoarg.jl")
3333
include("secondorder.jl")
3434
include("differentiate_with.jl")
35+
include("misc.jl")
3536

3637
end # module
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
## Pushforward
2+
DI.overloaded_input_type(prep::ForwardDiffOneArgPushforwardPrep) = typeof(prep.xdual_tmp)
3+
DI.overloaded_input_type(prep::ForwardDiffTwoArgPushforwardPrep) = typeof(prep.xdual_tmp)
4+
5+
## Derivative
6+
function DI.overloaded_input_type(prep::ForwardDiffOneArgDerivativePrep)
7+
return DI.overloaded_input_type(prep.pushforward_prep)
8+
end
9+
DI.overloaded_input_type(prep::ForwardDiffTwoArgDerivativePrep) = typeof(prep.config.duals)
10+
11+
## Gradient
12+
DI.overloaded_input_type(prep::ForwardDiffGradientPrep) = typeof(prep.config.duals)
13+
14+
## Jacobian
15+
DI.overloaded_input_type(prep::ForwardDiffOneArgJacobianPrep) = typeof(prep.config.duals[2])
16+
DI.overloaded_input_type(prep::ForwardDiffTwoArgJacobianPrep) = typeof(prep.config.duals[2])

DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,6 @@ end
3232

3333
include("onearg.jl")
3434
include("twoarg.jl")
35+
include("utils.jl")
3536

3637
end # module
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
## Gradient
2+
DI.overloaded_input_type(prep::ReverseDiffGradientPrep) = typeof(prep.config.input)
3+
4+
## Jacobian
5+
DI.overloaded_input_type(prep::ReverseDiffOneArgJacobianPrep) = typeof(prep.config.input)
6+
DI.overloaded_input_type(prep::ReverseDiffTwoArgJacobianPrep) = typeof(prep.config.input)

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,3 +329,9 @@ function _sparse_jacobian_aux!(
329329
decompress!(jac, compressed_matrix, coloring_result)
330330
return jac
331331
end
332+
333+
## Operator overloading
334+
335+
function DI.overloaded_input_type(prep::PushforwardSparseJacobianPrep)
336+
return DI.overloaded_input_type(prep.pushforward_prep)
337+
end

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ include("misc/from_primitive.jl")
6363
include("misc/sparsity_detector.jl")
6464
include("misc/simple_finite_diff.jl")
6565
include("misc/zero_backends.jl")
66+
include("misc/overloading.jl")
6667

6768
## Exported
6869

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""
2+
overloaded_input_type(prep)
3+
4+
If it exists, return the overloaded input type which will be passed to the differentiated function when preparation result `prep` is reused.
5+
6+
!!! danger
7+
This function is experimental and not part of the public API.
8+
"""
9+
function overloaded_input_type end

DifferentiationInterface/test/Back/ForwardDiff/test.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,31 @@ test_differentiation(AutoForwardDiff(), static_scenarios(); logging=LOGGING)
104104
end
105105
end
106106
end;
107+
108+
@testset verbose = true "Overloaded inputs" begin
109+
backend = AutoForwardDiff()
110+
sparse_backend = MyAutoSparse(AutoForwardDiff())
111+
112+
# Derivative
113+
x = 1.0
114+
y = [1.0, 1.0]
115+
@test DI.overloaded_input_type(prepare_derivative(copy, backend, x)) ==
116+
ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy),Float64},Float64,1}
117+
@test DI.overloaded_input_type(prepare_derivative(copyto!, y, backend, x)) ==
118+
Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!),Float64},Float64,1}}
119+
120+
# Gradient
121+
x = [1.0, 1.0]
122+
@test DI.overloaded_input_type(prepare_gradient(sum, backend, x)) ==
123+
Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(sum),Float64},Float64,2}}
124+
125+
# Jacobian
126+
x = [1.0, 0.0, 0.0]
127+
@test DI.overloaded_input_type(prepare_jacobian(copy, backend, x)) ==
128+
ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy),Float64},Float64,3}
129+
@test DI.overloaded_input_type(prepare_jacobian(copyto!, similar(x), backend, x)) ==
130+
Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!),Float64},Float64,3}}
131+
@test DI.overloaded_input_type(
132+
prepare_jacobian(copyto!, similar(x), sparse_backend, x)
133+
) == Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!),Float64},Float64,1}}
134+
end;

DifferentiationInterface/test/Back/ReverseDiff/test.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using Pkg
22
Pkg.add(["ForwardDiff", "ReverseDiff"]) # ForwardDiff already in ReverseDiff's deps
33

44
using DifferentiationInterface, DifferentiationInterfaceTest
5+
import DifferentiationInterface as DI
56
using ForwardDiff: ForwardDiff
67
using ReverseDiff: ReverseDiff
78
using StaticArrays: StaticArrays
@@ -38,3 +39,23 @@ test_differentiation(
3839
sparsity=true,
3940
logging=LOGGING,
4041
);
42+
43+
@testset verbose = true "Overloaded inputs" begin
44+
backend = AutoReverseDiff()
45+
46+
# Derivative
47+
x = 1.0
48+
@test_skip DI.overloaded_input_type(prepare_derivative(copy, backend, x)) ==
49+
ReverseDiff.TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}}
50+
51+
# Gradient
52+
x = [1.0; 0.0; 0.0]
53+
@test DI.overloaded_input_type(prepare_gradient(sum, backend, x)) ==
54+
ReverseDiff.TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}}
55+
56+
# Jacobian
57+
@test DI.overloaded_input_type(prepare_jacobian(copy, backend, x)) ==
58+
ReverseDiff.TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}}
59+
@test DI.overloaded_input_type(prepare_jacobian(copyto!, similar(x), backend, x)) ==
60+
ReverseDiff.TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}}
61+
end;

0 commit comments

Comments
 (0)