Training strategy class
The TrainingStrategy class encapsulates the concept of a training strategy. As mentioned in earlier chapters, it is responsible for training a neural network. This class is built upon two abstract components: LossIndex and OptimizationAlgorithm. LossIndex defines the type of error used while OptimizationAlgorithm specifies the optimization method employed by the network. The application’s specific requirements determine the selection of an appropriate error function and optimization algorithm.
As in previous chapters, we will use the iris data set to illustrate the application of some of the main methods of the TrainingStrategy class. Therefore, reading the NeuralNetwork and DataSet sections is advisable before continuing.
The most common way to create a training strategy object is by the default constructor, which depends on the DataSet and NeuralNetwork object:
// Initialize training strategy TrainingStrategy training_strategy(&neural_network, &data_set);
Once the object is created, the loss_method and training_method members are set by default in NORMALIZED_SQUARED_ERROR and QUASI_NEWTON_METHOD.TrainingStrategy class implements get and set methods for each member. Depending on the purpose of your project, you can set different errors or optimization algorithms.
In this particular example, we will use the QUASI_NEWTON_METHOD and the NORMALIZED_SQUARED_ERROR, which are set by default. If other methods are preferred, these parameters can be changed:
// Configure training strategy training_strategy.set_optimization_method( TrainingStrategy::OptimizationMethod::QUASI_NEWTON_METHOD ); training_strategy.set_loss_method( TrainingStrategy::LossMethod::NORMALIZED_SQUARED_ERROR );
The following sentence allows us to modify any parameter regarding the optimization algorithm. Some of the most useful methods can be:
// Configure Quasi-Newton optimization parameters QuasiNewtonMethod* quasi_newton_method = training_strategy.get_quasi_Newton_method(); quasi_newton_method->set_minimum_loss_decrease(type(1.0e-6)); quasi_newton_method->set_loss_goal(type(1.0e-3));
A similar procedure is followed if the parameter we want to change belongs to the loss method:
// Configure normalized squared error loss NormalizedSquaredError* normalized_squared_error = training_strategy.get_normalized_squared_error(); normalized_squared_error->set_normalization_coefficient();
Lastly, to perform the training, we use the following command:
// Run training process TrainingResults results = training_strategy.perform_training();
For more information on the TrainingStrategy class visit the TrainingStrategy Class Reference.