Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.44"
version = "0.6.45"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x)
(; f, backend) = dw
y = f(x)
prep_same = DI.prepare_pullback_same_point(f, backend, x, (y,))
prep_same = DI.prepare_pullback_same_point(Val(true), f, backend, x, (y,))
function pullbackfunc(dy)
tx = DI.pullback(f, prep_same, backend, x, (dy,))
return (NoTangent(), only(tx))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,37 +1,47 @@
## Pullback

struct ChainRulesPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep
struct ChainRulesPullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG}
_sig::Val{SIG}
y::Y
pb::PB
end

function DI.prepare_pullback(
f, ::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}
strict::Val,
f,
backend::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{DI.GeneralizedConstant,C};
) where {C}
return DI.NoPullbackPrep()
_sig = DI.signature(f, backend, x, ty, contexts...; strict)
return DI.NoPullbackPrep(_sig)
end

function DI.prepare_pullback_same_point(
f,
::DI.NoPullbackPrep,
prep::DI.NoPullbackPrep,
backend::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{DI.GeneralizedConstant,C},
contexts::Vararg{DI.GeneralizedConstant,C};
) where {C}
DI.check_prep(f, prep, backend, x, ty, contexts...)
_sig = DI.signature(f, backend, x, ty, contexts...; strict=DI.is_strict(prep))
rc = ruleconfig(backend)
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
return ChainRulesPullbackPrepSamePoint(y, pb)
return ChainRulesPullbackPrepSamePoint(_sig, y, pb)
end

function DI.value_and_pullback(
f,
::DI.NoPullbackPrep,
prep::DI.NoPullbackPrep,
backend::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{DI.GeneralizedConstant,C},
) where {C}
DI.check_prep(f, prep, backend, x, ty, contexts...)
rc = ruleconfig(backend)
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
tx = map(ty) do dy
Expand All @@ -43,11 +53,12 @@ end
function DI.value_and_pullback(
f,
prep::ChainRulesPullbackPrepSamePoint,
::AutoReverseChainRules,
backend::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{DI.GeneralizedConstant,C},
) where {C}
DI.check_prep(f, prep, backend, x, ty, contexts...)
(; y, pb) = prep
tx = map(ty) do dy
unthunk(pb(dy)[2])
Expand All @@ -58,11 +69,12 @@ end
function DI.pullback(
f,
prep::ChainRulesPullbackPrepSamePoint,
::AutoReverseChainRules,
backend::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{DI.GeneralizedConstant,C},
) where {C}
DI.check_prep(f, prep, backend, x, ty, contexts...)
(; pb) = prep
tx = map(ty) do dy
unthunk(pb(dy)[2])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,15 @@ DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow()

## Pushforward

DI.prepare_pushforward(f, ::AutoDiffractor, x, tx::NTuple) = DI.NoPushforwardPrep()
function DI.prepare_pushforward(strict::Val, f, backend::AutoDiffractor, x, tx::NTuple)
_sig = DI.signature(f, backend, x, tx; strict)
return DI.NoPushforwardPrep(_sig)
end

function DI.pushforward(f, ::DI.NoPushforwardPrep, ::AutoDiffractor, x, tx::NTuple)
function DI.pushforward(
f, prep::DI.NoPushforwardPrep, backend::AutoDiffractor, x, tx::NTuple
)
DI.check_prep(f, prep, backend, x, tx)
ty = map(tx) do dx
# code copied from Diffractor.jl
z = ∂☆{1}()(ZeroBundle{1}(f), bundle(x, dx))
Expand All @@ -25,6 +31,7 @@ end
function DI.value_and_pushforward(
f, prep::DI.NoPushforwardPrep, backend::AutoDiffractor, x, tx::NTuple
)
DI.check_prep(f, prep, backend, x, tx)
return f(x), DI.pushforward(f, prep, backend, x, tx)
end

Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
## Pushforward

function DI.prepare_pushforward(
strict::Val,
f::F,
::AutoEnzyme{<:Union{ForwardMode,Nothing}},
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{DI.Context,C},
contexts::Vararg{DI.Context,C};
) where {F,C}
return DI.NoPushforwardPrep()
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
return DI.NoPushforwardPrep(_sig)
end

