Skip to content

Commit 633af63

Browse files
authored
In-place second order stuff (#117)
1 parent e2611fa commit 633af63

15 files changed

Lines changed: 203 additions & 45 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ An interface to various automatic differentiation backends in Julia.
1111

1212
This package provides a backend-agnostic syntax to differentiate functions of the following types:
1313

14-
- **allocating**: `f(x) = y`
15-
- **mutating**: `f!(y, x) = nothing`
14+
- _allocating_: `f(x) = y`
15+
- _mutating_: `f!(y, x) = nothing`
1616

1717
## Features
1818

docs/src/backends.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ AutoFastDifferentiation
6767
You can use [`check_available`](@ref) to verify whether a given backend is loaded, like we did below:
6868

6969
```@example backends
70-
header = "| Backend | available |" # hide
70+
header = "| backend | available |" # hide
7171
subheader = "|---|---|" # hide
7272
rows = map(all_backends()) do backend # hide
73-
"| `$(backend_string(backend))` | $(check_available(backend) ? '' : '') |" # hide
73+
"| `$(backend_string(backend))` | $(check_available(backend) ? '' : '') |" # hide
7474
end # hide
7575
Markdown.parse(join(vcat(header, subheader, rows...), "\n")) # hide
7676
```
@@ -82,10 +82,10 @@ Only some are compatible with mutating functions `f!(y, x) = nothing`.
8282
You can use [`check_mutation`](@ref) to check that feature, like we did below:
8383

8484
```@example backends
85-
header = "| Backend | mutation |" # hide
85+
header = "| backend | mutation |" # hide
8686
subheader = "|---|---|" # hide
8787
rows = map(all_backends()) do backend # hide
88-
"| `$(backend_string(backend))` | $(check_mutation(backend) ? '' : '') |" # hide
88+
"| `$(backend_string(backend))` | $(check_mutation(backend) ? '' : '') |" # hide
8989
end # hide
9090
Markdown.parse(join(vcat(header, subheader, rows...), "\n")) # hide
9191
```

docs/src/overview.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ Second-order differentiation is also supported, with the following operators:
5656
| [`hvp`](@ref) | `Any` | `Number` | same as `x` | `size(x)` |
5757
| [`hessian`](@ref) | `AbstractArray` | `Number` | `AbstractMatrix` | `(length(x), length(x))` |
5858

59+
We only define two variants for now:
60+
61+
| out-of-place | in-place (or not) |
62+
| --------------------------- | ----------------------------- |
63+
| [`second_derivative`](@ref) | [`second_derivative!!`](@ref) |
64+
| [`hvp`](@ref) | [`hvp!!`](@ref) |
65+
| [`hessian`](@ref) | [`hessian!!`](@ref) |
66+
5967
!!! danger
6068
This is an experimental functionality, use at your own risk.
6169

ext/DifferentiationInterfaceEnzymeExt/reverse_allocating.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,8 @@ function DI.gradient(f, ::AutoReverseEnzyme, x::AbstractArray, extras::Nothing)
5656
end
5757

5858
function DI.gradient!!(f, grad, ::AutoReverseEnzyme, x::AbstractArray, extras::Nothing)
59-
return gradient!(Reverse, grad, f, x)
59+
grad_sametype = convert(typeof(x), grad)
60+
gradient!(Reverse, grad_sametype, f, x)
61+
grad .= grad_sametype
62+
return grad
6063
end

lib/DifferentiationInterfaceTest/src/tests/benchmark.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ function run_benchmark!(
306306
)
307307
(; f, x, y, dy) = deepcopy(scen)
308308
extras = prepare_second_derivative(f, ba, x)
309-
bench1 = @be second_derivative(f, ba, x, extras)
309+
bench1 = @be mysimilar(dy) second_derivative!!(f, _, ba, x, extras)
310310
# only test allocations if the output is scalar
311311
if allocations && y isa Number
312312
@test 0 == minimum(bench1).allocs
@@ -326,7 +326,7 @@ function run_benchmark!(
326326
)
327327
(; f, x, y, dx) = deepcopy(scen)
328328
extras = prepare_hvp(f, ba, x)
329-
bench1 = @be hvp(f, ba, x, dx, extras)
329+
bench1 = @be mysimilar(dx) hvp!!(f, _, ba, x, dx, extras)
330330
# no test for now
331331
record!(data, ba, op, hvp, scen, bench1)
332332
return nothing
@@ -343,7 +343,8 @@ function run_benchmark!(
343343
)
344344
(; f, x, y) = deepcopy(scen)
345345
extras = prepare_hessian(f, ba, x)
346-
bench1 = @be hessian(f, ba, x, extras)
346+
hess_template = Matrix{typeof(y)}(undef, length(x), length(x))
347+
bench1 = @be similar(hess_template) hessian!!(f, _, ba, x, extras)
347348
# no test for now
348349
record!(data, ba, op, hessian, scen, bench1)
349350
return nothing

lib/DifferentiationInterfaceTest/src/tests/call_count.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,11 @@ end
138138
function test_call_count(
139139
ba::AbstractADType, ::typeof(second_derivative), scen::Scenario{false}
140140
)
141-
(; f, x, y) = deepcopy(scen)
141+
(; f, x, y, dy) = deepcopy(scen)
142142
extras = prepare_second_derivative(CallCounter(f), ba, x)
143143
cc = CallCounter(f)
144-
second_derivative(cc, ba, x, extras)
144+
der2_in = mysimilar(dy)
145+
second_derivative!!(cc, der2_in, ba, x, extras)
145146
# what to test?
146147
return nothing
147148
end
@@ -152,7 +153,8 @@ function test_call_count(ba::AbstractADType, ::typeof(hvp), scen::Scenario{false
152153
(; f, x, y, dx) = deepcopy(scen)
153154
extras = prepare_hvp(CallCounter(f), ba, x)
154155
cc = CallCounter(f)
155-
hvp(cc, ba, x, dx, extras)
156+
p_in = mysimilar(dx)
157+
hvp!!(cc, p_in, ba, x, dx, extras)
156158
# what to test?
157159
return nothing
158160
end
@@ -163,7 +165,8 @@ function test_call_count(ba::AbstractADType, ::typeof(hessian), scen::Scenario{f
163165
(; f, x, y, dx) = deepcopy(scen)
164166
extras = prepare_hessian(CallCounter(f), ba, x)
165167
cc = CallCounter(f)
166-
hessian(cc, ba, x, extras)
168+
hess_in = Matrix{typeof(y)}(undef, length(x), length(x))
169+
hessian!!(cc, hess_in, ba, x, extras)
167170
# what to test?
168171
return nothing
169172
end

lib/DifferentiationInterfaceTest/src/tests/correctness.jl

Lines changed: 61 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ function test_correctness(
1616
::typeof(pushforward),
1717
scen::Scenario{false};
1818
isapprox::Function,
19+
atol,
1920
rtol,
2021
)
2122
(; f, x, y, dx, dy, ref) = new_scen = deepcopy(scen)
@@ -31,7 +32,7 @@ function test_correctness(
3132
dy3 = pushforward(f, ba, x, dx)
3233
dy4 = pushforward!!(f, mysimilar(dy), ba, x, dx)
3334

34-
let ()(x, y) = isapprox(x, y; rtol)
35+
let ()(x, y) = isapprox(x, y; atol, rtol)
3536
@testset "Primal value" begin
3637
@test y1 y
3738
@test y2 y
@@ -52,6 +53,7 @@ function test_correctness(
5253
::typeof(pushforward),
5354
scen::Scenario{true};
5455
isapprox::Function,
56+
atol,
5557
rtol,
5658
)
5759
(; f, x, y, dx, dy, ref) = new_scen = deepcopy(scen)
@@ -65,7 +67,7 @@ function test_correctness(
6567
y10 = mysimilar(y)
6668
y1, dy1 = value_and_pushforward!!(f!, y10, mysimilar(dy), ba, x, dx)
6769

68-
let ()(x, y) = isapprox(x, y; rtol)
70+
let ()(x, y) = isapprox(x, y; atol, rtol)
6971
@testset "Primal value" begin
7072
@test y10 y
7173
@test y1 y
@@ -81,7 +83,12 @@ end
8183
## Pullback
8284

8385
function test_correctness(
84-
ba::AbstractADType, ::typeof(pullback), scen::Scenario{false}; isapprox::Function, rtol
86+
ba::AbstractADType,
87+
::typeof(pullback),
88+
scen::Scenario{false};
89+
isapprox::Function,
90+
atol,
91+
rtol,
8592
)
8693
(; f, x, y, dx, dy, ref) = new_scen = deepcopy(scen)
8794
dx_true = if ref isa AbstractADType
@@ -96,7 +103,7 @@ function test_correctness(
96103
dx3 = pullback(f, ba, x, dy)
97104
dx4 = pullback!!(f, mysimilar(dx), ba, x, dy)
98105

99-
let ()(x, y) = isapprox(x, y; rtol)
106+
let ()(x, y) = isapprox(x, y; atol, rtol)
100107
@testset "Primal value" begin
101108
@test y1 y
102109
@test y2 y
@@ -113,7 +120,12 @@ function test_correctness(
113120
end
114121

115122
function test_correctness(
116-
ba::AbstractADType, ::typeof(pullback), scen::Scenario{true}; isapprox::Function, rtol
123+
ba::AbstractADType,
124+
::typeof(pullback),
125+
scen::Scenario{true};
126+
isapprox::Function,
127+
atol,
128+
rtol,
117129
)
118130
(; f, x, y, dx, dy, ref) = new_scen = deepcopy(scen)
119131
f! = f
@@ -126,7 +138,7 @@ function test_correctness(
126138
y10 = mysimilar(y)
127139
y1, dx1 = value_and_pullback!!(f!, y10, mysimilar(dx), ba, x, dy)
128140

129-
let ()(x, y) = isapprox(x, y; rtol)
141+
let ()(x, y) = isapprox(x, y; atol, rtol)
130142
@testset "Primal value" begin
131143
@test y10 y
132144
@test y1 y
@@ -146,6 +158,7 @@ function test_correctness(
146158
::typeof(derivative),
147159
scen::Scenario{false};
148160
isapprox::Function,
161+
atol,
149162
rtol,
150163
)
151164
(; f, x, y, dx, dy, ref) = new_scen = deepcopy(scen)
@@ -161,7 +174,7 @@ function test_correctness(
161174
der3 = derivative(f, ba, x)
162175
der4 = derivative!!(f, mysimilar(dy), ba, x)
163176

164-
let ()(x, y) = isapprox(x, y; rtol)
177+
let ()(x, y) = isapprox(x, y; atol, rtol)
165178
@testset "Primal value" begin
166179
@test y1 y
167180
@test y2 y
@@ -178,7 +191,12 @@ function test_correctness(
178191
end
179192

180193
function test_correctness(
181-
ba::AbstractADType, ::typeof(derivative), scen::Scenario{true}; isapprox::Function, rtol
194+
ba::AbstractADType,
195+
::typeof(derivative),
196+
scen::Scenario{true};
197+
isapprox::Function,
198+
atol,
199+
rtol,
182200
)
183201
(; f, x, y, dx, dy, ref) = new_scen = deepcopy(scen)
184202
f! = f
@@ -191,7 +209,7 @@ function test_correctness(
191209
y10 = mysimilar(y)
192210
y1, der1 = value_and_derivative!!(f!, y10, mysimilar(dy), ba, x)
193211

194-
let ()(x, y) = isapprox(x, y; rtol)
212+
let ()(x, y) = isapprox(x, y; atol, rtol)
195213
@testset "Primal value" begin
196214
@test y10 y
197215
@test y1 y
@@ -207,7 +225,12 @@ end
207225
## Gradient
208226

209227
function test_correctness(
210-
ba::AbstractADType, ::typeof(gradient), scen::Scenario{false}; isapprox::Function, rtol
228+
ba::AbstractADType,
229+
::typeof(gradient),
230+
scen::Scenario{false};
231+
isapprox::Function,
232+
atol,
233+
rtol,
211234
)
212235
(; f, x, y, dx, dy, ref) = new_scen = deepcopy(scen)
213236
grad_true = if ref isa AbstractADType
@@ -222,7 +245,7 @@ function test_correctness(
222245
grad3 = gradient(f, ba, x)
223246
grad4 = gradient!!(f, mysimilar(dx), ba, x)
224247

225-
let ()(x, y) = isapprox(x, y; rtol)
248+
let ()(x, y) = isapprox(x, y; atol, rtol)
226249
@testset "Primal value" begin
227250
@test y1 y
228251
@test y2 y
@@ -241,7 +264,12 @@ end
241264
## Jacobian
242265

243266
function test_correctness(
244-
ba::AbstractADType, ::typeof(jacobian), scen::Scenario{false}; isapprox::Function, rtol
267+
ba::AbstractADType,
268+
::typeof(jacobian),
269+
scen::Scenario{false};
270+
isapprox::Function,
271+
atol,
272+
rtol,
245273
)
246274
(; f, x, y, ref) = new_scen = deepcopy(scen)
247275
jac_true = if ref isa AbstractADType
@@ -256,7 +284,7 @@ function test_correctness(
256284
jac3 = jacobian(f, ba, x)
257285
jac4 = jacobian!!(f, mysimilar(jac_true), ba, x)
258286

259-
let ()(x, y) = isapprox(x, y; rtol)
287+
let ()(x, y) = isapprox(x, y; atol, rtol)
260288
@testset "Primal value" begin
261289
@test y1 y
262290
@test y2 y
@@ -273,7 +301,12 @@ function test_correctness(
273301
end
274302

275303
function test_correctness(
276-
ba::AbstractADType, ::typeof(jacobian), scen::Scenario{true}; isapprox::Function, rtol
304+
ba::AbstractADType,
305+
::typeof(jacobian),
306+
scen::Scenario{true};
307+
isapprox::Function,
308+
atol,
309+
rtol,
277310
)
278311
(; f, x, y, dy, ref) = new_scen = deepcopy(scen)
279312
f! = f
@@ -287,7 +320,7 @@ function test_correctness(
287320
y10 = mysimilar(y)
288321
y1, jac1 = value_and_jacobian!!(f!, y10, mysimilar(jac_true), ba, x)
289322

290-
let ()(x, y) = isapprox(x, y; rtol)
323+
let ()(x, y) = isapprox(x, y; atol, rtol)
291324
@testset "Primal value" begin
292325
@test y10 y
293326
@test y1 y
@@ -307,20 +340,23 @@ function test_correctness(
307340
::typeof(second_derivative),
308341
scen::Scenario;
309342
isapprox::Function,
343+
atol,
310344
rtol,
311345
)
312-
(; f, x, ref) = new_scen = deepcopy(scen)
346+
(; f, x, dy, ref) = new_scen = deepcopy(scen)
313347
der2_true = if ref isa AbstractADType
314348
second_derivative(f, ref, x)
315349
else
316350
ref.second_derivative(x)
317351
end
318352

319353
der21 = second_derivative(f, ba, x)
354+
der22 = second_derivative!!(f, mysimilar(dy), ba, x)
320355

321-
let ()(x, y) = isapprox(x, y; rtol)
356+
let ()(x, y) = isapprox(x, y; atol, rtol)
322357
@testset "Second derivative value" begin
323358
@test der21 der2_true
359+
@test der22 der2_true
324360
end
325361
end
326362
test_scen_intact(new_scen, scen)
@@ -330,7 +366,7 @@ end
330366
## Hessian-vector product
331367

332368
function test_correctness(
333-
ba::AbstractADType, ::typeof(hvp), scen::Scenario; isapprox::Function, rtol
369+
ba::AbstractADType, ::typeof(hvp), scen::Scenario; isapprox::Function, atol, rtol
334370
)
335371
(; f, x, dx, ref) = new_scen = deepcopy(scen)
336372
hvp_true = if ref isa AbstractADType
@@ -340,10 +376,12 @@ function test_correctness(
340376
end
341377

342378
hvp1 = hvp(f, ba, x, dx)
379+
hvp2 = hvp!!(f, mysimilar(dx), ba, x, dx)
343380

344-
let ()(x, y) = isapprox(x, y; rtol)
381+
let ()(x, y) = isapprox(x, y; atol, rtol)
345382
@testset "HVP value" begin
346383
@test hvp1 hvp_true
384+
@test hvp2 hvp_true
347385
end
348386
end
349387
test_scen_intact(new_scen, scen)
@@ -353,7 +391,7 @@ end
353391
## Hessian
354392

355393
function test_correctness(
356-
ba::AbstractADType, ::typeof(hessian), scen::Scenario; isapprox::Function, rtol
394+
ba::AbstractADType, ::typeof(hessian), scen::Scenario; isapprox::Function, atol, rtol
357395
)
358396
(; f, x, y, ref) = new_scen = deepcopy(scen)
359397
hess_true = if ref isa AbstractADType
@@ -363,10 +401,12 @@ function test_correctness(
363401
end
364402

365403
hess1 = hessian(f, ba, x)
404+
hess2 = hessian!!(f, mysimilar(hess_true), ba, x)
366405

367-
let ()(x, y) = isapprox(x, y; rtol)
406+
let ()(x, y) = isapprox(x, y; atol, rtol)
368407
@testset "Hessian value" begin
369408
@test hess1 hess_true
409+
@test hess2 hess_true
370410
end
371411
end
372412
test_scen_intact(new_scen, scen)

0 commit comments

Comments
 (0)