Skip to content

Commit f4ba734

Browse files
First test scenarios for Flux gradients (#352)
* Flux-Zygote : creating functions for test scenario * Weakdeps * Shorter workflows * Exclude * Exclude more * Working test on Zygote * Fix * Fix Flux tests * Enzyme working * Fixes * No random * Reactivate CI * Printing * Runtime activity * Make mergeable * Add FiniteDifferences * Imports * Group weird tests * define DIT * Higher tolerance in test --------- Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
1 parent a86ebb8 commit f4ba734

18 files changed

Lines changed: 402 additions & 92 deletions

File tree

.github/workflows/Test.yml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ jobs:
4848
- Back/Zygote
4949
- Down/Detector
5050
- Down/DifferentiateWith
51+
- Down/Flux
5152
exclude:
5253
# lts
5354
- version: 'lts'
@@ -68,8 +69,10 @@ jobs:
6869
group: Back/Symbolics
6970
- version: 'lts'
7071
group: Back/Tapir
71-
- version: 'pre'
72+
- version: 'lts'
7273
group: Down/Detector
74+
- version: 'lts'
75+
group: Down/Flux
7376
# pre-release
7477
- version: 'pre'
7578
group: Formalities
@@ -81,6 +84,8 @@ jobs:
8184
group: Back/SecondOrder
8285
- version: 'pre'
8386
group: Down/Detector
87+
- version: 'pre'
88+
group: Down/Flux
8489

8590
steps:
8691
- uses: actions/checkout@v4
@@ -124,10 +129,12 @@ jobs:
124129
- Formalities
125130
- Zero
126131
- ForwardDiff
127-
- Zygote
132+
- Weird
128133
exclude:
129134
- version: 'lts'
130135
group: Formalities
136+
- version: 'lts'
137+
group: Weird
131138

132139
steps:
133140
- uses: actions/checkout@v4

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ using Enzyme:
3535
gradient,
3636
gradient!,
3737
jacobian,
38-
make_zero
38+
make_zero,
39+
make_zero!
3940

4041
struct AutoDeferredEnzyme{M} <: ADTypes.AbstractADType
4142
mode::M

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function DI.value_and_pullback(
2626
f,
2727
backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}},
2828
x::Number,
29-
dy::AbstractArray,
29+
dy,
3030
::NoPullbackExtras,
3131
)
3232
tf, tx = typeof(f), typeof(x)
@@ -40,11 +40,28 @@ end
4040
function DI.value_and_pullback(
4141
f,
4242
backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}},
43-
x::AbstractArray,
44-
dy,
45-
extras::NoPullbackExtras,
43+
x,
44+
dy::Number,
45+
::NoPullbackExtras,
4646
)
47-
dx = similar(x)
47+
dx_sametype = make_zero(x)
48+
x_and_dx = Duplicated(x, dx_sametype)
49+
_, y = if backend isa AutoDeferredEnzyme
50+
autodiff_deferred(ReverseWithPrimal, Const(f), Active, x_and_dx)
51+
else
52+
autodiff(ReverseWithPrimal, Const(f), Active, x_and_dx)
53+
end
54+
if !isone(dy)
55+
# TODO: generalize beyond Arrays?
56+
dx_sametype .*= dy
57+
end
58+
return y, dx_sametype
59+
end
60+
61+
function DI.value_and_pullback(
62+
f, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, dy, extras::NoPullbackExtras
63+
)
64+
dx = make_zero(x)
4865
return DI.value_and_pullback!(f, dx, backend, x, dy, extras)
4966
end
5067

@@ -60,36 +77,34 @@ function DI.value_and_pullback!(
6077
f,
6178
dx,
6279
backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}},
63-
x::AbstractArray,
80+
x,
6481
dy::Number,
6582
::NoPullbackExtras,
6683
)
6784
dx_sametype = convert(typeof(x), dx)
68-
dx_sametype .= zero(eltype(x))
85+
make_zero!(dx_sametype)
6986
x_and_dx = Duplicated(x, dx_sametype)
7087
_, y = if backend isa AutoDeferredEnzyme
7188
autodiff_deferred(ReverseWithPrimal, Const(f), Active, x_and_dx)
7289
else
7390
autodiff(ReverseWithPrimal, Const(f), Active, x_and_dx)
7491
end
75-
dx_sametype .*= dy
92+
if !isone(dy)
93+
# TODO: generalize beyond Arrays?
94+
dx_sametype .*= dy
95+
end
7696
return y, copyto!(dx, dx_sametype)
7797
end
7898

