Skip to content

Commit 5ce58ea

Browse files
committed
feat: DifferentiateWith with contexts
1 parent ce5819e commit 5ce58ea

9 files changed

Lines changed: 129 additions & 48 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ using ChainRulesCore:
99
RuleConfig,
1010
frule_via_ad,
1111
rrule_via_ad,
12-
unthunk
12+
unthunk,
13+
@not_implemented
1314
import DifferentiationInterface as DI
1415

1516
ruleconfig(backend::AutoChainRules) = backend.ruleconfig
Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
1-
function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x)
2-
(; f, backend) = dw
3-
y = f(x)
4-
prep_same = DI.prepare_pullback_same_point_nokwarg(Val(false), f, backend, x, (y,))
5-
function pullbackfunc(dy)
6-
tx = DI.pullback(f, prep_same, backend, x, (dy,))
7-
return (NoTangent(), only(tx))
1+
function ChainRulesCore.rrule(
2+
dw::DI.DifferentiateWith{C}, x, contexts::Vararg{Any, C}
3+
) where {C}
4+
(; f, backend, context_wrappers) = dw
5+
y = f(x, contexts...)
6+
wrapped_contexts = map(DI.call, context_wrappers, contexts)
7+
prep_same = DI.prepare_pullback_same_point_nokwarg(
8+
Val(false), f, backend, x, (y,), wrapped_contexts...
9+
)
10+
function diffwith_pullbackfunc(dy)
11+
dx = DI.pullback(f, prep_same, backend, x, (dy,), wrapped_contexts...) |> only
12+
dc = map(contexts) do c
13+
@not_implemented(
14+
"""
15+
Derivatives with respect to context arguments are not implemented.
16+
"""
17+
)
18+
end
19+
return (NoTangent(), dx, dc...)
820
end
9-
return y, pullbackfunc
21+
return y, diffwith_pullbackfunc
1022
end

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/differentiate_with.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
function (dw::DI.DifferentiateWith)(x::Dual{T, V, N}) where {T, V, N}
1+
function (dw::DI.DifferentiateWith{0})(x::Dual{T, V, N}) where {T, V, N}
22
(; f, backend) = dw
33
xval = myvalue(T, x)
44
tx = mypartials(T, Val(N), x)
55
y, ty = DI.value_and_pushforward(f, backend, xval, tx)
66
return make_dual(T, y, ty)
77
end
88

9-
function (dw::DI.DifferentiateWith)(x::AbstractArray{Dual{T, V, N}}) where {T, V, N}
9+
function (dw::DI.DifferentiateWith{0})(x::AbstractArray{Dual{T, V, N}}) where {T, V, N}
1010
(; f, backend) = dw
1111
xval = myvalue(T, x)
1212
tx = mypartials(T, Val(N), x)

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ using Mooncake:
1818
value_and_pullback!!,
1919
zero_dual,
2020
zero_tangent,
21+
zero_rdata,
2122
rdata_type,
2223
fdata,
2324
rdata,
@@ -26,11 +27,13 @@ using Mooncake:
2627
@is_primitive,
2728
zero_fcodual,
2829
MinimalCtx,
30+
NoFData,
2931
NoRData,
3032
primal,
3133
_copy_output,
3234
_copy_to_output!!,
33-
tangent_to_primal!!
35+
tangent_to_primal!!,
36+
increment!!
3437

3538
const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}}
3639

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith, <:Any}
1+
const NumberOrArray = Union{Number, AbstractArray{<:Number}}
2+
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{0}, <:NumberOrArray}
3+
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{1}, <:NumberOrArray, <:NumberOrArray}
4+
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{2}, <:NumberOrArray, <:NumberOrArray, <:NumberOrArray}
5+
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{3}, <:NumberOrArray, <:NumberOrArray, <:NumberOrArray, <:NumberOrArray}
6+
# TODO: generate more cases programmatically
27

38
struct MooncakeDifferentiateWithError <: Exception
49
F::Type
@@ -12,72 +17,87 @@ end
1217
function Base.showerror(io::IO, e::MooncakeDifferentiateWithError)
1318
return print(
1419
io,
15-
"MooncakeDifferentiateWithError: For the function type $(e.F) and input type $(e.X), the output type $(e.Y) is currently not supported.",
20+
"MooncakeDifferentiateWithError: For the function type $(e.F) and argument types $(e.X), the output type $(e.Y) is currently not supported.",
1621
)
1722
end
1823

19-
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number})
24+
function Mooncake.rrule!!(
25+
dw::CoDual{<:DI.DifferentiateWith{C}},
26+
x::CoDual{<:Number},
27+
contexts::Vararg{CoDual, C}
28+
) where {C}
29+
@assert tangent_type(typeof(dw)) == NoTangent
2030
primal_func = primal(dw)
2131
primal_x = primal(x)
22-
(; f, backend) = primal_func
23-
y = zero_fcodual(f(primal_x))
32+
primal_contexts = map(primal, contexts)
33+
(; f, backend, context_wrappers) = primal_func
34+
y = zero_fcodual(f(primal_x, primal_contexts...))
35+
wrapped_primal_contexts = map(DI.call, context_wrappers, primal_contexts)
2436

