Optim algorithms like BFGS, multi-LSTM hybrids, yaml with bounds and PBM, regularization - L2, per parameter optimisers#267
Optim algorithms like BFGS, multi-LSTM hybrids, yaml with bounds and PBM, regularization - L2, per parameter optimisers#267BernhardAhrens wants to merge 16 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the train function by extracting its core logic into a private _train function and updating the public train method to support both explicit configuration objects and flat keyword arguments. Feedback points out a logic error where explicit configuration objects are overwritten if keyword arguments are provided, potentially leading to lost settings. It is also recommended to align the implementation with the documentation regarding the deprecation of flat keyword arguments by adding formal warnings.
| function train( | ||
| model, data; | ||
| train_cfg::TrainConfig = TrainConfig(), | ||
| data_cfg::DataConfig = DataConfig(), | ||
| kwargs..., | ||
| ) | ||
| if !isempty(kwargs) | ||
| train_cfg, data_cfg = kwargs_to_configs((), kwargs) | ||
| end | ||
|
|
||
| return _train(model, data, train_cfg, data_cfg) | ||
| end |
There was a problem hiding this comment.
The current implementation has a logic error when mixing explicit configuration objects (train_cfg, data_cfg) with flat keyword arguments. If kwargs is not empty, the provided train_cfg and data_cfg are completely overwritten by new instances created from kwargs and defaults, leading to silent loss of user configuration.
If mixing is not intended to be supported, it is safer to use nothing as default values to detect explicit usage and prevent accidental overwriting. If mixing is intended, kwargs_to_configs should be updated to merge with the existing configuration objects.
function train(
model, data;
train_cfg::Union{TrainConfig, Nothing} = nothing,
data_cfg::Union{DataConfig, Nothing} = nothing,
kwargs...,
)
if !isempty(kwargs)
if !isnothing(train_cfg) || !isnothing(data_cfg)
throw(ArgumentError("Cannot mix explicit `train_cfg`/`data_cfg` with flat keyword arguments."))
end
train_cfg, data_cfg = kwargs_to_configs((), kwargs)
else
train_cfg = isnothing(train_cfg) ? TrainConfig() : train_cfg
data_cfg = isnothing(data_cfg) ? DataConfig() : data_cfg
end
return _train(model, data, train_cfg, data_cfg)
end| # Keyword Arguments | ||
| - `train_cfg`: Training configuration. See [`TrainConfig`](@ref) for all options. | ||
| - `data_cfg`: Data preparation configuration. See [`DataConfig`](@ref) for all options. | ||
| - Any other kwargs (deprecated) are forwarded as fields to `TrainConfig` / `DataConfig`. |
There was a problem hiding this comment.
The docstring labels flat keyword arguments as "deprecated" in the context of the new train method. However, they are being explicitly supported in the method signature. If the intention is to phase them out, consider adding a formal deprecation warning when they are used via the new API, or clarify if they are intended as a permanent convenience feature.
|
|
||
| unknown = [k for k in keys(kwargs) if k ∉ train_keys && k ∉ data_keys] | ||
| if !isempty(unknown) | ||
| @warn "Unknown kwargs will be ignored: $(join(unknown, ", "))" |
There was a problem hiding this comment.
maybe is better to throw error instead?
…del state to not get epoch already exists error
| target_names = model.targets | ||
| save_epoch = stopper.best_epoch == 0 ? 0 : stopper.best_epoch | ||
| save_ps_st!(paths.best_model, model, cfg.cdev(ps), cfg.cdev(st), cfg.tracked_params, save_epoch) | ||
| save_ps_st(paths.best_model, model, cfg.cdev(ps), cfg.cdev(st), cfg.tracked_params, save_epoch) |
There was a problem hiding this comment.
it should be !. At this stage, the initial file was created already. I think this will write fully the file, or?
| export get_all_groups | ||
| export load_group | ||
| function save_ps_st(file_name, hm, ps, st, save_ps) | ||
| function save_ps_st(file_name, hm, ps, st, save_ps, epoch = 0) |
There was a problem hiding this comment.
oh, I see. Same function in both places, no need for the ! variant?
| when non-empty it is forwarded as `tracked_params` on the resulting `TrainConfig`. | ||
| """ | ||
| function kwargs_to_configs(kwargs) | ||
| function kwargs_to_configs(save_ps, kwargs) |
There was a problem hiding this comment.
so, this makes the interface compatible with the old syntax ? good.
|
@lazarusA I'd say my PR from hell is ready. I think I'll do not split into several PRs and just have a length squash commit messsage |
|
/gemini summary |
Summary of ChangesThis pull request introduces significant architectural enhancements to the hybrid modeling framework, primarily by integrating Optimization.jl to support a wider range of optimization algorithms. It enables flexible per-branch optimizer configurations, adds robust L2 regularization for model weights, and improves the persistence of model configurations in YAML. Additionally, the training pipeline has been refactored to support both traditional Optimisers.jl loops and new Optimization.jl-based solvers, providing greater versatility for training complex hybrid models. Highlights
New Features🧠 You can now enable Memory (public preview) to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Activity
|
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces an Optimization.jl-based training driver path alongside the existing Optimisers.jl loop, adds utilities for weight extraction and L2 regularization (extract_weights and weight_l2), updates the extra_loss signature to accept model parameters, and improves YAML serialization for configurations. Feedback on these changes highlights several key issues: Meta.parse error expressions are not properly caught during YAML serialization, eachslice indexing may cause runtime crashes on Julia versions prior to 1.9, and the new weight extraction utilities silently ignore parameters grouped in Tuple containers.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| expr_and_next = try | ||
| Meta.parse(text, idx; greedy = true, raise = false) | ||
| catch | ||
| return nothing | ||
| end | ||
| expr_and_next === nothing && return nothing | ||
| _, next_idx = expr_and_next |
There was a problem hiding this comment.
When Meta.parse is called with raise = false, parsing errors do not throw an exception. Instead, it returns an expression containing the error (e.g., Expr(:error, ...) or Expr(:incomplete, ...)). Checking if expr_and_next === nothing is insufficient to catch these parsing failures, which can lead to invalid source code being written to the YAML. We should explicitly check if the parsed expression is an error or incomplete.
expr_and_next = try
Meta.parse(text, idx; greedy = true, raise = false)
catch
return nothing
end
expr_and_next === nothing && return nothing
expr, next_idx = expr_and_next
if Meta.isexpr(expr, :error) || Meta.isexpr(expr, :incomplete)
return nothing
end| # `eachslice(...; dims = 1)` (instead of `eachrow`) so this works for both | ||
| # the feed-forward case (2D `(param, batch)` output) and the recurrent/LSTM | ||
| # case (3D `(param, time, batch)` sequence output). | ||
| nn_cols = eachslice(nn_outputs[nn_name]; dims = 1) | ||
|
|
||
| # Create parameter for this NN | ||
| nn_param = NamedTuple{(param_name,), Tuple{typeof(nn_cols[1])}}((nn_cols[1],)) |
There was a problem hiding this comment.
In Julia versions prior to 1.9, eachslice returns a Generator which does not support indexing (e.g., nn_cols[1]), which will cause a runtime crash. Since out_dim is always 1, we only ever need the first slice along the first dimension. Using selectdim is fully compatible with all Julia versions, avoids creating an iterator, and directly returns the first slice as a view.
# Use selectdim to get the first slice along the first dimension.
# This works for both 2D and 3D outputs and is fully compatible with Julia < 1.9.
nn_col = selectdim(nn_outputs[nn_name], 1, 1)
# Create parameter for this NN
nn_param = NamedTuple{(param_name,), Tuple{typeof(nn_col)}}((nn_col,))| _collect!(_, _, ::Symbol) = nothing | ||
|
|
||
| function _collect!(out, node::Union{NamedTuple, ComponentArray}, key::Symbol) |
There was a problem hiding this comment.
If the parameters are grouped in a Tuple (which is a common container in Julia and Lux), _collect! will silently ignore them because it only has methods for NamedTuple and ComponentArray. Adding a method for Tuple allows recursing into tuple elements.
_collect!(_, _, ::Symbol) = nothing
function _collect!(out, node::Tuple, key::Symbol)
for child in node
_collect!(out, child, key)
end
return nothing
end
function _collect!(out, node::Union{NamedTuple, ComponentArray}, key::Symbol)| _weight_l2_stats(::Any, ::Symbol) = (0.0f0, 0) | ||
|
|
||
| function _weight_l2_stats(node::Union{NamedTuple, ComponentArray}, key::Symbol) |
There was a problem hiding this comment.
Similarly to _collect!, _weight_l2_stats will silently return (0.0f0, 0) for any parameters grouped in a Tuple. Adding a method for Tuple ensures that L2 regularization correctly covers all parameters.
_weight_l2_stats(::Any, ::Symbol) = (0.0f0, 0)
function _weight_l2_stats(node::Tuple, key::Symbol)
s = 0.0f0
n = 0
for child in node
cs, cn = _weight_l2_stats(child, key)
s = s + cs
n = n + cn
end
return s, n
end
function _weight_l2_stats(node::Union{NamedTuple, ComponentArray}, key::Symbol)
No description provided.