Skip to content

Commit 0f44787

Browse files
authored
Secret preparation modifier for resizing (#521)
* Secret preparation modifier * Add docstrings
1 parent 73f7314 commit 0f44787

12 files changed

Lines changed: 547 additions & 255 deletions

File tree

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ include("second_order/hvp.jl")
5353
include("second_order/hessian.jl")
5454

5555
include("fallbacks/no_prep.jl")
56+
include("fallbacks/change_prep.jl")
5657

5758
include("misc/differentiate_with.jl")
5859
include("misc/from_primitive.jl")
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
for op in [
2+
:derivative,
3+
:gradient,
4+
:jacobian,
5+
:second_derivative,
6+
:hessian,
7+
:pushforward,
8+
:pullback,
9+
:hvp,
10+
]
11+
op! = Symbol(op, "!")
12+
val_and_op = if op == :second_derivative
13+
:value_derivative_and_second_derivative
14+
elseif op == :hessian
15+
:value_gradient_and_hessian
16+
elseif op == :hvp
17+
nothing
18+
else
19+
Symbol("value_and_", op)
20+
end
21+
val_and_op! = Symbol(val_and_op, "!")
22+
prep_op = Symbol("prepare_", op)
23+
prep_op! = Symbol("prepare!_", op)
24+
prep_op_same_point = Symbol("prepare_", op, "_same_point")
25+
P = if op == :derivative
26+
DerivativePrep
27+
elseif op == :gradient
28+
GradientPrep
29+
elseif op == :jacobian
30+
JacobianPrep
31+
elseif op == :second_derivative
32+
SecondDerivativePrep
33+
elseif op == :hessian
34+
HessianPrep
35+
elseif op == :pushforward
36+
PushforwardPrep
37+
elseif op == :pullback
38+
PullbackPrep
39+
elseif op == :hvp
40+
HVPPrep
41+
end
42+
43+
if op in (:derivative, :gradient, :jacobian)
44+
# 1-arg
45+
@eval function $prep_op!(
46+
f::F, ::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}
47+
) where {F,C}
48+
return $prep_op(f, backend, x, contexts...)
49+
end
50+
op == :gradient && continue
51+
# 2-arg
52+
@eval function $prep_op!(
53+
f!::F, y, ::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}
54+
) where {F,C}
55+
return $prep_op(f!, y, backend, x, contexts...)
56+
end
57+
58+
elseif op in (:second_derivative, :hessian)
59+
# 1-arg
60+
@eval function $prep_op!(
61+
f::F, ::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}
62+
) where {F,C}
63+
return $prep_op(f, backend, x, contexts...)
64+
end
65+
66+
elseif op in (:pushforward, :pullback, :hvp)
67+
# 1-arg
68+
@eval function $prep_op!(
69+
f::F,
70+
::$P,
71+
backend::AbstractADType,
72+
x,
73+
seed::NTuple,
74+
contexts::Vararg{Context,C},
75+
) where {F,C}
76+
return $prep_op(f, backend, x, seed, contexts...)
77+
end
78+
@eval function $prep_op_same_point(
79+
f::F,
80+
prep::$P,
81+
backend::AbstractADType,
82+
x,
83+
seed::NTuple,
84+
contexts::Vararg{Context,C},
85+
) where {F,C}
86+
return prep
87+
end
88+
@eval function $prep_op_same_point(
89+
f::F, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C}
90+
) where {F,C}
91+
prep = $prep_op(f, backend, x, seed, contexts...)
92+
return $prep_op_same_point(f, prep, backend, x, seed, contexts...)
93+
end
94+
op == :hvp && continue
95+
# 2-arg
96+
@eval function $prep_op!(
97+
f!::F,
98+
y,
99+
::$P,
100+
backend::AbstractADType,
101+
x,
102+
seed::NTuple,
103+
contexts::Vararg{Context,C},
104+
) where {F,C}
105+
return $prep_op(f!, y, backend, x, seed, contexts...)
106+
end
107+
@eval function $prep_op_same_point(
108+
f!::F,
109+
y,
110+
prep::$P,
111+
backend::AbstractADType,
112+
x,
113+
seed::NTuple,
114+
contexts::Vararg{Context,C},
115+
) where {F,C}
116+
return prep
117+
end
118+
@eval function $prep_op_same_point(
119+
f!::F, y, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C}
120+
) where {F,C}
121+
prep = $prep_op(f!, y, backend, x, seed, contexts...)
122+
return $prep_op_same_point(f!, y, prep, backend, x, seed, contexts...)
123+
end
124+
end
125+
end

0 commit comments

Comments
 (0)