Training strategy class

The concept of training strategy is represented by TrainingStrategy class, as we said in previous chapters this class carry out the training of a neural network. It is composed of 2 abstract classes: LossIndex and TrainingAlgorithm. LossIndex referes to the type of error and TrainingAlgorithm concerns to the type of training. The choice of a suitable error as well as the training algorithm depends on the particular application.

In this tutorial we are going to use the iris data set to show how to use some of the main methods of the TrainingStrategy class so, before continuing it is advisable to read the previous chapter NeuralNetwork class.

The most common way to create a training strategy object is by the default constructor, that depends on the DataSet and NeuralNetwork object:

TrainingStrategy training_strategy(&neural_network, &data_set);

Once the object is created, the loss_method and training method_members are set by default in NORMALIZED_SQUARED_ERROR and QUASI_NEWTON_METHOD.

TrainingStrategy class implements get and set methods for each member, depending on the purpose of your project, you can set different errors or training algorithms.

In this particular example we are going to use the QUASI_NEWTON_METHOD as training algorithm, we dont need to set it because is set by default, however if you want to try another training algorithm feel free change it with the following method:

training_strategy.set_training_method(TrainingStrategy::QUASI_NEWTON_METHOD);

To modifiy some parameter in the training algorithm you just have to get the pointer to this direction and then use the set methods. Some of the most useful methods can be:

QuasiNewtonMethod* quasi_Newton_method_pointer = training_strategy.get_quasi_Newton_method_pointer();

quasi_Newton_method_pointer->set_minimum_loss_increase(1.0e-6);
quasi_Newton_method_pointer->set_loss_goal(1.0e-3);

In order to change the error method and modify its parameters, the same procedure is followed as for training algorithm.

training_strategy.set_loss_method(TrainingStrategy::NORMALIZED_SQUARED_ERROR);

Lastly to perform the training of the network you will need the method below:

 TrainingStrategy::Results results = training_strategy.perform_training();

If you need more information about TrainingStrategy class visit TrainingStrategy Class Reference.
⇐ NeuralNetwork ModelSelection ⇒