Skip to content

Commit a86ebb8

Browse files
authored
Use dy=true for gradient (smallest possible one) (#363)
1 parent 6653db9 commit a86ebb8

5 files changed

Lines changed: 37 additions & 35 deletions

File tree

DifferentiationInterface/src/first_order/gradient.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,7 @@ struct PullbackGradientExtras{E<:PullbackExtras} <: GradientExtras
6262
end
6363

6464
function prepare_gradient(f::F, backend::AbstractADType, x) where {F}
65-
y = f(x)
66-
dy = one(y)
67-
pullback_extras = prepare_pullback(f, backend, x, dy)
65+
pullback_extras = prepare_pullback(f, backend, x, true)
6866
return PullbackGradientExtras(pullback_extras)
6967
end
7068

@@ -93,23 +91,23 @@ end
9391
function value_and_gradient(
9492
f::F, backend::AbstractADType, x, extras::PullbackGradientExtras
9593
) where {F}
96-
return value_and_pullback(f, backend, x, one(eltype(x)), extras.pullback_extras)
94+
return value_and_pullback(f, backend, x, true, extras.pullback_extras)
9795
end
9896

9997
function value_and_gradient!(
10098
f::F, grad, backend::AbstractADType, x, extras::PullbackGradientExtras
10199
) where {F}
102-
return value_and_pullback!(f, grad, backend, x, one(eltype(x)), extras.pullback_extras)
100+
return value_and_pullback!(f, grad, backend, x, true, extras.pullback_extras)
103101
end
104102

105103
function gradient(
106104
f::F, backend::AbstractADType, x, extras::PullbackGradientExtras
107105
) where {F}
108-
return pullback(f, backend, x, one(eltype(x)), extras.pullback_extras)
106+
return pullback(f, backend, x, true, extras.pullback_extras)
109107
end
110108

111109
function gradient!(
112110
f::F, grad, backend::AbstractADType, x, extras::PullbackGradientExtras
113111
) where {F}
114-
return pullback!(f, grad, backend, x, one(eltype(x)), extras.pullback_extras)
112+
return pullback!(f, grad, backend, x, true, extras.pullback_extras)
115113
end

DifferentiationInterface/src/first_order/jacobian.jl

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ function prepare_jacobian(f!::F, y, backend::AbstractADType, x) where {F}
8686
return prepare_jacobian_aux((f!, y), backend, x, y, pushforward_performance(backend))
8787
end
8888

89-
function prepare_jacobian_aux(f_or_f!y::FY, backend, x, y, ::PushforwardFast) where {FY}
89+
function prepare_jacobian_aux(
90+
f_or_f!y::FY, backend::AbstractADType, x, y, ::PushforwardFast
91+
) where {FY}
9092
N = length(x)
9193
B = pick_batchsize(backend, N)
9294
seeds = [basis(backend, x, ind) for ind in CartesianIndices(x)]
@@ -107,7 +109,9 @@ function prepare_jacobian_aux(f_or_f!y::FY, backend, x, y, ::PushforwardFast) wh
107109
)
108110
end
109111

110-
function prepare_jacobian_aux(f_or_f!y::FY, backend, x, y, ::PushforwardSlow) where {FY}
112+
function prepare_jacobian_aux(
113+
f_or_f!y::FY, backend::AbstractADType, x, y, ::PushforwardSlow
114+
) where {FY}
111115
M = length(y)
112116
B = pick_batchsize(backend, M)
113117
seeds = [basis(backend, y, ind) for ind in CartesianIndices(y)]
@@ -221,7 +225,7 @@ end
221225
## Common auxiliaries
222226

223227
function jacobian_aux(
224-
f_or_f!y::FY, backend, x::AbstractArray, extras::PushforwardJacobianExtras{B}
228+
f_or_f!y::FY, backend::AbstractADType, x, extras::PushforwardJacobianExtras{B}
225229
) where {FY,B}
226230
@compat (; batched_seeds, pushforward_batched_extras, N) = extras
227231

@@ -244,7 +248,7 @@ function jacobian_aux(
244248
end
245249

246250
function jacobian_aux(
247-
f_or_f!y::FY, backend, x::AbstractArray, extras::PullbackJacobianExtras{B}
251+
f_or_f!y::FY, backend::AbstractADType, x, extras::PullbackJacobianExtras{B}
248252
) where {FY,B}
249253
@compat (; batched_seeds, pullback_batched_extras, M) = extras
250254

@@ -267,11 +271,7 @@ function jacobian_aux(
267271
end
268272

269273
function jacobian_aux!(
270-
f_or_f!y::FY,
271-
jac::AbstractMatrix,
272-
backend,
273-
x::AbstractArray,
274-
extras::PushforwardJacobianExtras{B},
274+
f_or_f!y::FY, jac, backend::AbstractADType, x, extras::PushforwardJacobianExtras{B}
275275
) where {FY,B}
276276
@compat (; batched_seeds, batched_results, pushforward_batched_extras, N) = extras
277277

@@ -303,11 +303,7 @@ function jacobian_aux!(
303303
end
304304

305305
function jacobian_aux!(
306-
f_or_f!y::FY,
307-
jac::AbstractMatrix,
308-
backend,
309-
x::AbstractArray,
310-
extras::PullbackJacobianExtras{B},
306+
f_or_f!y::FY, jac, backend::AbstractADType, x, extras::PullbackJacobianExtras{B}
311307
) where {FY,B}
312308
@compat (; batched_seeds, batched_results, pullback_batched_extras, M) = extras
313309

DifferentiationInterface/src/first_order/pullback.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,23 +110,27 @@ function prepare_pullback(f!::F, y, backend::AbstractADType, x, dy) where {F}
110110
return prepare_pullback_aux(f!, y, backend, x, dy, pullback_performance(backend))
111111
end
112112

113-
function prepare_pullback_aux(f::F, backend, x, dy, ::PullbackSlow) where {F}
113+
function prepare_pullback_aux(
114+
f::F, backend::AbstractADType, x, dy, ::PullbackSlow
115+
) where {F}
114116
dx = x isa Number ? one(x) : basis(backend, x, first(CartesianIndices(x)))
115117
pushforward_extras = prepare_pushforward(f, backend, x, dx)
116118
return PushforwardPullbackExtras(pushforward_extras)
117119
end
118120

119-
function prepare_pullback_aux(f!::F, y, backend, x, dy, ::PullbackSlow) where {F}
121+
function prepare_pullback_aux(
122+
f!::F, y, backend::AbstractADType, x, dy, ::PullbackSlow
123+
) where {F}
120124
dx = x isa Number ? one(x) : basis(backend, x, first(CartesianIndices(x)))
121125
pushforward_extras = prepare_pushforward(f!, y, backend, x, dx)
122126
return PushforwardPullbackExtras(pushforward_extras)
123127
end
124128

125-
function prepare_pullback_aux(f, backend, x, dy, ::PullbackFast)
129+
function prepare_pullback_aux(f, backend::AbstractADType, x, dy, ::PullbackFast)
126130
throw(MissingBackendError(backend))
127131
end
128132

129-
function prepare_pullback_aux(f!, y, backend, x, dy, ::PullbackFast)
133+
function prepare_pullback_aux(f!, y, backend::AbstractADType, x, dy, ::PullbackFast)
130134
throw(MissingBackendError(backend))
131135
end
132136

@@ -177,7 +181,7 @@ end
177181
### With extras
178182

179183
function value_and_pullback(
180-
f::F, backend, x, dy, extras::PushforwardPullbackExtras
184+
f::F, backend::AbstractADType, x, dy, extras::PushforwardPullbackExtras
181185
) where {F}
182186
@compat (; pushforward_extras) = extras
183187
y = f(x)
@@ -241,7 +245,7 @@ end
241245
### With extras
242246

243247
function value_and_pullback(
244-
f!::F, y, backend, x, dy, extras::PushforwardPullbackExtras
248+
f!::F, y, backend::AbstractADType, x, dy, extras::PushforwardPullbackExtras
245249
) where {F}
246250
@compat (; pushforward_extras) = extras
247251
dx = if x isa Number && y isa AbstractArray

DifferentiationInterface/src/first_order/pushforward.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,24 +110,28 @@ function prepare_pushforward(f!::F, y, backend::AbstractADType, x, dx) where {F}
110110
return prepare_pushforward_aux(f!, y, backend, x, dx, pushforward_performance(backend))
111111
end
112112

113-
function prepare_pushforward_aux(f::F, backend, x, dx, ::PushforwardSlow) where {F}
113+
function prepare_pushforward_aux(
114+
f::F, backend::AbstractADType, x, dx, ::PushforwardSlow
115+
) where {F}
114116
y = f(x)
115117
dy = y isa Number ? one(y) : basis(backend, y, first(CartesianIndices(y)))
116118
pullback_extras = prepare_pullback(f, backend, x, dy)
117119
return PullbackPushforwardExtras(pullback_extras)
118120
end
119121

120-
function prepare_pushforward_aux(f!::F, y, backend, x, dx, ::PushforwardSlow) where {F}
122+
function prepare_pushforward_aux(
123+
f!::F, y, backend::AbstractADType, x, dx, ::PushforwardSlow
124+
) where {F}
121125
dy = y isa Number ? one(y) : basis(backend, y, first(CartesianIndices(y)))
122126
pullback_extras = prepare_pullback(f!, y, backend, x, dy)
123127
return PullbackPushforwardExtras(pullback_extras)
124128
end
125129

126-
function prepare_pushforward_aux(f, backend, x, dy, ::PushforwardFast)
130+
function prepare_pushforward_aux(f, backend::AbstractADType, x, dx, ::PushforwardFast)
127131
throw(MissingBackendError(backend))
128132
end
129133

130-
function prepare_pushforward_aux(f!, y, backend, x, dy, ::PushforwardFast)
134+
function prepare_pushforward_aux(f!, y, backend::AbstractADType, x, dx, ::PushforwardFast)
131135
throw(MissingBackendError(backend))
132136
end
133137

@@ -180,7 +184,7 @@ end
180184
### With extras
181185

182186
function value_and_pushforward(
183-
f::F, backend, x, dx, extras::PullbackPushforwardExtras
187+
f::F, backend::AbstractADType, x, dx, extras::PullbackPushforwardExtras
184188
) where {F}
185189
@compat (; pullback_extras) = extras
186190
y = f(x)
@@ -248,7 +252,7 @@ end
248252
### With extras
249253

250254
function value_and_pushforward(
251-
f!::F, y, backend, x, dx, extras::PullbackPushforwardExtras
255+
f!::F, y, backend::AbstractADType, x, dx, extras::PullbackPushforwardExtras
252256
) where {F}
253257
@compat (; pullback_extras) = extras
254258
dy = if x isa Number && y isa AbstractArray

DifferentiationInterface/src/sparse/jacobian.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ function prepare_jacobian(f!::F, y, backend::AutoSparse, x) where {F}
8484
end
8585

8686
function prepare_sparse_jacobian_aux(
87-
f_or_f!y::FY, backend, x, y, ::PushforwardFast
87+
f_or_f!y::FY, backend::AutoSparse, x, y, ::PushforwardFast
8888
) where {FY}
8989
dense_backend = dense_ad(backend)
9090
initial_sparsity = jacobian_sparsity(f_or_f!y..., x, sparsity_detector(backend))
@@ -116,7 +116,7 @@ function prepare_sparse_jacobian_aux(
116116
end
117117

118118
function prepare_sparse_jacobian_aux(
119-
f_or_f!y::FY, backend, x, y, ::PushforwardSlow
119+
f_or_f!y::FY, backend::AutoSparse, x, y, ::PushforwardSlow
120120
) where {FY}
121121
dense_backend = dense_ad(backend)
122122
initial_sparsity = jacobian_sparsity(f_or_f!y..., x, sparsity_detector(backend))

0 commit comments

Comments
 (0)