Skip to content

Commit 6894ad7

Browse files
committed
Fix overload
1 parent fffbf8a commit 6894ad7

3 files changed

Lines changed: 24 additions & 9 deletions

File tree

DifferentiationInterface/src/first_order/mixed_mode.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,10 @@ Appropriate mode type for `MixedMode` backends.
4141
"""
4242
struct ForwardAndReverseMode <: ADTypes.AbstractMode end
4343
ADTypes.mode(::MixedMode) = ForwardAndReverseMode()
44+
45+
function threshold_batchsize(backend::MixedMode, B::Integer)
46+
return MixedMode(
47+
threshold_batchsize(forward_backend(backend), B),
48+
threshold_batchsize(reverse_backend(backend), B),
49+
)
50+
end

DifferentiationInterface/src/misc/from_primitive.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,29 @@ abstract type FromPrimitive{inplace} <: AbstractADType end
33
check_available(backend::FromPrimitive) = check_available(backend.backend)
44
inplace_support(::FromPrimitive{true}) = InPlaceSupported()
55
inplace_support(::FromPrimitive{false}) = InPlaceNotSupported()
6-
function inner_preparation_behavior(backend::FromPrimitive)
7-
return inner_preparation_behavior(backend.backend)
6+
7+
function pick_batchsize(backend::FromPrimitive, x_or_y::AbstractArray)
8+
return pick_batchsize(backend.backend, x_or_y)
89
end
910

1011
function pick_batchsize(backend::FromPrimitive, N::Integer)
1112
return pick_batchsize(backend.backend, N)
1213
end
1314

15+
function inner_preparation_behavior(backend::FromPrimitive)
16+
return inner_preparation_behavior(backend.backend)
17+
end
18+
19+
function overloaded_input(::typeof(pushforward), f, backend::FromPrimitive, x, tx::NTuple)
20+
return overloaded_input(pushforward, f, backend.backend, x, tx)
21+
end
22+
23+
function overloaded_input(
24+
::typeof(pushforward), f!, y, backend::FromPrimitive, x, tx::NTuple
25+
)
26+
return overloaded_input(pushforward, f!, y, backend.backend, x, tx)
27+
end
28+
1429
"""
1530
AutoForwardFromPrimitive(backend::AbstractADType)
1631

DifferentiationInterface/src/utils/batchsize.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,6 @@ function threshold_batchsize(backend::SecondOrder, B::Integer)
112112
)
113113
end
114114

115-
function threshold_batchsize(backend::MixedMode, B::Integer)
116-
return MixedMode(
117-
threshold_batchsize(forward_backend(backend), B),
118-
threshold_batchsize(reverse_backend(backend), B),
119-
)
120-
end
121-
122115
"""
123116
reasonable_batchsize(N::Integer, Bmax::Integer)
124117

0 commit comments

Comments
 (0)