Skip to content

Optim algorithms like BFGS, multi-LSTM hybrids, yaml with bounds and PBM, regularization - L2, per parameter optimisers#267

Open
BernhardAhrens wants to merge 16 commits into
mainfrom
ba/train_signatures
Open

Optim algorithms like BFGS, multi-LSTM hybrids, yaml with bounds and PBM, regularization - L2, per parameter optimisers#267
BernhardAhrens wants to merge 16 commits into
mainfrom
ba/train_signatures

Conversation

@BernhardAhrens

Copy link
Copy Markdown
Collaborator

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the train function by extracting its core logic into a private _train function and updating the public train method to support both explicit configuration objects and flat keyword arguments. Feedback points out a logic error where explicit configuration objects are overwritten if keyword arguments are provided, potentially leading to lost settings. It is also recommended to align the implementation with the documentation regarding the deprecation of flat keyword arguments by adding formal warnings.

Comment thread src/training/train.jl
Comment on lines +83 to +94
function train(
model, data;
train_cfg::TrainConfig = TrainConfig(),
data_cfg::DataConfig = DataConfig(),
kwargs...,
)
if !isempty(kwargs)
train_cfg, data_cfg = kwargs_to_configs((), kwargs)
end

return _train(model, data, train_cfg, data_cfg)
end

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation has a logic error when mixing explicit configuration objects (train_cfg, data_cfg) with flat keyword arguments. If kwargs is not empty, the provided train_cfg and data_cfg are completely overwritten by new instances created from kwargs and defaults, leading to silent loss of user configuration.

If mixing is not intended to be supported, it is safer to use nothing as default values to detect explicit usage and prevent accidental overwriting. If mixing is intended, kwargs_to_configs should be updated to merge with the existing configuration objects.

function train(
    model, data;
    train_cfg::Union{TrainConfig, Nothing} = nothing,
    data_cfg::Union{DataConfig, Nothing} = nothing,
    kwargs...,
)
    if !isempty(kwargs)
        if !isnothing(train_cfg) || !isnothing(data_cfg)
            throw(ArgumentError("Cannot mix explicit `train_cfg`/`data_cfg` with flat keyword arguments."))
        end
        train_cfg, data_cfg = kwargs_to_configs((), kwargs)
    else
        train_cfg = isnothing(train_cfg) ? TrainConfig() : train_cfg
        data_cfg = isnothing(data_cfg) ? DataConfig() : data_cfg
    end

    return _train(model, data, train_cfg, data_cfg)
end

Comment thread src/training/train.jl Outdated
# Keyword Arguments
- `train_cfg`: Training configuration. See [`TrainConfig`](@ref) for all options.
- `data_cfg`: Data preparation configuration. See [`DataConfig`](@ref) for all options.
- Any other kwargs (deprecated) are forwarded as fields to `TrainConfig` / `DataConfig`.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring labels flat keyword arguments as "deprecated" in the context of the new train method. However, they are being explicitly supported in the method signature. If the intention is to phase them out, consider adding a formal deprecation warning when they are used via the new API, or clarify if they are intended as a permanent convenience feature.

Comment thread src/training/train.jl Outdated

unknown = [k for k in keys(kwargs) if k ∉ train_keys && k ∉ data_keys]
if !isempty(unknown)
@warn "Unknown kwargs will be ignored: $(join(unknown, ", "))"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe is better to throw error instead?

Comment thread src/io/checkpoints.jl
target_names = model.targets
save_epoch = stopper.best_epoch == 0 ? 0 : stopper.best_epoch
save_ps_st!(paths.best_model, model, cfg.cdev(ps), cfg.cdev(st), cfg.tracked_params, save_epoch)
save_ps_st(paths.best_model, model, cfg.cdev(ps), cfg.cdev(st), cfg.tracked_params, save_epoch)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be !. At this stage, the initial file was created already. I think this will write fully the file, or?

Comment thread src/io/save.jl
export get_all_groups
export load_group
function save_ps_st(file_name, hm, ps, st, save_ps)
function save_ps_st(file_name, hm, ps, st, save_ps, epoch = 0)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I see. Same function in both places, no need for the ! variant?

