Skip to content

Commit 5f444f7

Browse files
authored
add HyperHessian definition (#146)
* add HyperHessian definition ref JuliaDiff/DifferentiationInterface.jl#940 * add to docs
1 parent 58528c1 commit 5f444f7

7 files changed

Lines changed: 54 additions & 2 deletions

File tree

docs/src/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ Taylor mode:
3636
```@docs
3737
AutoGTPSA
3838
AutoTaylorDiff
39+
AutoHyperHessians
3940
```
4041

4142
### Reverse mode

src/ADTypes.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ export AutoChainRules,
3636
AutoFiniteDifferences,
3737
AutoForwardDiff,
3838
AutoGTPSA,
39+
AutoHyperHessians,
3940
AutoModelingToolkit,
4041
AutoMooncake,
4142
AutoMooncakeForward,

src/dense.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,38 @@ struct AutoZygote <: AbstractADType end
562562

563563
mode(::AutoZygote) = ReverseMode()
564564

565+
"""
566+
AutoHyperHessians{chunksize}
567+
568+
Struct used to select the [HyperHessians.jl](https://github.com/KristofferC/HyperHessians.jl) backend for automatic differentiation.
569+
570+
Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
571+
572+
# Constructors
573+
574+
AutoHyperHessians(; chunksize=nothing)
575+
576+
# Type parameters
577+
578+
- `chunksize`: the preferred chunk size to evaluate several derivatives at once. If `nothing`, HyperHessians chooses automatically.
579+
"""
580+
struct AutoHyperHessians{chunksize} <: AbstractADType end
581+
582+
function AutoHyperHessians(; chunksize::Union{Nothing, Int} = nothing)
583+
if chunksize isa Int
584+
chunksize > 0 || throw(ArgumentError("chunksize must be positive, got $chunksize"))
585+
end
586+
return AutoHyperHessians{chunksize}()
587+
end
588+
589+
mode(::AutoHyperHessians) = ForwardMode()
590+
591+
function Base.show(io::IO, ::AutoHyperHessians{chunksize}) where {chunksize}
592+
print(io, AutoHyperHessians, "(")
593+
chunksize !== nothing && print(io, "chunksize=", repr(chunksize; context = io))
594+
return print(io, ")")
595+
end
596+
565597
"""
566598
NoAutoDiff
567599

src/symbols.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ Auto(package::Symbol, args...; kws...) = Auto(Val(package), args...; kws...)
2424

2525
for backend in (
2626
:ChainRules, :Diffractor, :Enzyme, :Reactant, :FastDifferentiation,
27-
:FiniteDiff, :FiniteDifferences, :ForwardDiff, :GTPSA, :Mooncake, :PolyesterForwardDiff,
28-
:ReverseDiff, :Symbolics, :Tapir, :TaylorDiff, :Tracker, :Zygote,
27+
:FiniteDiff, :FiniteDifferences, :ForwardDiff, :GTPSA, :HyperHessians, :Mooncake,
28+
:PolyesterForwardDiff, :ReverseDiff, :Symbolics, :Tapir, :TaylorDiff, :Tracker, :Zygote,
2929
)
3030
@eval Auto(::Val{$(QuoteNode(backend))}, args...; kws...) = $(Symbol(:Auto, backend))(
3131
args...; kws...

test/dense.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,20 @@ end
160160
@test ad.descriptor == Val(:descriptor)
161161
end
162162

163+
@testset "AutoHyperHessians" begin
164+
ad = AutoHyperHessians()
165+
@test ad isa AbstractADType
166+
@test ad isa AutoHyperHessians{nothing}
167+
@test mode(ad) isa ForwardMode
168+
169+
ad = AutoHyperHessians(; chunksize = 8)
170+
@test ad isa AbstractADType
171+
@test ad isa AutoHyperHessians{8}
172+
@test mode(ad) isa ForwardMode
173+
174+
@test_throws ArgumentError AutoHyperHessians(; chunksize = -1)
175+
end
176+
163177
@testset "AutoMooncake" begin
164178
ad = AutoMooncake(; config = :config)
165179
@test ad isa AbstractADType

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ function every_ad()
4646
AutoFiniteDifferences(; fdm = :fdm),
4747
AutoForwardDiff(),
4848
AutoGTPSA(),
49+
AutoHyperHessians(),
4950
AutoPolyesterForwardDiff(),
5051
AutoReverseDiff(),
5152
AutoSymbolics(),
@@ -72,6 +73,8 @@ function every_ad_with_options()
7273
AutoForwardDiff(chunksize = 3, tag = :tag),
7374
AutoGTPSA(),
7475
AutoGTPSA(descriptor = Val(:descriptor)),
76+
AutoHyperHessians(),
77+
AutoHyperHessians(chunksize = 8),
7578
AutoMooncake(; config = :config),
7679
AutoMooncakeForward(; config = :config),
7780
AutoPolyesterForwardDiff(),

test/symbols.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using Test
88
@test ADTypes.Auto(:FiniteDiff) isa AutoFiniteDiff
99
@test ADTypes.Auto(:FiniteDifferences, 1.0) isa AutoFiniteDifferences{Float64}
1010
@test ADTypes.Auto(:ForwardDiff) isa AutoForwardDiff
11+
@test ADTypes.Auto(:HyperHessians) isa AutoHyperHessians
1112
@test ADTypes.Auto(:Mooncake) isa AutoMooncake
1213
@test ADTypes.Auto(:PolyesterForwardDiff) isa AutoPolyesterForwardDiff
1314
@test ADTypes.Auto(:ReverseDiff) isa AutoReverseDiff

0 commit comments

Comments
 (0)