Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/Test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
actions: write
contents: read
strategy:
fail-fast: true # TODO: toggle
fail-fast: false # TODO: toggle
matrix:
version:
- "1.10"
Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/docs/src/dev_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Your AD package needs to be registered first.
### Core code

In the main package, you should define a new struct `SuperDiffBackend` which subtypes [`ADTypes.AbstractADType`](@extref ADTypes), and endow it with the fields you need to parametrize your differentiation routines.
You also have to define [`ADTypes.mode`](@extref) and [`DifferentiationInterface.inplace_support`](@ref) on `SuperDiffBackend`.
You also have to define [`ADTypes.mode`](@extref) and [`DifferentiationInterface.check_inplace`](@ref) on `SuperDiffBackend`.

!!! info
In the end, this backend struct will need to be contributed to [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ const AutoForwardChainRules = AutoChainRules{<:RuleConfig{>:HasForwardsMode}}
const AutoReverseChainRules = AutoChainRules{<:RuleConfig{>:HasReverseMode}}

DI.check_available(::AutoChainRules) = true
DI.inplace_support(::AutoChainRules) = DI.InPlaceNotSupported()
DI.check_inplace(::AutoChainRules) = false

include("reverse_onearg.jl")
include("differentiate_with.jl")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@ struct ChainRulesPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep
end

function DI.prepare_pullback(
f,
::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
f, ::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}
) where {C}
return DI.NoPullbackPrep()
end
Expand All @@ -21,7 +17,7 @@ function DI.prepare_pullback_same_point(
backend::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
contexts::Vararg{DI.GeneralizedConstant,C},
) where {C}
rc = ruleconfig(backend)
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
Expand All @@ -34,7 +30,7 @@ function DI.value_and_pullback(
backend::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
contexts::Vararg{DI.GeneralizedConstant,C},
) where {C}
rc = ruleconfig(backend)
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
Expand All @@ -50,7 +46,7 @@ function DI.value_and_pullback(
::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
contexts::Vararg{DI.GeneralizedConstant,C},
) where {C}
(; y, pb) = prep
tx = map(ty) do dy
Expand All @@ -65,7 +61,7 @@ function DI.pullback(
::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
contexts::Vararg{DI.GeneralizedConstant,C},
) where {C}
(; pb) = prep
tx = map(ty) do dy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import DifferentiationInterface as DI
using Diffractor: DiffractorRuleConfig, TaylorTangentIndex, ZeroBundle, bundle, ∂☆

DI.check_available(::AutoDiffractor) = true
DI.inplace_support(::AutoDiffractor) = DI.InPlaceNotSupported()
DI.check_inplace(::AutoDiffractor) = false
DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow()

## Pushforward
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ force_annotation(f::F) where {F} = Const(f)
end

@inline function _translate(
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.Cache
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.GeneralizedCache
) where {B}
if B == 1
return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using FiniteDifferences: FiniteDifferences, grad, jacobian, jvp, j′vp
using LinearAlgebra: dot

DI.check_available(::AutoFiniteDifferences) = true
DI.inplace_support(::AutoFiniteDifferences) = DI.InPlaceNotSupported()
DI.check_inplace(::AutoFiniteDifferences) = false

## Pushforward

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ using ForwardDiff:
value

DI.check_available(::AutoForwardDiff) = true
DI.check_operator_overloading(::AutoForwardDiff) = true

include("utils.jl")
include("onearg.jl")
include("twoarg.jl")
include("secondorder.jl")
include("differentiate_with.jl")
include("misc.jl")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,58 @@
## Pushforward

function DI.overloaded_input(
::typeof(DI.pushforward),
f::F,
backend::AutoForwardDiff,
x,
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
T = tag_type(f, backend, x)
xdual = make_dual(T, x, tx)
return xdual
end
#=
function DI.overloaded_input(
::typeof(DI.pushforward),
f!::F,
y,
backend::AutoForwardDiff,
x,
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
T = tag_type(f, backend, x)
xdual = if x isa Number
make_dual(T, x, tx)
else
make_dual_similar(T, x, tx)
end
return xdual
end
=#

DI.overloaded_input_type(prep::ForwardDiffOneArgPushforwardPrep) = typeof(prep.xdual_tmp)
DI.overloaded_input_type(prep::ForwardDiffTwoArgPushforwardPrep) = typeof(prep.xdual_tmp)

## Derivative

function DI.overloaded_input_type(prep::ForwardDiffOneArgDerivativePrep)
return DI.overloaded_input_type(prep.pushforward_prep)
end
DI.overloaded_input_type(prep::ForwardDiffTwoArgDerivativePrep) = typeof(prep.config.duals)
function DI.overloaded_input_type(prep::ForwardDiffTwoArgDerivativePrep)
return typeof(prep.config.duals)
end

## Gradient

DI.overloaded_input_type(prep::ForwardDiffGradientPrep) = typeof(prep.config.duals)

## Jacobian
DI.overloaded_input_type(prep::ForwardDiffOneArgJacobianPrep) = typeof(prep.config.duals[2])
DI.overloaded_input_type(prep::ForwardDiffTwoArgJacobianPrep) = typeof(prep.config.duals[2])

function DI.overloaded_input_type(prep::ForwardDiffOneArgJacobianPrep)
return typeof(prep.config.duals[2])
end
function DI.overloaded_input_type(prep::ForwardDiffTwoArgJacobianPrep)
return typeof(prep.config.duals[2])
end
Loading