Training strategy class

The TrainingStrategy class represents the training strategy concept. As we said in previous chapters, this class carries out the training of a neural network. It is composed of 2 abstract classes: LossIndex and OptimizationAlgorithm. LossIndex refers to the type of error, and OptimizationAlgorithm concerns the optimization algorithm used by the neural network. The choice of a suitable error and the optimization algorithm depends on the particular application.

As in previous chapters, we will use the iris data set to show how to apply some of the main methods of the TrainingStrategy class. Therefore, it is advisable to read the NeuralNetwork class and the DataSet class sections before continuing.

The most common way to create a training strategy object is by the default constructor, which depends on the DataSet and NeuralNetwork object:

TrainingStrategy training_strategy(&neural_network,&data_set);

Once the object is created, the loss_method and training method_members are set by default in NORMALIZED_SQUARED_ERROR and QUASI_NEWTON_METHOD.

TrainingStrategy class implements get and set methods for each member. Depending on the purpose of your project, you can set different errors or optimization algorithms.

In this particular example, we will use the QUASI_NEWTON_METHOD and the NORMALIZED_SQUARED_ERROR, which are set by default. If other methods are preferred, these parameters can be changed using the following methods:


The following sentence allows us to modify any parameter regarding the optimization algorithm. Some of the most useful methods can be:

QuasiNewtonMethod* quasi_Newton_method_pointer = training_strategy.get_quasi_Newton_method_pointer();

To change any parameter from the loss method, a similar procedure is followed for the optimization algorithm.

NormalizedSquaredError* normalized_squared_error_pointer = training_strategy.get_normalized_squared_error_pointer();

Lastly, to perform the training of the network, you will need the method below:

 TrainingResults results = training_strategy.perform_training();

If you need more information about TrainingStrategy class visit TrainingStrategy Class Reference.