Skip to content

Commit 4d1b4a8

Browse files
authored
Add jacobian!! and derivative!! for mutating functions (#160)
1 parent a1dfaf1 commit 4d1b4a8

11 files changed

Lines changed: 197 additions & 9 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_allocating.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,24 +73,20 @@ end
7373

7474
DI.prepare_gradient(f, ::AutoReverseEnzyme, x) = NoGradientExtras()
7575

76-
function DI.gradient(f, ::AutoReverseEnzyme, x::AbstractArray, ::NoGradientExtras)
76+
function DI.gradient(f, ::AutoReverseEnzyme, x, ::NoGradientExtras)
7777
return gradient(Reverse, f, x)
7878
end
7979

80-
function DI.gradient!!(f, grad, ::AutoReverseEnzyme, x::AbstractArray, ::NoGradientExtras)
80+
function DI.gradient!!(f, grad, ::AutoReverseEnzyme, x, ::NoGradientExtras)
8181
grad_sametype = convert(typeof(x), grad)
8282
gradient!(Reverse, grad_sametype, f, x)
8383
return grad_sametype
8484
end
8585

86-
function DI.value_and_gradient(
87-
f, backend::AutoReverseEnzyme, x::AbstractArray, ::NoGradientExtras
88-
)
86+
function DI.value_and_gradient(f, backend::AutoReverseEnzyme, x, ::NoGradientExtras)
8987
return DI.value_and_pullback(f, backend, x, one(eltype(x)), NoPullbackExtras())
9088
end
9189

92-
function DI.value_and_gradient!!(
93-
f, grad, backend::AutoReverseEnzyme, x::AbstractArray, ::NoGradientExtras
94-
)
90+
function DI.value_and_gradient!!(f, grad, backend::AutoReverseEnzyme, x, ::NoGradientExtras)
9591
return DI.value_and_pullback!!(f, grad, backend, x, one(eltype(x)), NoPullbackExtras())
9692
end

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/mutating.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,20 @@ function DI.value_and_pushforward!!(
3535
return y, dy
3636
end
3737

38+
function DI.pushforward!!(
39+
f!,
40+
y,
41+
dy,
42+
::AnyAutoFastDifferentiation,
43+
x,
44+
dx,
45+
extras::FastDifferentiationMutatingPushforwardExtras,
46+
)
47+
v_vec = vcat(myvec(x), myvec(dx))
48+
extras.jvp_exe!(vec(dy), v_vec)
49+
return dy
50+
end
51+
3852
## Derivative
3953

4054
struct FastDifferentiationMutatingDerivativeExtras{E} <: DerivativeExtras
@@ -66,6 +80,18 @@ function DI.value_and_derivative!!(
6680
return y, der
6781
end
6882

83+
function DI.derivative!!(
84+
f!,
85+
y,
86+
der,
87+
::AnyAutoFastDifferentiation,
88+
x,
89+
extras::FastDifferentiationMutatingDerivativeExtras,
90+
)
91+
extras.der_exe!(der, monovec(x))
92+
return der
93+
end
94+
6995
## Jacobian
7096

7197
struct FastDifferentiationMutatingJacobianExtras{E} <: JacobianExtras
@@ -100,3 +126,15 @@ function DI.value_and_jacobian!!(
100126
extras.jac_exe!(jac, vec(x))
101127
return y, jac
102128
end
129+
130+
function DI.jacobian!!(
131+
f!,
132+
y,
133+
jac,
134+
::AnyAutoFastDifferentiation,
135+
x,
136+
extras::FastDifferentiationMutatingJacobianExtras,
137+
)
138+
extras.jac_exe!(jac, vec(x))
139+
return jac
140+
end

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/mutating.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,18 @@ function DI.value_and_derivative!!(
4747
return y, der
4848
end
4949

50+
function DI.derivative!!(
51+
f!,
52+
y::AbstractArray,
53+
der::AbstractArray,
54+
backend::AnyAutoFiniteDiff,
55+
x,
56+
::FiniteDiffMutatingDerivativeExtras,
57+
)
58+
finite_difference_gradient!(der, f!, x, fdtype(backend), eltype(y), FUNCTION_INPLACE)
59+
return der
60+
end
61+
5062
## Jacobian
5163

5264
struct FiniteDiffMutatingJacobianExtras{C}
@@ -73,3 +85,15 @@ function DI.value_and_jacobian!!(
7385
f!(y, x)
7486
return y, jac
7587
end
88+
89+
function DI.jacobian!!(
90+
f!,
91+
y::AbstractArray,
92+
jac::AbstractMatrix,
93+
::AnyAutoFiniteDiff,
94+
x,
95+
extras::FiniteDiffMutatingJacobianExtras,
96+
)
97+
finite_difference_jacobian!(jac, f!, x, extras.cache)
98+
return jac
99+
end

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/mutating.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,18 @@ function DI.value_and_derivative!!(
3535
return DiffResults.value(result), DiffResults.derivative(result)
3636
end
3737

38+
function DI.derivative!!(
39+
f!,
40+
y::AbstractArray,
41+
der::AbstractArray,
42+
::AnyAutoForwardDiff,
43+
x::Number,
44+
extras::ForwardDiffMutatingDerivativeExtras,
45+
)
46+
der = derivative!(der, f!, y, x, extras.config)
47+
return der
48+
end
49+
3850
## Jacobian
3951

4052
struct ForwardDiffMutatingJacobianExtras{C} <: JacobianExtras
@@ -61,3 +73,15 @@ function DI.value_and_jacobian!!(
6173
result = jacobian!(result, f!, y, x, extras.config)
6274
return DiffResults.value(result), DiffResults.jacobian(result)
6375
end
76+
77+
function DI.jacobian!!(
78+
f!,
79+
y::AbstractArray,
80+
jac::AbstractMatrix,
81+
::AnyAutoForwardDiff,
82+
x::AbstractArray,
83+
extras::ForwardDiffMutatingJacobianExtras,
84+
)
85+
jac = jacobian!(jac, f!, y, x, extras.config)
86+
return jac
87+
end

DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/mutating.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ function DI.value_and_derivative!!(
2222
return DI.value_and_derivative!!(f!, y, der, single_threaded(backend), x, extras)
2323
end
2424

25+
function DI.derivative!!(
26+
f!, y, der, backend::AnyAutoPolyForwardDiff, x, extras::DerivativeExtras
27+
)
28+
return DI.derivative!!(f!, y, der, single_threaded(backend), x, extras)
29+
end
30+
2531
## Jacobian
2632

2733
DI.prepare_jacobian(f!, ::AnyAutoPolyForwardDiff, y, x) = NoJacobianExtras()
@@ -34,7 +40,19 @@ function DI.value_and_jacobian!!(
3440
x::AbstractArray,
3541
::NoJacobianExtras,
3642
) where {C}
37-
f!(y, x)
3843
threaded_jacobian!(f!, y, jac, x, Chunk{C}())
44+
f!(y, x)
3945
return y, jac
4046
end
47+
48+
function DI.jacobian!!(
49+
f!,
50+
y::AbstractArray,
51+
jac::AbstractMatrix,
52+
::AnyAutoPolyForwardDiff{C},
53+
x::AbstractArray,
54+
::NoJacobianExtras,
55+
) where {C}
56+
threaded_jacobian!(f!, y, jac, x, Chunk{C}())
57+
return jac
58+
end

DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/mutating.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,24 @@ function DI.value_and_pullback!!(
2323
return y, dx
2424
end
2525

26+
function DI.pullback!!(
27+
f!,
28+
y::AbstractArray,
29+
dx::AbstractArray,
30+
::AnyAutoReverseDiff,
31+
x::AbstractArray,
32+
dy::AbstractArray,
33+
::NoPullbackExtras,
34+
)
35+
function dotproduct_closure(x)
36+
y_copy = similar(y, eltype(x))
37+
f!(y_copy, x)
38+
return dot(y_copy, dy)
39+
end
40+
dx = gradient!(dx, dotproduct_closure, x)
41+
return dx
42+
end
43+
2644
### Number in, not supported
2745

2846
function DI.value_and_pullback!!(
@@ -72,3 +90,15 @@ function DI.value_and_jacobian!!(
7290
result = jacobian!(result, extras.tape, x)
7391
return DiffResults.value(result), DiffResults.derivative(result)
7492
end
93+
94+
function DI.jacobian!!(
95+
_f!,
96+
y::AbstractArray,
97+
jac::AbstractMatrix,
98+
::AnyAutoReverseDiff,
99+
x::AbstractArray,
100+
extras::ReverseDiffMutatingJacobianExtras,
101+
)
102+
jac = jacobian!(jac, extras.tape, x)
103+
return jac
104+
end

DifferentiationInterface/ext/DifferentiationInterfaceSparseDiffToolsExt/mutating.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,17 @@ for AutoSparse in SPARSE_BACKENDS
2727
f!(y, x)
2828
return y, jac
2929
end
30+
31+
function DI.jacobian!!(
32+
f!,
33+
y,
34+
jac,
35+
backend::$AutoSparse,
36+
x,
37+
extras::SparseDiffToolsMutatingJacobianExtras,
38+
)
39+
sparse_jacobian!(jac, backend, extras.cache, f!, y, x)
40+
return jac
41+
end
3042
end
3143
end

DifferentiationInterface/src/derivative.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,17 @@ function value_and_derivative!!(
9696
f!, y, der, backend, x, one(x), extras.pushforward_extras
9797
)
9898
end
99+
100+
"""
101+
derivative!!(f!, y, der, backend, x, [extras]) -> der
102+
"""
103+
function derivative!!(
104+
f!,
105+
y,
106+
der,
107+
backend::AbstractADType,
108+
x,
109+
extras::DerivativeExtras=prepare_derivative(f!, backend, y, x),
110+
)
111+
return pushforward!!(f!, y, der, backend, x, one(x), extras.pushforward_extras)
112+
end

DifferentiationInterface/src/jacobian.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,17 @@ function value_and_jacobian_aux!!(
203203
end
204204
return y, jac
205205
end
206+
207+
"""
208+
jacobian!!(f!, y, jac, backend, x, [extras]) -> jac
209+
"""
210+
function jacobian!!(
211+
f!,
212+
y,
213+
jac,
214+
backend::AbstractADType,
215+
x,
216+
extras::JacobianExtras=prepare_jacobian(f!, backend, y, x),
217+
)
218+
return value_and_jacobian!!(f!, y, jac, backend, x, extras)[2]
219+
end

DifferentiationInterfaceTest/src/tests/benchmark.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,15 +245,19 @@ function run_benchmark!(
245245
bench1 = @be (mysimilar(y), mysimilar(y)) value_and_derivative!!(
246246
f!, _[1], _[2], ba, x, extras
247247
)
248+
bench2 = @be (mysimilar(y), mysimilar(y)) derivative!!(f!, _[1], _[2], ba, x, extras)
248249
# count
249250
cc! = CallCounter(f!)
250251
extras = prepare_derivative(cc!, ba, y, x)
251252
calls0 = reset_count!(cc!)
252253
value_and_derivative!!(cc!, mysimilar(y), mysimilar(y), ba, x, extras)
253254
calls1 = reset_count!(cc!)
255+
derivative!!(cc!, mysimilar(y), mysimilar(y), ba, x, extras)
256+
calls2 = reset_count!(cc!)
254257
# record
255258
record!(data, ba, scen, prepare_derivative, bench0, calls0)
256259
record!(data, ba, scen, value_and_derivative!!, bench1, calls1)
260+
record!(data, ba, scen, derivative!!, bench2, calls2)
257261
return nothing
258262
end
259263

@@ -322,15 +326,21 @@ function run_benchmark!(
322326
bench1 = @be (mysimilar(y), mysimilar(jac_template)) value_and_jacobian!!(
323327
f!, _[1], _[2], ba, x, extras
324328
)
329+
bench2 = @be (mysimilar(y), mysimilar(jac_template)) jacobian!!(
330+
f!, _[1], _[2], ba, x, extras
331+
)
325332
# count
326333
cc! = CallCounter(f!)
327334
extras = prepare_jacobian(cc!, ba, y, x)
328335
calls0 = reset_count!(cc!)
329336
value_and_jacobian!!(cc!, mysimilar(y), mysimilar(jac_template), ba, x, extras)
330337
calls1 = reset_count!(cc!)
338+
jacobian!!(cc!, mysimilar(y), mysimilar(jac_template), ba, x, extras)
339+
calls2 = reset_count!(cc!)
331340
# record
332341
record!(data, ba, scen, prepare_jacobian, bench0, calls0)
333342
record!(data, ba, scen, value_and_jacobian!!, bench1, calls1)
343+
record!(data, ba, scen, jacobian!!, bench2, calls2)
334344
return nothing
335345
end
336346

0 commit comments

Comments
 (0)