|
OpenNN
Open-source neural networks library
|
Coordinates the training of a NeuralNetwork on a Dataset. More...
#include <training_strategy.h>
Public Member Functions | |
| TrainingStrategy (NeuralNetwork *new_neural_network=nullptr, Dataset *new_dataset=nullptr) | |
| Constructs a training strategy for a given network and dataset. | |
| const Dataset * | get_dataset () const |
| Returns the dataset associated with the training strategy. | |
| Dataset * | get_dataset () |
| Returns the dataset associated with the training strategy. | |
| const NeuralNetwork * | get_neural_network () const |
| Returns the neural network being trained. | |
| NeuralNetwork * | get_neural_network () |
| Returns the neural network being trained. | |
| const Loss * | get_loss () const |
| Returns the loss term used during training. | |
| Loss * | get_loss () |
| Returns the loss term used during training. | |
| const Optimizer * | get_optimization_algorithm () const |
| Returns the optimizer that performs parameter updates. | |
| Optimizer * | get_optimization_algorithm () |
| Returns the optimizer that performs parameter updates. | |
| void | set (NeuralNetwork *new_neural_network=nullptr, Dataset *new_dataset=nullptr) |
| Resets the strategy to point at a new network and dataset. | |
| void | set_default () |
| Picks a default loss and optimizer based on the network architecture. | |
| void | set_dataset (Dataset *new_dataset) |
| Replaces the dataset pointer without rebuilding loss/optimizer. | |
| void | set_neural_network (NeuralNetwork *new_neural_network) |
| Replaces the neural network pointer without rebuilding loss/optimizer. | |
| void | set_loss (const string &new_loss) |
| Selects the loss term by name. | |
| void | set_optimization_algorithm (const string &new_optimization_algorithm) |
| Selects the optimization algorithm by name. | |
| TrainingResults | train () |
| Runs the training loop. | |
| void | from_JSON (const JsonDocument &document) |
| Restores the strategy state from a JSON document. | |
| void | to_JSON (JsonWriter &writer) const |
| Serializes the strategy state to JSON. | |
| void | save (const filesystem::path &file_name) const |
| Saves the strategy state to a JSON file on disk. | |
| void | load (const filesystem::path &file_name) |
| Loads the strategy state from a JSON file on disk. | |
Coordinates the training of a NeuralNetwork on a Dataset.
Aggregates a Loss (error term) and an Optimizer (parameter-update rule), keeps non-owning pointers to the network and dataset they operate on, and exposes a single train() entry point that delegates to the optimizer.
Sensible defaults are picked automatically based on the network's layer composition: AdaptiveMomentEstimation + appropriate loss for recurrent, convolutional, transformer-like and text-classification networks; QuasiNewtonMethod
| opennn::TrainingStrategy::TrainingStrategy | ( | NeuralNetwork * | new_neural_network = nullptr, |
| Dataset * | new_dataset = nullptr ) |
Constructs a training strategy for a given network and dataset.
Both pointers may be null and supplied later via set(). When both are provided the constructor calls set_default() to pick a loss and optimizer appropriate for the network's architecture.
| new_neural_network | Non-owning pointer to the network to be trained. |
| new_dataset | Non-owning pointer to the dataset providing samples. |
| void opennn::TrainingStrategy::from_JSON | ( | const JsonDocument & | document | ) |
Restores the strategy state from a JSON document.
| document | Parsed JSON document produced by to_JSON(). |
|
inline |
Returns the dataset associated with the training strategy.
|
inline |
Returns the dataset associated with the training strategy.
|
inline |
Returns the loss term used during training.
|
inline |
Returns the loss term used during training.
|
inline |
Returns the neural network being trained.
|
inline |
Returns the neural network being trained.
|
inline |
Returns the optimizer that performs parameter updates.
|
inline |
Returns the optimizer that performs parameter updates.
| void opennn::TrainingStrategy::load | ( | const filesystem::path & | file_name | ) |
Loads the strategy state from a JSON file on disk.
Calls set_default() before reading, so any fields missing from the file keep architecture-appropriate defaults.
| file_name | Source path. |
| void opennn::TrainingStrategy::save | ( | const filesystem::path & | file_name | ) | const |
Saves the strategy state to a JSON file on disk.
| file_name | Destination path. |
| runtime_error | if the file cannot be opened for writing. |
| void opennn::TrainingStrategy::set | ( | NeuralNetwork * | new_neural_network = nullptr, |
| Dataset * | new_dataset = nullptr ) |
Resets the strategy to point at a new network and dataset.
Stores both pointers and calls set_default() to rebuild the loss and optimizer to match the network's architecture.
| new_neural_network | Non-owning pointer to the network. |
| new_dataset | Non-owning pointer to the dataset. |
|
inline |
Replaces the dataset pointer without rebuilding loss/optimizer.
| new_dataset | Non-owning pointer to the new dataset. |
| void opennn::TrainingStrategy::set_default | ( | ) |
Picks a default loss and optimizer based on the network architecture.
Inspects the layers of the configured neural network and selects:
| void opennn::TrainingStrategy::set_loss | ( | const string & | new_loss | ) |
Selects the loss term by name.
Constructs a fresh Loss instance bound to the current network and dataset, sets the requested error type and re-binds the optimizer (if any).
| new_loss | One of: "MeanSquaredError", "NormalizedSquaredError", "WeightedSquaredError", "CrossEntropy", "CrossEntropyError3d", "MinkowskiError". |
|
inline |
Replaces the neural network pointer without rebuilding loss/optimizer.
| new_neural_network | Non-owning pointer to the new network. |
| void opennn::TrainingStrategy::set_optimization_algorithm | ( | const string & | new_optimization_algorithm | ) |
Selects the optimization algorithm by name.
Looks up the optimizer in the registry, instantiates it and binds it to the current loss.
| new_optimization_algorithm | One of: "AdaptiveMomentEstimation", "QuasiNewtonMethod", "StochasticGradientDescent", "LevenbergMarquardtAlgorithm". |
| void opennn::TrainingStrategy::to_JSON | ( | JsonWriter & | writer | ) | const |
Serializes the strategy state to JSON.
| writer | JSON writer that receives the strategy's element tree. |
| TrainingResults opennn::TrainingStrategy::train | ( | ) |
Runs the training loop.
Validates that network, dataset, loss and optimizer are all set, applies the forecasting batch-size adjustment when relevant, then delegates to the optimizer's train() method.
| runtime_error | if any required component is not configured. |