Skip to content

Commit 66f7ff3

Browse files
authored
feat: allow and test holomorphic derivatives (#687)
* feat: allow and test holomorphic derivatives * API doc * Add tests * Skip matrix tests for FastDifferentiation * Fix * Proper use of dot * Dep
1 parent 56f2a1c commit 66f7ff3

20 files changed

Lines changed: 177 additions & 78 deletions

File tree

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@
1010

1111
*.csv
1212

13-
playground.jl
13+
playground.jl
14+
.vscode

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.30"
4+
version = "0.6.31"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/docs/src/faq/limitations.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,7 @@ As a result, DifferentiationInterface only supports a single active argument, ca
77

88
## Complex numbers
99

10-
At the moment, complex numbers are only handled by a few AD backends, sometimes using different conventions.
11-
As a result, DifferentiationInterface is only tested on real numbers and complex number support is not part of its API guarantees.
10+
Complex derivatives are only handled by a few AD backends, sometimes using different conventions.
11+
To find the easiest common ground, DifferentiationInterface assumes that whenever complex numbers are involved, the function to differentiate is holomorphic.
12+
This functionality is still considered experimental and not yet part of the public API guarantees.
13+
If you work with non-holomorphic functions, you will need to manually separate real and imaginary parts.

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,9 @@ function _sparse_jacobian_aux!(
319319
)
320320

321321
for b in eachindex(batched_results[a])
322+
if eltype(x) <: Complex
323+
batched_results[a][b] .= conj.(batched_results[a][b])
324+
end
322325
copyto!(
323326
view(compressed_matrix, 1 + ((a - 1) * B + (b - 1)) % N, :),
324327
vec(batched_results[a][b]),

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,9 @@ function _sparse_jacobian_aux!(
220220
)
221221

222222
for b in eachindex(batched_results_reverse[a])
223+
if eltype(x) <: Complex
224+
batched_results_reverse[a][b] .= conj.(batched_results_reverse[a][b])
225+
end
223226
copyto!(
224227
view(compressed_matrix_reverse, 1 + ((a - 1) * Br + (b - 1)) % Nr, :),
225228
vec(batched_results_reverse[a][b]),

DifferentiationInterface/src/first_order/derivative.jl

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,14 @@ end
7474
function prepare_derivative(
7575
f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}
7676
) where {F,C}
77-
pushforward_prep = prepare_pushforward(f, backend, x, (realone(x),), contexts...)
77+
pushforward_prep = prepare_pushforward(f, backend, x, (one(x),), contexts...)
7878
return PushforwardDerivativePrep(pushforward_prep)
7979
end
8080

8181
function prepare_derivative(
8282
f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}
8383
) where {F,C}
84-
pushforward_prep = prepare_pushforward(f!, y, backend, x, (realone(x),), contexts...)
84+
pushforward_prep = prepare_pushforward(f!, y, backend, x, (one(x),), contexts...)
8585
return PushforwardDerivativePrep(pushforward_prep)
8686
end
8787

@@ -95,7 +95,7 @@ function value_and_derivative(
9595
contexts::Vararg{Context,C},
9696
) where {F,C}
9797
y, ty = value_and_pushforward(
98-
f, prep.pushforward_prep, backend, x, (realone(x),), contexts...
98+
f, prep.pushforward_prep, backend, x, (one(x),), contexts...
9999
)
100100
return y, only(ty)
101101
end
@@ -109,7 +109,7 @@ function value_and_derivative!(
109109
contexts::Vararg{Context,C},
110110
) where {F,C}
111111
y, _ = value_and_pushforward!(
112-
f, (der,), prep.pushforward_prep, backend, x, (realone(x),), contexts...
112+
f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...
113113
)
114114
return y, der
115115
end
@@ -121,7 +121,7 @@ function derivative(
121121
x,
122122
contexts::Vararg{Context,C},
123123
) where {F,C}
124-
ty = pushforward(f, prep.pushforward_prep, backend, x, (realone(x),), contexts...)
124+
ty = pushforward(f, prep.pushforward_prep, backend, x, (one(x),), contexts...)
125125
return only(ty)
126126
end
127127

@@ -133,7 +133,7 @@ function derivative!(
133133
x,
134134
contexts::Vararg{Context,C},
135135
) where {F,C}
136-
pushforward!(f, (der,), prep.pushforward_prep, backend, x, (realone(x),), contexts...)
136+
pushforward!(f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...)
137137
return der
138138
end
139139

@@ -148,7 +148,7 @@ function value_and_derivative(
148148
contexts::Vararg{Context,C},
149149
) where {F,C}
150150
y, ty = value_and_pushforward(
151-
f!, y, prep.pushforward_prep, backend, x, (realone(x),), contexts...
151+
f!, y, prep.pushforward_prep, backend, x, (one(x),), contexts...
152152
)
153153
return y, only(ty)
154154
end
@@ -163,7 +163,7 @@ function value_and_derivative!(
163163
contexts::Vararg{Context,C},
164164
) where {F,C}
165165
y, _ = value_and_pushforward!(
166-
f!, y, (der,), prep.pushforward_prep, backend, x, (realone(x),), contexts...
166+
f!, y, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...
167167
)
168168
return y, der
169169
end
@@ -176,7 +176,7 @@ function derivative(
176176
x,
177177
contexts::Vararg{Context,C},
178178
) where {F,C}
179-
ty = pushforward(f!, y, prep.pushforward_prep, backend, x, (realone(x),), contexts...)
179+
ty = pushforward(f!, y, prep.pushforward_prep, backend, x, (one(x),), contexts...)
180180
return only(ty)
181181
end
182182

@@ -189,9 +189,7 @@ function derivative!(
189189
x,
190190
contexts::Vararg{Context,C},
191191
) where {F,C}
192-
pushforward!(
193-
f!, y, (der,), prep.pushforward_prep, backend, x, (realone(x),), contexts...
194-
)
192+
pushforward!(f!, y, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...)
195193
return der
196194
end
197195

DifferentiationInterface/src/first_order/jacobian.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,9 @@ function _jacobian_aux(
319319
dx_batch = pullback(
320320
f_or_f!y..., pullback_prep, backend, x, only(batched_seeds), contexts...
321321
)
322+
if eltype(x) <: Complex
323+
dx_batch = map(conj, dx_batch)
324+
end
322325
block = stack_vec_row(dx_batch)
323326
if aligned
324327
return block
@@ -345,6 +348,9 @@ function _jacobian_aux(
345348
dx_batch = pullback(
346349
f_or_f!y..., pullback_prep_same, backend, x, batched_seeds[a], contexts...
347350
)
351+
if eltype(x) <: Complex
352+
dx_batch = map(conj, dx_batch)
353+
end
348354
block = stack_vec_row(dx_batch)
349355
if !aligned && a == A
350356
return block[1:B_last, :]
@@ -403,7 +409,7 @@ function _jacobian_aux!(
403409
(; N) = batch_size_settings
404410

405411
pullback_prep_same = prepare_pullback_same_point(
406-
f_or_f!y..., prep.pullback_prep, backend, x, batched_seeds[1], contexts...
412+
f_or_f!y..., pullback_prep, backend, x, batched_seeds[1], contexts...
407413
)
408414

409415
for a in eachindex(batched_seeds, batched_results)
@@ -418,6 +424,9 @@ function _jacobian_aux!(
418424
)
419425

420426
for b in eachindex(batched_results[a])
427+
if eltype(x) <: Complex
428+
batched_results[a][b] .= conj.(batched_results[a][b])
429+
end
421430
copyto!(
422431
view(jac, 1 + ((a - 1) * B + (b - 1)) % N, :), vec(batched_results[a][b])
423432
)

DifferentiationInterface/src/first_order/pullback.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ function _pullback_via_pushforward(
156156
contexts::Vararg{Context,C},
157157
) where {F,C}
158158
t1 = pushforward(f, pushforward_prep, backend, x, (one(x),), contexts...)
159-
dx = dot(dy, only(t1))
159+
dx = dot(only(t1), dy)
160160
return dx
161161
end
162162

@@ -170,7 +170,7 @@ function _pullback_via_pushforward(
170170
) where {F,C}
171171
dx = map(CartesianIndices(x)) do j
172172
t1 = pushforward(f, pushforward_prep, backend, x, (basis(backend, x, j),), contexts...)
173-
dot(dy, only(t1))
173+
dot(only(t1), dy)
174174
end
175175
return dx
176176
end
@@ -241,7 +241,7 @@ function _pullback_via_pushforward(
241241
contexts::Vararg{Context,C},
242242
) where {F,C}
243243
t1 = pushforward(f!, y, pushforward_prep, backend, x, (one(x),), contexts...)
244-
dx = dot(dy, only(t1))
244+
dx = dot(only(t1), dy)
245245
return dx
246246
end
247247

@@ -258,7 +258,7 @@ function _pullback_via_pushforward(
258258
t1 = pushforward(
259259
f!, y, pushforward_prep, backend, x, (basis(backend, x, j),), contexts...
260260
)
261-
dot(dy, only(t1))
261+
dot(only(t1), dy)
262262
end
263263
return dx
264264
end

DifferentiationInterface/src/first_order/pushforward.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ function _pushforward_via_pullback(
158158
contexts::Vararg{Context,C},
159159
) where {F,C}
160160
t1 = pullback(f, pullback_prep, backend, x, (one(y),), contexts...)
161-
dy = dot(dx, only(t1))
161+
dy = dot(only(t1), dx)
162162
return dy
163163
end
164164

@@ -173,7 +173,7 @@ function _pushforward_via_pullback(
173173
) where {F,C}
174174
dy = map(CartesianIndices(y)) do i
175175
t1 = pullback(f, pullback_prep, backend, x, (basis(backend, y, i),), contexts...)
176-
dot(dx, only(t1))
176+
dot(only(t1), dx)
177177
end
178178
return dy
179179
end
@@ -245,7 +245,7 @@ function _pushforward_via_pullback(
245245
) where {F,C}
246246
dy = map(CartesianIndices(y)) do i # preserve shape
247247
t1 = pullback(f!, y, pullback_prep, backend, x, (basis(backend, y, i),), contexts...)
248-
dot(dx, only(t1))
248+
dot(only(t1), dx)
249249
end
250250
return dy
251251
end

DifferentiationInterface/src/misc/from_primitive.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
abstract type FromPrimitive <: AbstractADType end
22

3-
function basis(fromprim::FromPrimitive, x::AbstractArray{<:Real}, i)
3+
function basis(fromprim::FromPrimitive, x::AbstractArray, i)
44
return basis(fromprim.backend, x, i)
55
end
66

7-
function multibasis(fromprim::FromPrimitive, x::AbstractArray{<:Real}, inds)
7+
function multibasis(fromprim::FromPrimitive, x::AbstractArray, inds)
88
return multibasis(fromprim.backend, x, inds)
99
end
1010

1111
check_available(fromprim::FromPrimitive) = check_available(fromprim.backend)
1212
inplace_support(fromprim::FromPrimitive) = inplace_support(fromprim.backend)
1313

14-
function BatchSizeSettings(fromprim::FromPrimitive, x::AbstractArray{<:Real})
14+
function BatchSizeSettings(fromprim::FromPrimitive, x::AbstractArray)
1515
return BatchSizeSettings(fromprim.backend, x)
1616
end
1717

0 commit comments

Comments
 (0)