Comment thread src/training/train.jl
when non-empty it is forwarded as `tracked_params` on the resulting `TrainConfig`.
"""
function kwargs_to_configs(kwargs)
function kwargs_to_configs(save_ps, kwargs)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so, this makes the interface compatible with the old syntax ? good.

@BernhardAhrens BernhardAhrens changed the title support train kwargs and configs Optim algorithms like BFGS, multi-LSTM hybrids, yaml with bounds and PBM, regularization - L2, per parameter optimisers Jun 12, 2026
@BernhardAhrens

Copy link
Copy Markdown
Collaborator Author

@lazarusA I'd say my PR from hell is ready. I think I'll do not split into several PRs and just have a length squash commit messsage

@lazarusA

Copy link
Copy Markdown
Member

/gemini summary

@gemini-code-assist

Copy link
Copy Markdown

Summary of Changes

This pull request introduces significant architectural enhancements to the hybrid modeling framework, primarily by integrating Optimization.jl to support a wider range of optimization algorithms. It enables flexible per-branch optimizer configurations, adds robust L2 regularization for model weights, and improves the persistence of model configurations in YAML. Additionally, the training pipeline has been refactored to support both traditional Optimisers.jl loops and new Optimization.jl-based solvers, providing greater versatility for training complex hybrid models.

Highlights

  • Optimization Engine Integration: Integrated Optimization.jl to support advanced solvers like L-BFGS, enabling both full-batch and minibatch training strategies.
  • Per-Branch Optimizers: Enabled per-parameter/per-branch optimizer configurations using Optimisers.jl, allowing for more granular control over training dynamics.
  • Regularization and Utilities: Added L2 regularization support for model weights and introduced utility functions for weight extraction and norm calculation.
  • Enhanced Configuration Serialization: Improved YAML configuration saving to include parameter bounds and the full source code of mechanistic models for better reproducibility.
  • Data Handling Improvements: Refined data processing for MultiNNHybridModel, including better sequence filtering and support for multi-branch neural network architectures.
New Features

🧠 You can now enable Memory (public preview) to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Activity
  • Author BernhardAhrens indicated the PR is ready and will be submitted as a single large commit.
  • Reviewer lazarusA requested a summary of the changes.
  • Previous automated reviews addressed logic for mixing configuration objects with keyword arguments and clarified deprecation warnings.

@lazarusA

Copy link
Copy Markdown
Member

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an Optimization.jl-based training driver path alongside the existing Optimisers.jl loop, adds utilities for weight extraction and L2 regularization (extract_weights and weight_l2), updates the extra_loss signature to accept model parameters, and improves YAML serialization for configurations. Feedback on these changes highlights several key issues: Meta.parse error expressions are not properly caught during YAML serialization, eachslice indexing may cause runtime crashes on Julia versions prior to 1.9, and the new weight extraction utilities silently ignore parameters grouped in Tuple containers.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread src/config/config_yaml.jl
Comment on lines +104 to +110
expr_and_next = try
Meta.parse(text, idx; greedy = true, raise = false)
catch
return nothing
end
expr_and_next === nothing && return nothing
_, next_idx = expr_and_next

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When Meta.parse is called with raise = false, parsing errors do not throw an exception. Instead, it returns an expression containing the error (e.g., Expr(:error, ...) or Expr(:incomplete, ...)). Checking if expr_and_next === nothing is insufficient to catch these parsing failures, which can lead to invalid source code being written to the YAML. We should explicitly check if the parsed expression is an error or incomplete.

    expr_and_next = try
        Meta.parse(text, idx; greedy = true, raise = false)
    catch
        return nothing
    end
    expr_and_next === nothing && return nothing
    expr, next_idx = expr_and_next
    if Meta.isexpr(expr, :error) || Meta.isexpr(expr, :incomplete)
        return nothing
    end

Comment on lines +486 to 492
# `eachslice(...; dims = 1)` (instead of `eachrow`) so this works for both
# the feed-forward case (2D `(param, batch)` output) and the recurrent/LSTM
# case (3D `(param, time, batch)` sequence output).
nn_cols = eachslice(nn_outputs[nn_name]; dims = 1)

# Create parameter for this NN
nn_param = NamedTuple{(param_name,), Tuple{typeof(nn_cols[1])}}((nn_cols[1],))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In Julia versions prior to 1.9, eachslice returns a Generator which does not support indexing (e.g., nn_cols[1]), which will cause a runtime crash. Since out_dim is always 1, we only ever need the first slice along the first dimension. Using selectdim is fully compatible with all Julia versions, avoids creating an iterator, and directly returns the first slice as a view.

        # Use selectdim to get the first slice along the first dimension.
        # This works for both 2D and 3D outputs and is fully compatible with Julia < 1.9.
        nn_col = selectdim(nn_outputs[nn_name], 1, 1)

        # Create parameter for this NN
        nn_param = NamedTuple{(param_name,), Tuple{typeof(nn_col)}}((nn_col,))

Comment on lines +37 to +39
_collect!(_, _, ::Symbol) = nothing

function _collect!(out, node::Union{NamedTuple, ComponentArray}, key::Symbol)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If the parameters are grouped in a Tuple (which is a common container in Julia and Lux), _collect! will silently ignore them because it only has methods for NamedTuple and ComponentArray. Adding a method for Tuple allows recursing into tuple elements.

_collect!(_, _, ::Symbol) = nothing

function _collect!(out, node::Tuple, key::Symbol)
    for child in node
        _collect!(out, child, key)
    end
    return nothing
end

function _collect!(out, node::Union{NamedTuple, ComponentArray}, key::Symbol)

Comment on lines +74 to +76
_weight_l2_stats(::Any, ::Symbol) = (0.0f0, 0)

function _weight_l2_stats(node::Union{NamedTuple, ComponentArray}, key::Symbol)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similarly to _collect!, _weight_l2_stats will silently return (0.0f0, 0) for any parameters grouped in a Tuple. Adding a method for Tuple ensures that L2 regularization correctly covers all parameters.

_weight_l2_stats(::Any, ::Symbol) = (0.0f0, 0)

function _weight_l2_stats(node::Tuple, key::Symbol)
    s = 0.0f0
    n = 0
    for child in node
        cs, cn = _weight_l2_stats(child, key)
        s = s + cs
        n = n + cn
    end
    return s, n
end

function _weight_l2_stats(node::Union{NamedTuple, ComponentArray}, key::Symbol)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

nn weights regularisation Custom L2 and L1 regularization - only on NN parameters

2 participants