OpenNN
Open-source neural networks library
Loading...
Searching...
No Matches
opennn::TrainingStrategy Class Reference

Coordinates the training of a NeuralNetwork on a Dataset. More...

#include <training_strategy.h>

Public Member Functions

 TrainingStrategy (NeuralNetwork *new_neural_network=nullptr, Dataset *new_dataset=nullptr)
 Constructs a training strategy for a given network and dataset.
 
const Datasetget_dataset () const
 Returns the dataset associated with the training strategy.
 
Datasetget_dataset ()
 Returns the dataset associated with the training strategy.
 
const NeuralNetworkget_neural_network () const
 Returns the neural network being trained.
 
NeuralNetworkget_neural_network ()
 Returns the neural network being trained.
 
const Lossget_loss () const
 Returns the loss term used during training.
 
Lossget_loss ()
 Returns the loss term used during training.
 
const Optimizerget_optimization_algorithm () const
 Returns the optimizer that performs parameter updates.
 
Optimizerget_optimization_algorithm ()
 Returns the optimizer that performs parameter updates.
 
void set (NeuralNetwork *new_neural_network=nullptr, Dataset *new_dataset=nullptr)
 Resets the strategy to point at a new network and dataset.
 
void set_default ()
 Picks a default loss and optimizer based on the network architecture.
 
void set_dataset (Dataset *new_dataset)
 Replaces the dataset pointer without rebuilding loss/optimizer.
 
void set_neural_network (NeuralNetwork *new_neural_network)
 Replaces the neural network pointer without rebuilding loss/optimizer.
 
void set_loss (const string &new_loss)
 Selects the loss term by name.
 
void set_optimization_algorithm (const string &new_optimization_algorithm)
 Selects the optimization algorithm by name.
 
TrainingResults train ()
 Runs the training loop.
 
void from_JSON (const JsonDocument &document)
 Restores the strategy state from a JSON document.
 
void to_JSON (JsonWriter &writer) const
 Serializes the strategy state to JSON.
 
void save (const filesystem::path &file_name) const
 Saves the strategy state to a JSON file on disk.
 
void load (const filesystem::path &file_name)
 Loads the strategy state from a JSON file on disk.
 

Detailed Description

Coordinates the training of a NeuralNetwork on a Dataset.

Aggregates a Loss (error term) and an Optimizer (parameter-update rule), keeps non-owning pointers to the network and dataset they operate on, and exposes a single train() entry point that delegates to the optimizer.

Sensible defaults are picked automatically based on the network's layer composition: AdaptiveMomentEstimation + appropriate loss for recurrent, convolutional, transformer-like and text-classification networks; QuasiNewtonMethod

  • CrossEntropy / WeightedSquaredError / MeanSquaredError for the classification and approximation cases.

Constructor & Destructor Documentation

◆ TrainingStrategy()

opennn::TrainingStrategy::TrainingStrategy ( NeuralNetwork * new_neural_network = nullptr,
Dataset * new_dataset = nullptr )

Constructs a training strategy for a given network and dataset.

Both pointers may be null and supplied later via set(). When both are provided the constructor calls set_default() to pick a loss and optimizer appropriate for the network's architecture.

Parameters
new_neural_networkNon-owning pointer to the network to be trained.
new_datasetNon-owning pointer to the dataset providing samples.

Member Function Documentation

◆ from_JSON()

void opennn::TrainingStrategy::from_JSON ( const JsonDocument & document)

Restores the strategy state from a JSON document.

Parameters
documentParsed JSON document produced by to_JSON().

◆ get_dataset() [1/2]

Dataset * opennn::TrainingStrategy::get_dataset ( )
inline

Returns the dataset associated with the training strategy.

Returns
Mutable pointer to the dataset (may be nullptr).

◆ get_dataset() [2/2]

const Dataset * opennn::TrainingStrategy::get_dataset ( ) const
inline

Returns the dataset associated with the training strategy.

Returns
Const pointer to the dataset (may be nullptr).

◆ get_loss() [1/2]

Loss * opennn::TrainingStrategy::get_loss ( )
inline

Returns the loss term used during training.

Returns
Mutable pointer to the Loss instance owned by this strategy.

◆ get_loss() [2/2]

