Skip to content

Commit 02ae6b5

Browse files
committed
fix: handle constant derivatives with runtime activity for Enzyme
1 parent a5fb081 commit 02ae6b5

8 files changed

Lines changed: 139 additions & 8 deletions

File tree

DifferentiationInterface/CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.7...main)
99

10+
### Fixed
11+
12+
- Handle constant derivatives with runtime activity for Enzyme
13+
1014
## [0.7.7](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.6...DifferentiationInterface-v0.7.7)
1115

16+
### Fixed
17+
1218
- Improve support for empty inputs (still not guaranteed) ([#835](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/835))
1319

1420
## [0.7.6](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.5...DifferentiationInterface-v0.7.6)

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module DifferentiationInterfaceEnzymeExt
22

33
using ADTypes: ADTypes, AutoEnzyme
4-
using Base: Fix1
4+
using Base: Fix1, datatype_pointerfree
55
import DifferentiationInterface as DI
66
using EnzymeCore:
77
Active,
@@ -42,7 +42,8 @@ using Enzyme:
4242
jacobian,
4343
make_zero,
4444
make_zero!,
45-
onehot
45+
onehot,
46+
runtime_activity
4647

4748
DI.check_available(::AutoEnzyme) = true
4849

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ function DI.value_and_pushforward(
3737
x_and_dx = Duplicated(x, dx)
3838
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1))
3939
dy, y = autodiff(mode, f_and_df, x_and_dx, annotated_contexts...)
40+
dy = runtime_activity_safeguard(backend, y, dy)
4041
return y, (dy,)
4142
end
4243

@@ -54,8 +55,10 @@ function DI.value_and_pushforward(
5455
f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B))
5556
x_and_tx = BatchDuplicated(x, tx)
5657
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B))
57-
ty, y = autodiff(mode, f_and_df, x_and_tx, annotated_contexts...)
58-
return y, values(ty)
58+
ty_nt, y = autodiff(mode, f_and_df, x_and_tx, annotated_contexts...)
59+
ty = values(ty_nt)
60+
ty = runtime_activity_safeguard(backend, y, ty)
61+
return y, ty
5962
end
6063

6164
function DI.pushforward(
@@ -66,6 +69,9 @@ function DI.pushforward(
6669
tx::NTuple{1},
6770
contexts::Vararg{DI.Context,C},
6871
) where {F,C}
72+
if has_runtime_activity(backend)
73+
return DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)[2]
74+
end
6975
DI.check_prep(f, prep, backend, x, tx, contexts...)
7076
(; df, context_shadows) = prep
7177
mode = forward_noprimal(backend)
@@ -85,14 +91,18 @@ function DI.pushforward(
8591
tx::NTuple{B},
8692
contexts::Vararg{DI.Context,C},
8793
) where {F,B,C}
94+
if has_runtime_activity(backend)
95+
return DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)[2]
96+
end
8897
DI.check_prep(f, prep, backend, x, tx, contexts...)
8998
(; df, context_shadows) = prep
9099
mode = forward_noprimal(backend)
91100
f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B))
92101
x_and_tx = BatchDuplicated(x, tx)
93102
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B))
94-
ty = only(autodiff(mode, f_and_df, x_and_tx, annotated_contexts...))
95-
return values(ty)
103+
ty_nt = only(autodiff(mode, f_and_df, x_and_tx, annotated_contexts...))
104+
ty = values(ty_nt)
105+
return ty
96106
end
97107

98108
function DI.value_and_pushforward!(
@@ -168,7 +178,9 @@ function DI.gradient(
168178
derivs = gradient(
169179
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=basis_shadows
170180
)
171-
return first(derivs)
181+
deriv = first(derivs)
182+
deriv = runtime_activity_safeguard(backend, x, deriv)
183+
return deriv
172184
end
173185

174186
function DI.value_and_gradient(
@@ -186,7 +198,9 @@ function DI.value_and_gradient(
186198
(; derivs, val) = gradient(
187199
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=basis_shadows
188200
)
189-
return val, first(derivs)
201+
deriv = first(derivs)
202+
deriv = runtime_activity_safeguard(backend, x, deriv)
203+
return val, deriv
190204
end
191205

192206
function DI.gradient!(

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ function seeded_autodiff_thunk(
77
) where {ReturnPrimal,FA<:Annotation,RA<:Annotation,N}
88
forward, reverse = autodiff_thunk(rmode, FA, RA, typeof.(args)...)
99
tape, result, shadow_result = forward(f, args...)
10+
shadow_result = runtime_activity_safeguard(rmode, result, shadow_result)
1011
if RA <: Active
1112
dinputs = only(reverse(f, args..., dresult, tape))
1213
else
@@ -30,6 +31,7 @@ function batch_seeded_autodiff_thunk(
3031
rmode_rightwidth = ReverseSplitWidth(rmode, Val(B))
3132
forward, reverse = autodiff_thunk(rmode_rightwidth, FA, RA, typeof.(args)...)
3233
tape, result, shadow_results = forward(f, args...)
34+
shadow_results = runtime_activity_safeguard(rmode_rightwidth, result, shadow_results)
3335
if RA <: Active
3436
dinputs = only(reverse(f, args..., dresults, tape))
3537
else

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,39 @@ end
193193

194194
batchify_activity(::Type{Active{T}}, ::Val{B}) where {T,B} = Active{T}
195195
batchify_activity(::Type{Duplicated{T}}, ::Val{B}) where {T,B} = BatchDuplicated{T,B}
196+
197+
has_runtime_activity(mode::Mode) = runtime_activity(mode)
198+
has_runtime_activity(::AutoEnzyme{Nothing}) = false
199+
has_runtime_activity(backend::AutoEnzyme{<:Mode}) = has_runtime_activity(backend.mode)
200+
201+
function runtime_activity_safeguard(
202+
backend_or_mode::Union{<:AutoEnzyme,<:Mode}, primal::T, shadow::T
203+
) where {T}
204+
# TODO: improve datatype_pointerfree to take Ptr into account
205+
if has_runtime_activity(backend_or_mode) &&
206+
!datatype_pointerfree(T) &&
207+
pointer(primal) === pointer(shadow) # TODO: doesn't work beyond arrays
208+
return make_zero(shadow)
209+
else
210+
return shadow
211+
end
212+
end
213+
214+
function runtime_activity_safeguard(
215+
backend_or_mode::Union{<:AutoEnzyme,<:Mode},
216+
primal::T,
217+
shadow::Union{NTuple{N,T},NamedTuple},
218+
) where {T,N}
219+
# TODO: improve datatype_pointerfree to take Ptr into account
220+
if has_runtime_activity(backend_or_mode) &&
221+
!datatype_pointerfree(T) &&
222+
pointer(primal) === pointer(shadow[1]) # TODO: doesn't work beyond arrays
223+
return make_zero(shadow)
224+
else
225+
return shadow
226+
end
227+
end
228+
229+
function runtime_activity_safeguard(::Union{<:AutoEnzyme,<:Mode}, primal, shadow::Nothing)
230+
return nothing
231+
end

DifferentiationInterface/test/Back/Enzyme/test.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,19 @@ end
196196
excluded=[:jacobian],
197197
)
198198
end;
199+
200+
@testset "Runtime activity" begin
201+
# TODO: higher-level operators not tested
202+
test_differentiation(
203+
AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Forward)),
204+
DIT.unknown_activity(default_scenarios());
205+
excluded=vcat(SECOND_ORDER, :jacobian, :gradient, :derivative, :pullback),
206+
logging=LOGGING,
207+
)
208+
test_differentiation(
209+
AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse)),
210+
DIT.unknown_activity(default_scenarios());
211+
excluded=vcat(SECOND_ORDER, :jacobian, :gradient, :derivative, :pushforward),
212+
logging=LOGGING,
213+
)
214+
end

DifferentiationInterfaceTest/src/scenarios/modify.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,54 @@ function closurify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
163163
)
164164
end
165165

166+
struct UnknownActivityReturn{pl_fun,F}
167+
f::F
168+
end
169+
170+
function Base.show(io::IO, f::UnknownActivityReturn)
171+
return print(io, "UnknownActivityReturn($(f.f))")
172+
end
173+
174+
function (f::UnknownActivityReturn{:out})(x, yc, return_constant::Bool)
175+
if return_constant
176+
return copy(yc)
177+
else
178+
return f.f(x)
179+
end
180+
end
181+
182+
function (f::UnknownActivityReturn{:in})(y, x, yc, return_constant::Bool)
183+
if return_constant
184+
copyto!(y, copy(yc))
185+
else
186+
f.f(y, x)
187+
end
188+
return nothing
189+
end
190+
191+
"""
192+
unknown_activity(scen::Scenario)
193+
194+
Return a new scenario identical to `scen` except that the function now takes an additional constant argument which is the theoretical output, and a constant boolean condition stating whether or not that output should be recomputed.
195+
"""
196+
function unknown_activity(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
197+
(; f) = deepcopy(scen)
198+
zero_scen = deepcopy(zero(scen))
199+
@assert isempty(scen.contexts)
200+
unknown_f = UnknownActivityReturn{pl_fun,typeof(f)}(f)
201+
return Scenario{op,pl_op,pl_fun}(;
202+
f=unknown_f,
203+
x=scen.x,
204+
y=scen.y,
205+
t=scen.t,
206+
contexts=(Constant(scen.y), Constant(true)),
207+
res1=zero_scen.res1,
208+
res2=zero_scen.res2,
209+
prep_args=(; scen.prep_args..., contexts=(Constant(scen.y), Constant(true))),
210+
name=isnothing(scen.name) ? nothing : scen.name * " [unknown activity]",
211+
)
212+
end
213+
166214
struct MultiplyByConstant{pl_fun,F} <: FunctionModifier
167215
f::F
168216
end
@@ -366,6 +414,7 @@ closurify(scens::AbstractVector{<:Scenario}) = closurify.(scens)
366414
constantify(scens::AbstractVector{<:Scenario}) = constantify.(scens)
367415
cachify(scens::AbstractVector{<:Scenario}; use_tuples) = cachify.(scens; use_tuples)
368416
constantorcachify(scens::AbstractVector{<:Scenario}) = constantorcachify.(scens)
417+
unknown_activity(scens::AbstractVector{<:Scenario}) = unknown_activity.(scens)
369418

370419
## Compute results with backend
371420

DifferentiationInterfaceTest/test/weird.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@ test_differentiation(
6565
logging=LOGGING,
6666
);
6767

68+
test_differentiation(
69+
AutoFiniteDiff(),
70+
unknown_activity(default_scenarios);
71+
excluded=SECOND_ORDER,
72+
logging=LOGGING,
73+
);
74+
6875
## Neural nets
6976

7077
test_differentiation(

0 commit comments

Comments
 (0)