Skip to content

Commit 47d0b12

Browse files
committed
Fix caches
1 parent e05f4fd commit 47d0b12

2 files changed

Lines changed: 7 additions & 2 deletions

File tree

DifferentiationInterface/src/second_order/hvp.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,10 @@ function _prepare_hvp_aux(
202202
xoi = overloaded_input(
203203
pushforward, shuffled_gradient!, grad_buffer, outer(backend), x, tx
204204
)
205-
inner_gradient_prep = prepare_gradient(f, inner(backend), xo, contexts...; strict)
206-
inner_gradient_in_prep = prepare_gradient(f, inner(backend), xoi, contexts...; strict)
205+
contextso = adapt_eltype.(contexts, Ref(eltype(xo)))
206+
contextsoi = adapt_eltype.(contexts, Ref(eltype(xoi)))
207+
inner_gradient_prep = prepare_gradient(f, inner(backend), xo, contextso...; strict)
208+
inner_gradient_in_prep = prepare_gradient(f, inner(backend), xoi, contextsoi...; strict)
207209
# Outer pushforward
208210
new_contexts = (
209211
FunctionContext(f),

DifferentiationInterface/src/utils/context.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,6 @@ function with_contexts(f::F, contexts::Vararg{Context,N}) where {F,N}
144144
tail_args = map(unwrap, contexts)
145145
return FixTail(f, tail_args...)
146146
end
147+
148+
adapt_eltype(c::Constant, ::Type) where {T} = c
149+
adapt_eltype(c::Cache, ::Type{T}) where {T} = Cache(similar(unwrap(c), T))

0 commit comments

Comments
 (0)