using DifferentiationInterface
import Enzyme, Mooncake, ForwardDiff, Zygote, FiniteDiff, BenchmarkTools, Reactant
function f(x)
A = [x+1 x+2 x+3
x+4 x+5 x+6
x+7 x+8 x+9]
B = [x-1 x-2 x-3
x-4 x-5 x-6
x-7 x-8 x-9]
sum(A' * (B .* 1.5))
end
function bench()
for backend ∈ (AutoEnzyme(mode=Enzyme.Forward),
AutoEnzyme(mode=Enzyme.Reverse),
AutoMooncake(),
AutoMooncakeForward(),
AutoForwardDiff(),
AutoZygote(),
AutoFiniteDiff())
prep = prepare_derivative(f, backend, 1.0)
x = Ref(1.0)
@show backend
print(" "); @btime derivative(f, $prep, $backend, $x[])
println()
end
end
bench()
@btime Enzyme.autodiff(Reverse, f, Active(1.0))
@btime Enzyme.gradient(Reverse, f, 1.0)