Training strategy class

The TrainingStrategy class represents the training strategy concept, as we said in previous chapters, this class carries out the training of a neural network. It is composed of 2 abstract classes: LossIndex and OptimizationAlgorithm. LossIndex refers to the type of error and OptimizationAlgorithm concerns to the optimization algorithm used by the neural network. The choice of a suitable error and the optimization algorithm depends on the particular application.

As in previous chapters, we will use the iris data set to show how to use some of the main methods of the TrainingStrategy class. Therefore, it is advisable to read the NeuralNetwork class and the DataSet class sections before continuing.

The most common way to create a training strategy object is by the default constructor, that depends on the DataSet and NeuralNetwork object:

TrainingStrategy training_strategy(&neural_network,&data_set);

Once the object is created, the loss_method and training method_members are set by default in NORMALIZED_SQUARED_ERROR and QUASI_NEWTON_METHOD.

TrainingStrategy class implements get and set methods for each member, depending on the purpose of your project, you can set different errors or optimization algorithms.

In this particular example, we are going to use the QUASI_NEWTON_METHOD and the NORMALIZED_SQUARED_ERROR, which are set by default. If other methods are preferred, these parameters can be changed by means of the following methods:


The following sentence allows us to modify any parameter regarding the optimization algorithm. Some of the most useful methods can be:

QuasiNewtonMethod* quasi_Newton_method_pointer = training_strategy.get_quasi_Newton_method_pointer();

In order to change any parameter from the loss method, a similar procedure is followed as for optimization algorithm.

NormalizedSquaredError* normalized_squared_error_pointer = training_strategy.get_normalized_squared_error_pointer();

Lastly to perform the training of the network you will need the method below:

 OptimizationAlgorithm::Results results = training_strategy.perform_training();

If you need more information about TrainingStrategy class visit TrainingStrategy Class Reference.

⇐ NeuralNetwork ModelSelection ⇒