function DI.value_and_pushforward(
f::F,
::DI.NoPushforwardPrep,
prep::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{1},
contexts::Vararg{DI.Context,C},
) where {F,C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
mode = forward_withprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
dx = only(tx)
Expand All @@ -29,12 +32,13 @@ end

function DI.value_and_pushforward(
f::F,
::DI.NoPushforwardPrep,
prep::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
mode = forward_withprimal(backend)
f_and_df = get_f_and_df(f, backend, mode, Val(B))
x_and_tx = BatchDuplicated(x, tx)
Expand All @@ -45,12 +49,13 @@ end

function DI.pushforward(
f::F,
::DI.NoPushforwardPrep,
prep::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{1},
contexts::Vararg{DI.Context,C},
) where {F,C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
mode = forward_noprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
dx = only(tx)
Expand All @@ -62,12 +67,13 @@ end

function DI.pushforward(
f::F,
::DI.NoPushforwardPrep,
prep::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
mode = forward_noprimal(backend)
f_and_df = get_f_and_df(f, backend, mode, Val(B))
x_and_tx = BatchDuplicated(x, tx)
Expand All @@ -85,6 +91,7 @@ function DI.value_and_pushforward!(
tx::NTuple,
contexts::Vararg{DI.Context,C},
) where {F,C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
# dy cannot be passed anyway
y, new_ty = DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)
foreach(copyto!, ty, new_ty)
Expand All @@ -100,6 +107,7 @@ function DI.pushforward!(
tx::NTuple,
contexts::Vararg{DI.Context,C},
) where {F,C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
# dy cannot be passed anyway
new_ty = DI.pushforward(f, prep, backend, x, tx, contexts...)
foreach(copyto!, ty, new_ty)
Expand All @@ -108,32 +116,33 @@ end

## Gradient

struct EnzymeForwardGradientPrep{B,O} <: DI.GradientPrep
struct EnzymeForwardGradientPrep{SIG,B,O} <: DI.GradientPrep{SIG}
_sig::Val{SIG}
_valB::Val{B}
shadows::O
end

function EnzymeForwardGradientPrep(::Val{B}, shadows::O) where {B,O}
return EnzymeForwardGradientPrep{B,O}(shadows)
end

function DI.prepare_gradient(
strict::Val,
f::F,
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
x,
contexts::Vararg{DI.Constant,C},
contexts::Vararg{DI.Constant,C};
) where {F,C}
_sig = DI.signature(f, backend, x, contexts...; strict)
valB = to_val(DI.pick_batchsize(backend, x))
shadows = create_shadows(valB, x)
return EnzymeForwardGradientPrep(valB, shadows)
return EnzymeForwardGradientPrep(_sig, valB, shadows)
end

function DI.gradient(
f::F,
prep::EnzymeForwardGradientPrep{B},
prep::EnzymeForwardGradientPrep{SIG,B},
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
x,
contexts::Vararg{DI.Constant,C},
) where {F,B,C}
) where {F,SIG,B,C}
DI.check_prep(f, prep, backend, x, contexts...)
mode = forward_noprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
Expand All @@ -145,11 +154,12 @@ end

function DI.value_and_gradient(
f::F,
prep::EnzymeForwardGradientPrep{B},
prep::EnzymeForwardGradientPrep{SIG,B},
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
x,
contexts::Vararg{DI.Constant,C},
) where {F,B,C}
) where {F,SIG,B,C}
DI.check_prep(f, prep, backend, x, contexts...)
mode = forward_withprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
Expand All @@ -162,58 +172,59 @@ end
function DI.gradient!(
f::F,
grad,
prep::EnzymeForwardGradientPrep{B},
prep::EnzymeForwardGradientPrep{SIG,B},
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
x,
contexts::Vararg{DI.Constant,C},
) where {F,B,C}
) where {F,SIG,B,C}
DI.check_prep(f, prep, backend, x, contexts...)
return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...))
end

function DI.value_and_gradient!(
f::F,
grad,
prep::EnzymeForwardGradientPrep{B},
prep::EnzymeForwardGradientPrep{SIG,B},
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
x,
contexts::Vararg{DI.Constant,C},
) where {F,B,C}
) where {F,SIG,B,C}
DI.check_prep(f, prep, backend, x, contexts...)
y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...)
return y, copyto!(grad, new_grad)
end

## Jacobian

struct EnzymeForwardOneArgJacobianPrep{B,O} <: DI.JacobianPrep
struct EnzymeForwardOneArgJacobianPrep{SIG,B,O} <: DI.JacobianPrep{SIG}
_sig::Val{SIG}
_valB::Val{B}
shadows::O
output_length::Int
end

function EnzymeForwardOneArgJacobianPrep(
::Val{B}, shadows::O, output_length::Integer
) where {B,O}
return EnzymeForwardOneArgJacobianPrep{B,O}(shadows, output_length)
end

function DI.prepare_jacobian(
strict::Val,
f::F,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
x,
contexts::Vararg{DI.Constant,C},
contexts::Vararg{DI.Constant,C};
) where {F,C}
_sig = DI.signature(f, backend, x, contexts...; strict)
y = f(x, map(DI.unwrap, contexts)...)
valB = to_val(DI.pick_batchsize(backend, x))
shadows = create_shadows(valB, x)
return EnzymeForwardOneArgJacobianPrep(valB, shadows, length(y))
return EnzymeForwardOneArgJacobianPrep(_sig, valB, shadows, length(y))
end

function DI.jacobian(
f::F,
prep::EnzymeForwardOneArgJacobianPrep{B},
prep::EnzymeForwardOneArgJacobianPrep{SIG,B},
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
x,
contexts::Vararg{DI.Constant,C},
) where {F,B,C}
) where {F,SIG,B,C}
DI.check_prep(f, prep, backend, x, contexts...)
mode = forward_noprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
Expand All @@ -226,11 +237,12 @@ end

function DI.value_and_jacobian(
f::F,
prep::EnzymeForwardOneArgJacobianPrep{B},
prep::EnzymeForwardOneArgJacobianPrep{SIG,B},
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
x,
contexts::Vararg{DI.Constant,C},
) where {F,B,C}
) where {F,SIG,B,C}
DI.check_prep(f, prep, backend, x, contexts...)
mode = forward_withprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
Expand All @@ -249,6 +261,7 @@ function DI.jacobian!(
x,
contexts::Vararg{DI.Constant,C},
) where {F,C}
DI.check_prep(f, prep, backend, x, contexts...)
return copyto!(jac, DI.jacobian(f, prep, backend, x, contexts...))
end

Expand All @@ -260,6 +273,7 @@ function DI.value_and_jacobian!(
x,
contexts::Vararg{DI.Constant,C},
) where {F,C}
DI.check_prep(f, prep, backend, x, contexts...)
y, new_jac = DI.value_and_jacobian(f, prep, backend, x, contexts...)
return y, copyto!(jac, new_jac)
end
Loading