const Loss * opennn::TrainingStrategy::get_loss ( ) const
inline

Returns the loss term used during training.

Returns
Const pointer to the Loss instance owned by this strategy.

◆ get_neural_network() [1/2]

NeuralNetwork * opennn::TrainingStrategy::get_neural_network ( )
inline

Returns the neural network being trained.

Returns
Mutable pointer to the network (may be nullptr).

◆ get_neural_network() [2/2]

const NeuralNetwork * opennn::TrainingStrategy::get_neural_network ( ) const
inline

Returns the neural network being trained.

Returns
Const pointer to the network (may be nullptr).

◆ get_optimization_algorithm() [1/2]

Optimizer * opennn::TrainingStrategy::get_optimization_algorithm ( )
inline

Returns the optimizer that performs parameter updates.

Returns
Mutable pointer to the Optimizer instance owned by this strategy.

◆ get_optimization_algorithm() [2/2]

const Optimizer * opennn::TrainingStrategy::get_optimization_algorithm ( ) const
inline

Returns the optimizer that performs parameter updates.

Returns
Const pointer to the Optimizer instance owned by this strategy.

◆ load()

void opennn::TrainingStrategy::load ( const filesystem::path & file_name)

Loads the strategy state from a JSON file on disk.

Calls set_default() before reading, so any fields missing from the file keep architecture-appropriate defaults.

Parameters
file_nameSource path.

◆ save()

void opennn::TrainingStrategy::save ( const filesystem::path & file_name) const

Saves the strategy state to a JSON file on disk.

Parameters
file_nameDestination path.
Exceptions
runtime_errorif the file cannot be opened for writing.

◆ set()

void opennn::TrainingStrategy::set ( NeuralNetwork * new_neural_network = nullptr,
Dataset * new_dataset = nullptr )

Resets the strategy to point at a new network and dataset.

Stores both pointers and calls set_default() to rebuild the loss and optimizer to match the network's architecture.

Parameters
new_neural_networkNon-owning pointer to the network.
new_datasetNon-owning pointer to the dataset.

◆ set_dataset()

void opennn::TrainingStrategy::set_dataset ( Dataset * new_dataset)
inline

Replaces the dataset pointer without rebuilding loss/optimizer.

Parameters
new_datasetNon-owning pointer to the new dataset.

◆ set_default()

void opennn::TrainingStrategy::set_default ( )

Picks a default loss and optimizer based on the network architecture.

Inspects the layers of the configured neural network and selects:

◆ set_loss()

void opennn::TrainingStrategy::set_loss ( const string & new_loss)

Selects the loss term by name.

Constructs a fresh Loss instance bound to the current network and dataset, sets the requested error type and re-binds the optimizer (if any).

Parameters
new_lossOne of: "MeanSquaredError", "NormalizedSquaredError", "WeightedSquaredError", "CrossEntropy", "CrossEntropyError3d", "MinkowskiError".

◆ set_neural_network()

void opennn::TrainingStrategy::set_neural_network ( NeuralNetwork * new_neural_network)
inline

Replaces the neural network pointer without rebuilding loss/optimizer.

Parameters
new_neural_networkNon-owning pointer to the new network.

◆ set_optimization_algorithm()

void opennn::TrainingStrategy::set_optimization_algorithm ( const string & new_optimization_algorithm)

Selects the optimization algorithm by name.

Looks up the optimizer in the registry, instantiates it and binds it to the current loss.

Parameters
new_optimization_algorithmOne of: "AdaptiveMomentEstimation", "QuasiNewtonMethod", "StochasticGradientDescent", "LevenbergMarquardtAlgorithm".

◆ to_JSON()

void opennn::TrainingStrategy::to_JSON ( JsonWriter & writer) const

Serializes the strategy state to JSON.

Parameters
writerJSON writer that receives the strategy's element tree.

◆ train()

TrainingResults opennn::TrainingStrategy::train ( )

Runs the training loop.

Validates that network, dataset, loss and optimizer are all set, applies the forecasting batch-size adjustment when relevant, then delegates to the optimizer's train() method.

Returns
TrainingResults with the per-epoch loss/selection-error history and the final stopping condition.
Exceptions
runtime_errorif any required component is not configured.