Skip to content

inner second order and outer optimization AD? #953

@zahachtah

Description

@zahachtah

Hi, I have a case where I want to use DiffEqFlux.jl collocation example:

https://docs.sciml.ai/DiffEqFlux/stable/examples/collocation/

but instead of just training a neural net, I need the first and second derivative with respect to one input.

I am trying to solve it but I keep getting errors in zygote or forwarddiff backend compabilities which are a bit over my head.

I was going to ask if there are any things I should consider when doing this that might help me. I can also give more detailed code but here are some relevant parts (note that I need to use the value_derivative_and_second_derivative! which I think is not available for AutoZygote?):

so in the inner function I use

backend = AutoForwardDiff()

u0 = MVector{3,Float64}(rand(3))
prep = prepare_second_derivative(g, backend, u0[2], Constant(u0), Constant(0.0))

# Reusable buffers (scalars)
der = zeros(Float64, 1)   # df/du2
der2 = zeros(Float64, 1)  # d2f/du2^2

function rhs!(du, u, p, t) # the differential equation to run the simulation with
    # DI computes value + 1st + 2nd derivative wrt x in one call
    f, _, _ = value_derivative_and_second_derivative!(
        g, der, der2, prep, backend, u[2], Constant(u), Constant(t)
    )
    f1 = der[1]
    f2 = der2[1]
    # Moment-closure dynamics (M0..M2) with M3/M4 from drivers
    m0, m1, m2 = u
    drv_vec = driver_vec(t, eltype(u))
    m3 = drv_vec[end-1]
    m4 = drv_vec[end]
    d0 = f * m0 + 0.5 * f2 * m2
    d1 = f1 * m2 + 0.5 * f2 * m3
    d2 = f1 * m3 + 0.5 * f2 * (m4 - m2^2)
    # d3/d4 are provided by drivers (closure)
    #println((d0, d1, d2))
    du[1] = d0
    du[2] = d1
    du[3] = d2
    return nothing
end

and then I need to train the parameters to minimize the loss of the simulation with the actual data I have.

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions