Skip to content

Commit 8aa5bb4

Browse files
authored
Tapir Safe Mode (#259)
* Respect request for safe mode * Require ADTypes 1.2 * Make safety message appear * Disable safe mode to avoid polluting test logs * Formatting
1 parent 80325fc commit 8aa5bb4

5 files changed

Lines changed: 18 additions & 6 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ DifferentiationInterfaceTrackerExt = "Tracker"
4343
DifferentiationInterfaceZygoteExt = "Zygote"
4444

4545
[compat]
46-
ADTypes = "1.0.0"
46+
ADTypes = "1.2.0"
4747
ChainRulesCore = "1.23.0"
4848
Compat = "3,4"
4949
Diffractor = "=0.2.6"

DifferentiationInterface/ext/DifferentiationInterfaceTapirExt/DifferentiationInterfaceTapirExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ using Tapir:
1818
NoRData,
1919
fdata,
2020
rdata,
21-
__value_and_pullback!!
21+
__value_and_pullback!!,
22+
TapirInterpreter
2223

2324
DI.check_available(::AutoTapir) = true
2425

DifferentiationInterface/ext/DifferentiationInterfaceTapirExt/onearg.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@ end
55

66
function DI.prepare_pullback(f, backend::AutoTapir, x, dy)
77
y = f(x)
8-
rrule = build_rrule(f, x)
8+
rrule = build_rrule(
9+
TapirInterpreter(),
10+
Tuple{typeof(f),typeof(x)};
11+
safety_on=backend.safe_mode,
12+
silence_safety_messages=false,
13+
)
914
extras = TapirOneArgPullbackExtras(y, rrule)
1015
DI.value_and_pullback(f, backend, x, dy, extras) # warm up
1116
return extras

DifferentiationInterface/ext/DifferentiationInterfaceTapirExt/twoarg.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@ struct TapirTwoArgPullbackExtras{R} <: PullbackExtras
33
end
44

55
function DI.prepare_pullback(f!, y, backend::AutoTapir, x, dy)
6-
rrule = build_rrule(f!, y, x)
6+
rrule = build_rrule(
7+
TapirInterpreter(),
8+
Tuple{typeof(f!),typeof(y),typeof(x)};
9+
safety_on=backend.safe_mode,
10+
silence_safety_messages=false,
11+
)
712
extras = TapirTwoArgPullbackExtras(rrule)
813
DI.value_and_pullback(f!, y, backend, x, dy, extras) # warm up
914
return extras
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
using DifferentiationInterface, DifferentiationInterfaceTest
22
using Tapir: Tapir
33

4-
for backend in [AutoTapir()]
4+
for backend in [AutoTapir(; safe_mode=false)]
55
@test check_available(backend)
66
@test check_twoarg(backend)
77
@test !check_hessian(backend; verbose=false)
88
end
99

10-
test_differentiation(AutoTapir(); second_order=false, logging=LOGGING);
10+
# Safe mode switched off to avoid polluting the test suite with
11+
test_differentiation(AutoTapir(; safe_mode=false); second_order=false, logging=LOGGING);

0 commit comments

Comments
 (0)