2537
# output is a vector, so we need to use the vector pullback
2638
function pullback_array!!(dy::NoRData)
27-
tx = DI.pullback(f, backend, primal_x, (y.dx,))
28-
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x)))
29-
return NoRData(), rdata(only(tx))
39+
dx = DI.pullback(f, backend, primal_x, (y.dx,), wrapped_primal_contexts...) |> only
40+
@assert rdata(only(dx)) isa rdata_type(tangent_type(typeof(primal_x)))
41+
rc = nanify_fdata_and_rdata!!(contexts...)
42+
return (NoRData(), rdata(dx), rc...)
3043
end
3144

3245
# output is a scalar, so we can use the scalar pullback
3346
function pullback_scalar!!(dy::Number)
34-
tx = DI.pullback(f, backend, primal_x, (dy,))
35-
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x)))
36-
return NoRData(), rdata(only(tx))
47+
dx = DI.pullback(f, backend, primal_x, (dy,), wrapped_primal_contexts...) |> only
48+
@assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x)))
49+
rc = nanify_fdata_and_rdata!!(contexts...)
50+
return (NoRData(), rdata(dx), rc...)
3751
end
3852

3953
pullback = if primal(y) isa Number
4054
pullback_scalar!!
4155
elseif primal(y) isa AbstractArray
4256
pullback_array!!
4357
else
44-
throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y)))
58+
throw(MooncakeDifferentiateWithError(primal_func, (primal_x, primal_contexts...), primal(y)))
4559
end
4660

4761
return y, pullback
4862
end
4963

5064
function Mooncake.rrule!!(
51-
dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray{<:Number}}
52-
)
65+
dw::CoDual{<:DI.DifferentiateWith{C}},
66+
x::CoDual{<:AbstractArray{<:Number}},
67+
contexts::Vararg{CoDual, C}
68+
) where {C}
69+
@assert tangent_type(typeof(dw)) == NoTangent
5370
primal_func = primal(dw)
5471
primal_x = primal(x)
55-
fdata_arg = x.dx
56-
(; f, backend) = primal_func
57-
y = zero_fcodual(f(primal_x))
72+
primal_contexts = map(primal, contexts)
73+
(; f, backend, context_wrappers) = primal_func
74+
y = zero_fcodual(f(primal_x, primal_contexts...))
75+
wrapped_primal_contexts = map(DI.call, context_wrappers, primal_contexts)
5876

5977
# output is a vector, so we need to use the vector pullback
6078
function pullback_array!!(dy::NoRData)
61-
tx = DI.pullback(f, backend, primal_x, (y.dx,))
62-
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x))))
63-
fdata_arg .+= only(tx)
64-
return NoRData(), dy
79+
dx = DI.pullback(f, backend, primal_x, (y.dx,), wrapped_primal_contexts...) |> only
80+
@assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x)))
81+
x.dx .+= dx
82+
rc = nanify_fdata_and_rdata!!(contexts...)
83+
return (NoRData(), dy, rc...)
6584
end
6685

6786
# output is a scalar, so we can use the scalar pullback
6887
function pullback_scalar!!(dy::Number)
69-
tx = DI.pullback(f, backend, primal_x, (dy,))
70-
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x))))
71-
fdata_arg .+= only(tx)
72-
return NoRData(), NoRData()
88+
dx = DI.pullback(f, backend, primal_x, (dy,), wrapped_primal_contexts...) |> only
89+
@assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x)))
90+
x.dx .+= dx
91+
rc = nanify_fdata_and_rdata!!(contexts...)
92+
return (NoRData(), NoRData(), rc...)
7393
end
7494

7595
pullback = if primal(y) isa Number
7696
pullback_scalar!!
7797
elseif primal(y) isa AbstractArray
7898
pullback_array!!
7999
else
80-
throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y)))
100+
throw(MooncakeDifferentiateWithError(primal_func, (primal_x, primal_contexts...), primal(y)))
81101
end
82102

83103
return y, pullback

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,19 @@ function zero_tangent_or_primal(x, backend::AnyAutoMooncake)
1717
return zero_tangent(x)
1818
end
1919
end
20+
21+
nanify(x::AbstractFloat) = convert(typeof(x), NaN)
22+
nanify(x::AbstractArray) = map(nan_tangent, x)
23+
nanify(x::Union{Tuple, NamedTuple}) = map(nan_tangent, x)
24+
nanify(::NoFData) = NoFData()
25+
nanify(::NoRData) = NoRData()
26+
27+
function nanify_fdata_and_rdata!!(contexts::Vararg{CoDual, C}) where {C}
28+
primal_contexts = map(primal, contexts)
29+
fdata_contexts = map(fdata, contexts)
30+
zero_rdata_contexts = map(zero_rdata, primal_contexts)
31+
foreach(fdata_contexts) do fc
32+
increment!!(fc, nanify(fc))
33+
end
34+
return map(nanify, zero_rdata_contexts)
35+
end

