@@ -74,32 +74,34 @@ struct DifferentiateWith{C, F, B <: AbstractADType, N <: NTuple{C, Any}}
7474 f:: F
7575 backend:: B
7676 context_wrappers:: N
77- end
7877
79- function DifferentiateWith (
80- f:: F ,
81- backend:: B ,
82- context_wrappers:: NTuple{C, Any} ,
83- ) where {F, B <: AbstractADType , C}
84- for (i, wrapper) in pairs (context_wrappers)
85- # Accept typical constructor-like values: functions or types.
86- if ! (wrapper isa Function || wrapper isa Type)
87- throw (
88- ArgumentError (
89- " Each context wrapper must be a callable object or type " *
90- " (e.g., a wrapper constructor like `Constant` or `Cache`), " *
91- " but element $i has type $(typeof (wrapper)) ." ,
92- ),
93- )
78+ function DifferentiateWith (
79+ f:: F ,
80+ backend:: B ,
81+ context_wrappers:: NTuple{C, Any} ,
82+ ) where {F, B <: AbstractADType , C}
83+ for (i, wrapper) in pairs (context_wrappers)
84+ # Accept typical constructor-like values: functions or types.
85+ if ! (wrapper isa Function || wrapper isa Type)
86+ throw (
87+ ArgumentError (
88+ " Each context wrapper must be a callable object or type " *
89+ " (e.g., a wrapper constructor like `Constant` or `Cache`), " *
90+ " but element $i has type $(typeof (wrapper)) ." ,
91+ ),
92+ )
93+ end
9494 end
95+ return new {C, F, B, typeof(context_wrappers)} (
96+ f,
97+ backend,
98+ context_wrappers,
99+ )
95100 end
96- return DifferentiateWith {C, F, B, typeof(context_wrappers)} (
97- f,
98- backend,
99- context_wrappers,
100- )
101101end
102102
103+ DifferentiateWith (f:: F , backend:: AbstractADType ) where {F} = DifferentiateWith (f, backend, ())
104+
103105function (dw:: DifferentiateWith{C} )(x, args:: Vararg{Any, C} ) where {C}
104106 return dw. f (x, args... )
105107end
0 commit comments