7999
function DI.value_and_pullback!(
80-
f,
81-
dx,
82-
backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}},
83-
x::AbstractArray,
84-
dy::AbstractArray,
85-
::NoPullbackExtras,
100+
f, dx, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, dy, ::NoPullbackExtras
86101
)
87102
tf, tx = typeof(f), typeof(x)
88103
forw, rev = autodiff_thunk(
89104
ReverseSplitWithPrimal, Const{tf}, Duplicated, Duplicated{tx}
90105
)
91106
dx_sametype = convert(typeof(x), dx)
92-
dx_sametype .= zero(eltype(x))
107+
make_zero!(dx_sametype)
93108
tape, y, new_dy = forw(Const(f), Duplicated(x, dx_sametype))
94109
copyto!(new_dy, dy)
95110
rev(Const(f), Duplicated(x, dx_sametype), tape)
@@ -133,7 +148,7 @@ function DI.gradient!(
133148
extras::NoGradientExtras,
134149
)
135150
grad_sametype = convert(typeof(x), grad)
136-
grad_sametype .= zero(eltype(x))
151+
make_zero!(grad_sametype)
137152
if backend isa AutoDeferredEnzyme
138153
autodiff_deferred(reverse_mode(backend), f, Active, Duplicated(x, grad_sametype))
139154
else
@@ -145,13 +160,13 @@ end
145160
function DI.value_and_gradient(
146161
f, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, ::NoGradientExtras
147162
)
148-
return DI.value_and_pullback(f, backend, x, one(eltype(x)), NoPullbackExtras())
163+
return DI.value_and_pullback(f, backend, x, true, NoPullbackExtras())
149164
end
150165

151166
function DI.value_and_gradient!(
152167
f, grad, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, ::NoGradientExtras
153168
)
154-
return DI.value_and_pullback!(f, grad, backend, x, one(eltype(x)), NoPullbackExtras())
169+
return DI.value_and_pullback!(f, grad, backend, x, true, NoPullbackExtras())
155170
end
156171

157172
## Jacobian
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[deps]
2+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
3+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
4+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
5+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
6+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
7+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using DifferentiationInterface, DifferentiationInterfaceTest
2+
import DifferentiationInterfaceTest as DIT
3+
using FiniteDifferences: FiniteDifferences
4+
using Flux: Flux
5+
using Enzyme: Enzyme
6+
using Zygote: Zygote
7+
using Test
8+
9+
Enzyme.API.runtimeActivity!(true)
10+
11+
test_differentiation(
12+
[
13+
AutoZygote(),
14+
# AutoEnzyme() # TODO: fix
15+
],
16+
flux_scenarios();
17+
isequal=DIT.flux_isequal,
18+
isapprox=DIT.flux_isapprox,
19+
rtol=1e-2,
20+
atol=1e-6,
21+
)

DifferentiationInterface/test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ function subtest(category, folder)
2424
@testset "$file" for file in filter(
2525
endswith(".jl"), readdir(joinpath(@__DIR__, category, folder))
2626
)
27-
@info "Testing category/$folder/$file"
27+
@info "Testing $category/$folder/$file"
2828
include(joinpath(@__DIR__, category, folder, file))
2929
end
3030
Pkg.activate(TEST_ENV)

DifferentiationInterfaceTest/Project.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1010
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
1111
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1212
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
13+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1314
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
1415
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1516
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
@@ -20,11 +21,14 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2021

2122
[weakdeps]
2223
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
24+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
25+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
2326
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
2427
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2528

2629
[extensions]
2730
DifferentiationInterfaceTestComponentArraysExt = "ComponentArrays"
31+
DifferentiationInterfaceTestFluxExt = ["FiniteDifferences", "Flux"]
2832
DifferentiationInterfaceTestJLArraysExt = "JLArrays"
2933
DifferentiationInterfaceTestStaticArraysExt = "StaticArrays"
3034

@@ -36,6 +40,7 @@ ComponentArrays = "0.15"
3640
DataFrames = "1.6.1"
3741
DifferentiationInterface = "0.5.6"
3842
DocStringExtensions = "0.9"
43+
Functors = "0.4"
3944
JET = "0.4 - 0.8, 0.9"
4045
JLArrays = "0.1"
4146
LinearAlgebra = "<0.0.1,1"
@@ -54,6 +59,8 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
5459
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
5560
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
5661
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
62+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
63+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
5764
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
5865
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
5966
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
@@ -73,6 +80,8 @@ test = [
7380
"ComponentArrays",
7481
"DataFrames",
7582
"DifferentiationInterface",
83+
"FiniteDifferences",
84+
"Flux",
7685
"ForwardDiff",
7786
"JET",
7887
"JLArrays",

DifferentiationInterfaceTest/docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ sparse_scenarios
2525
component_scenarios
2626
gpu_scenarios
2727
static_scenarios
28+
flux_scenarios
2829
```
2930

3031
## Scenario types

0 commit comments

Comments
 (0)