DifferentiationInterface/src/misc/differentiate_with.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be
1414
1515
!!! warning
1616
17-
`DifferentiateWith` only supports out-of-place functions `y = f(x)` without additional context arguments.
17+
`DifferentiateWith` only supports out-of-place functions `y = f(x, contexts...)`, where the derivatives with respect to `contexts` can be safely ignored in the rest of your code.
1818
It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl), reverse-mode [Mooncake](https://github.com/chalk-lab/Mooncake.jl), or if it automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules.
1919
For any other true backend, the differentiation behavior is not altered by `DifferentiateWith` (it becomes a transparent wrapper).
2020
@@ -25,16 +25,17 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be
2525
2626
# Fields
2727
28-
- `f`: the function in question, with signature `f(x)`
28+
- `f`: the function in question, with signature `f(x, contexts...)`
2929
- `backend::AbstractADType`: the substitute backend to use for differentiation
30+
- `context_wrappers::NTuple`: a tuple like `(Constant, Cache)`, meaning that `f(x, a, b)` will be differentiated with `Constant(a)` and `Cache(b)` as contexts.
3031
3132
!!! note
3233
3334
For the substitute AD backend to be called under the hood, its package needs to be loaded in addition to the package of the true AD backend.
3435
3536
# Constructor
3637
37-
DifferentiateWith(f, backend)
38+
DifferentiateWith(f, backend, context_wrappers)
3839
3940
# Example
4041
@@ -69,12 +70,17 @@ julia> Zygote.gradient(alg, [3.0, 5.0])[1]
6970
70.0
7071
```
7172
"""
72-
struct DifferentiateWith{F, B <: AbstractADType}
73+
struct DifferentiateWith{C, F, B <: AbstractADType, N <: NTuple{C, Any}}
7374
f::F
7475
backend::B
76+
context_wrappers::N
7577
end
7678

77-
(dw::DifferentiateWith)(x) = dw.f(x)
79+
DifferentiateWith(f::F, backend::AbstractADType) where {F} = DifferentiateWith(f, backend, ())
80+
81+
function (dw::DifferentiateWith{C})(x, args::Vararg{Any, C}) where {C}
82+
return dw.f(x, args...)
83+
end
7884

7985
function Base.show(io::IO, dw::DifferentiateWith)
8086
(; f, backend) = dw

DifferentiationInterface/src/utils/context.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,5 @@ Convenience for constructing a [`FixTail`](@ref), with a shortcut when there are
179179
"""
180180
@inline fix_tail(f::F) where {F} = f
181181
fix_tail(f::F, args::Vararg{Any, N}) where {F, N} = FixTail(f, args...)
182+
183+
@inline call(f::F, x) where {F} = f(x)

DifferentiationInterface/test/Back/DifferentiateWith/test.jl

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,14 @@ function (adb::ADBreaker)(x::AbstractArray)
2424
return adb.f(x)
2525
end
2626

27-
function differentiatewith_scenarios()
28-
outofplace_scens = filter(DIT.default_scenarios()) do scen
29-
DIT.function_place(scen) == :out
27+
# TODO: break Mooncake with overlay?
28+
29+
function differentiatewith_scenarios(; kwargs...)
30+
outofplace_scens = filter(DIT.default_scenarios(; kwargs...)) do scen
31+
DIT.function_place(scen) == :out &&
32+
# save some time
33+
!isa(scen.x, AbstractMatrix) &&
34+
!isa(scen.y, AbstractMatrix)
3035
end
3136
# with bad_scens, everything would break
3237
bad_scens = map(outofplace_scens) do scen
@@ -44,7 +49,23 @@ test_differentiation(
4449
differentiatewith_scenarios();
4550
excluded = SECOND_ORDER,
4651
logging = LOGGING,
47-
testset_name = "DI tests",
52+
testset_name = "DI tests - normal",
53+
)
54+
55+
test_differentiation(
56+
[AutoZygote(), AutoMooncake(; config = nothing)],
57+
map(DIT.constantify, differentiatewith_scenarios());
58+
excluded = SECOND_ORDER,
59+
logging = LOGGING,
60+
testset_name = "DI tests - Constant",
61+
)
62+
63+
test_differentiation(
64+
[AutoMooncake(; config = nothing)],
65+
map(DIT.cachify, differentiatewith_scenarios());
66+
excluded = SECOND_ORDER,
67+
logging = LOGGING,
68+
testset_name = "DI tests - Cache",
4869
)
4970

5071
@testset "ChainRules tests" begin
@@ -71,7 +92,7 @@ end;
7192

7293
e = MooncakeDifferentiateWithError(identity, 1.0, 2.0)
7394
@test sprint(showerror, e) ==
74-
"MooncakeDifferentiateWithError: For the function type typeof(identity) and input type Float64, the output type Float64 is currently not supported."
95+
"MooncakeDifferentiateWithError: For the function type typeof(identity) and input types (Float64,), the output type Float64 is currently not supported."
7596

7697
f_num2tup(x::Number) = (x,)
7798
f_vec2tup(x::Vector) = (first(x),)

0 commit comments

Comments
 (0)