TrainingStrategy Class Reference

This class represents the concept of training strategy for a neural network in OpenNN. More...

#include <training_strategy.h>

Public Types

enum class  LossMethod {
  SUM_SQUARED_ERROR , MEAN_SQUARED_ERROR , NORMALIZED_SQUARED_ERROR , MINKOWSKI_ERROR ,
  WEIGHTED_SQUARED_ERROR , CROSS_ENTROPY_ERROR
}
 Enumeration of available error terms in OpenNN. More...
 
enum class  OptimizationMethod {
  GRADIENT_DESCENT , CONJUGATE_GRADIENT , QUASI_NEWTON_METHOD , LEVENBERG_MARQUARDT_ALGORITHM ,
  STOCHASTIC_GRADIENT_DESCENT , ADAPTIVE_MOMENT_ESTIMATION
}
 Enumeration of all the available types of optimization algorithms. More...
 

Public Member Functions

 TrainingStrategy ()
 
 TrainingStrategy (NeuralNetwork *, DataSet *)
 
virtual ~TrainingStrategy ()
 
DataSetget_data_set_pointer ()
 Returns a pointer to the DataSet class. More...
 
NeuralNetworkget_neural_network_pointer () const
 Returns a pointer to the NeuralNetwork class. More...
 
LossIndexget_loss_index_pointer ()
 Returns a pointer to the LossIndex class. More...
 
OptimizationAlgorithmget_optimization_algorithm_pointer ()
 Returns a pointer to the OptimizationAlgorithm class. More...
 
bool has_neural_network () const
 
bool has_data_set () const
 
SumSquaredErrorget_sum_squared_error_pointer ()
 
MeanSquaredErrorget_mean_squared_error_pointer ()
 
NormalizedSquaredErrorget_normalized_squared_error_pointer ()
 
MinkowskiErrorget_Minkowski_error_pointer ()
 
CrossEntropyErrorget_cross_entropy_error_pointer ()
 
WeightedSquaredErrorget_weighted_squared_error_pointer ()
 
GradientDescentget_gradient_descent_pointer ()
 
ConjugateGradientget_conjugate_gradient_pointer ()
 
QuasiNewtonMethodget_quasi_Newton_method_pointer ()
 
LevenbergMarquardtAlgorithmget_Levenberg_Marquardt_algorithm_pointer ()
 
StochasticGradientDescentget_stochastic_gradient_descent_pointer ()
 
AdaptiveMomentEstimationget_adaptive_moment_estimation_pointer ()
 
const LossMethodget_loss_method () const
 Returns the type of the main loss algorithm composing this training strategy object. More...
 
const OptimizationMethodget_optimization_method () const
 Returns the type of the main optimization algorithm composing this training strategy object. More...
 
string write_loss_method () const
 Returns a string with the type of the main loss algorithm composing this training strategy object. More...
 
string write_optimization_method () const
 
string write_optimization_method_text () const
 
string write_loss_method_text () const
 Returns a string with the main loss method type in text format. More...
 
const bool & get_display () const
 
void set ()
 
void set (NeuralNetwork *, DataSet *)
 
void set_default ()
 
void set_threads_number (const int &)
 
void set_data_set_pointer (DataSet *)
 
void set_neural_network_pointer (NeuralNetwork *)
 
void set_loss_index_threads_number (const int &)
 
void set_optimization_algorithm_threads_number (const int &)
 
void set_loss_index_pointer (LossIndex *)
 
void set_loss_index_data_set_pointer (DataSet *)
 
void set_loss_index_neural_network_pointer (NeuralNetwork *)
 
void set_loss_method (const LossMethod &)
 
void set_optimization_method (const OptimizationMethod &)
 
void set_loss_method (const string &)
 
void set_optimization_method (const string &)
 
void set_display (const bool &)
 
void set_loss_goal (const type &)
 
void set_maximum_selection_failures (const Index &)
 
void set_maximum_epochs_number (const int &)
 
void set_display_period (const int &)
 
void set_maximum_time (const type &)
 
TrainingResults perform_training ()
 
void fix_forecasting ()
 
void print () const
 Prints to the screen the string representation of the training strategy object. More...
 
void from_XML (const tinyxml2::XMLDocument &)
 
void write_XML (tinyxml2::XMLPrinter &) const
 
void save (const string &) const
 
void load (const string &)
 

Private Attributes

DataSetdata_set_pointer = nullptr
 
NeuralNetworkneural_network_pointer = nullptr
 
SumSquaredError sum_squared_error
 Pointer to the sum squared error object wich can be used as the error term. More...
 
MeanSquaredError mean_squared_error
 Pointer to the mean squared error object wich can be used as the error term. More...
 
NormalizedSquaredError normalized_squared_error
 Pointer to the normalized squared error object wich can be used as the error term. More...
 
MinkowskiError Minkowski_error
 Pointer to the Mikowski error object wich can be used as the error term. More...
 
CrossEntropyError cross_entropy_error
 Pointer to the cross entropy error object wich can be used as the error term. More...
 
WeightedSquaredError weighted_squared_error
 Pointer to the weighted squared error object wich can be used as the error term. More...
 
LossMethod loss_method
 Type of loss method. More...
 
GradientDescent gradient_descent
 Gradient descent object to be used as a main optimization algorithm. More...
 
ConjugateGradient conjugate_gradient
 Conjugate gradient object to be used as a main optimization algorithm. More...
 
QuasiNewtonMethod quasi_Newton_method
 Quasi-Newton method object to be used as a main optimization algorithm. More...
 
LevenbergMarquardtAlgorithm Levenberg_Marquardt_algorithm
 Levenberg-Marquardt algorithm object to be used as a main optimization algorithm. More...
 
StochasticGradientDescent stochastic_gradient_descent
 Stochastic gradient descent algorithm object to be used as a main optimization algorithm. More...
 
AdaptiveMomentEstimation adaptive_moment_estimation
 Adaptive moment estimation algorithm object to be used as a main optimization algorithm. More...
 
OptimizationMethod optimization_method
 Type of main optimization algorithm. More...
 
bool display = true
 Display messages to screen. More...
 

Detailed Description

This class represents the concept of training strategy for a neural network in OpenNN.

A training strategy is composed of two objects:

  • Loss index.
  • Optimization algorithm.


Definition at line 55 of file training_strategy.h.

Member Enumeration Documentation

◆ LossMethod

enum class LossMethod
strong

Enumeration of available error terms in OpenNN.

Definition at line 74 of file training_strategy.h.

◆ OptimizationMethod

enum class OptimizationMethod
strong

Enumeration of all the available types of optimization algorithms.

Definition at line 86 of file training_strategy.h.

Constructor & Destructor Documentation

◆ TrainingStrategy() [1/2]

TrainingStrategy ( )
explicit

Default constructor. It creates a training strategy object not associated to any loss index object. It also constructs the main optimization algorithm object.

Definition at line 19 of file training_strategy.cpp.

◆ TrainingStrategy() [2/2]

TrainingStrategy ( NeuralNetwork new_neural_network_pointer,
DataSet new_data_set_pointer 
)
explicit

Pointer constuctor. It creates a training strategy object not associated to any loss index object. It also loads the members of this object from NeuralNetwork and DataSet class.

Definition at line 42 of file training_strategy.cpp.

◆ ~TrainingStrategy()

~TrainingStrategy ( )
virtual

Destructor. This destructor deletes the loss index and optimization algorithm objects.

Definition at line 61 of file training_strategy.cpp.

Member Function Documentation

◆ fix_forecasting()

void fix_forecasting ( )

Check the time steps and the batch size in forecasting problems. The batch size must be multiple of the time step. If they are not multiples, then the batch size is changed to a multiple (the first multiple that is lower than the batch size).

Definition at line 791 of file training_strategy.cpp.

◆ from_XML()

void from_XML ( const tinyxml2::XMLDocument document)

Loads the members of this training strategy object from a XML document.

Parameters
documentXML document of the TinyXML library.

Definition at line 914 of file training_strategy.cpp.

◆ get_adaptive_moment_estimation_pointer()

AdaptiveMomentEstimation * get_adaptive_moment_estimation_pointer ( )

Returns a pointer to the adaptive moment estimation main algorithm. It also throws an exception if that pointer is nullptr.

Definition at line 192 of file training_strategy.cpp.

◆ get_conjugate_gradient_pointer()

ConjugateGradient * get_conjugate_gradient_pointer ( )

Returns a pointer to the conjugate gradient main algorithm. It also throws an exception if that pointer is nullptr.

Definition at line 156 of file training_strategy.cpp.

◆ get_cross_entropy_error_pointer()

CrossEntropyError * get_cross_entropy_error_pointer ( )

Returns a pointer to the cross entropy error which is used as error. If that object does not exists, an exception is thrown.

Definition at line 240 of file training_strategy.cpp.

◆ get_data_set_pointer()

DataSet * get_data_set_pointer ( )

Returns a pointer to the DataSet class.

Definition at line 68 of file training_strategy.cpp.

◆ get_display()

const bool & get_display ( ) const

Returns true if messages from this class can be displayed on the screen, or false if messages from this class can't be displayed on the screen.

Definition at line 416 of file training_strategy.cpp.

◆ get_gradient_descent_pointer()

GradientDescent * get_gradient_descent_pointer ( )

Returns a pointer to the gradient descent main algorithm. It also throws an exception if that pointer is nullptr.

Definition at line 147 of file training_strategy.cpp.

◆ get_Levenberg_Marquardt_algorithm_pointer()

LevenbergMarquardtAlgorithm * get_Levenberg_Marquardt_algorithm_pointer ( )

Returns a pointer to the Levenberg-Marquardt main algorithm. It also throws an exception if that pointer is nullptr.

Definition at line 174 of file training_strategy.cpp.

◆ get_loss_index_pointer()

LossIndex * get_loss_index_pointer ( )

Returns a pointer to the LossIndex class.

Definition at line 84 of file training_strategy.cpp.

◆ get_loss_method()

const TrainingStrategy::LossMethod & get_loss_method ( ) const

Returns the type of the main loss algorithm composing this training strategy object.

Definition at line 257 of file training_strategy.cpp.

◆ get_mean_squared_error_pointer()

MeanSquaredError * get_mean_squared_error_pointer ( )

Returns a pointer to the mean squared error which is used as error. If that object does not exists, an exception is thrown.

Definition at line 210 of file training_strategy.cpp.

◆ get_Minkowski_error_pointer()

MinkowskiError * get_Minkowski_error_pointer ( )

Returns a pointer to the Minkowski error which is used as error. If that object does not exists, an exception is thrown.

Definition at line 230 of file training_strategy.cpp.

◆ get_neural_network_pointer()

NeuralNetwork * get_neural_network_pointer ( ) const

Returns a pointer to the NeuralNetwork class.

Definition at line 76 of file training_strategy.cpp.

◆ get_normalized_squared_error_pointer()

NormalizedSquaredError * get_normalized_squared_error_pointer ( )

Returns a pointer to the normalized squared error which is used as error. If that object does not exists, an exception is thrown.

Definition at line 219 of file training_strategy.cpp.

◆ get_optimization_algorithm_pointer()

OptimizationAlgorithm * get_optimization_algorithm_pointer ( )

Returns a pointer to the OptimizationAlgorithm class.

Definition at line 107 of file training_strategy.cpp.

◆ get_optimization_method()

const TrainingStrategy::OptimizationMethod & get_optimization_method ( ) const

Returns the type of the main optimization algorithm composing this training strategy object.

Definition at line 265 of file training_strategy.cpp.

◆ get_quasi_Newton_method_pointer()

QuasiNewtonMethod * get_quasi_Newton_method_pointer ( )

Returns a pointer to the Newton method main algorithm. It also throws an exception if that pointer is nullptr.

Definition at line 165 of file training_strategy.cpp.

◆ get_stochastic_gradient_descent_pointer()

StochasticGradientDescent * get_stochastic_gradient_descent_pointer ( )

Returns a pointer to the stochastic gradient descent main algorithm. It also throws an exception if that pointer is nullptr.

Definition at line 183 of file training_strategy.cpp.

◆ get_sum_squared_error_pointer()

SumSquaredError * get_sum_squared_error_pointer ( )

Returns a pointer to the sum squared error which is used as error. If that object does not exists, an exception is thrown.

Definition at line 201 of file training_strategy.cpp.

◆ get_weighted_squared_error_pointer()

WeightedSquaredError * get_weighted_squared_error_pointer ( )

Returns a pointer to the weighted squared error which is used as error. If that object does not exists, an exception is thrown.

Definition at line 249 of file training_strategy.cpp.

◆ has_data_set()

bool has_data_set ( ) const

Definition at line 136 of file training_strategy.cpp.

◆ has_neural_network()

bool has_neural_network ( ) const

Definition at line 128 of file training_strategy.cpp.

◆ load()

void load ( const string &  file_name)

Loads a gradient descent object from a XML-type file. Please mind about the file format, wich is specified in the User's Guide.

Parameters
file_nameName of optimization algorithm XML-type file.

Definition at line 1205 of file training_strategy.cpp.

◆ perform_training()

TrainingResults perform_training ( )

This is the most important method of this class. It optimizes the loss index of a neural network. This method also returns a structure with the results from training.

Definition at line 719 of file training_strategy.cpp.

◆ print()

void print ( ) const

Prints to the screen the string representation of the training strategy object.

Definition at line 845 of file training_strategy.cpp.

◆ save()

void save ( const string &  file_name) const

Saves to a XML-type file the members of the optimization algorithm object.

Parameters
file_nameName of optimization algorithm XML-type file.

Definition at line 1189 of file training_strategy.cpp.

◆ set() [1/2]

void set ( )

Sets the loss index pointer to nullptr. It also destructs the loss index and the optimization algorithm. Finally, it sets the rest of members to their default values.

Definition at line 426 of file training_strategy.cpp.

◆ set() [2/2]

void set ( NeuralNetwork new_neural_network_pointer,
DataSet new_data_set_pointer 
)

Definition at line 434 of file training_strategy.cpp.

◆ set_data_set_pointer()

void set_data_set_pointer ( DataSet new_data_set_pointer)

Definition at line 556 of file training_strategy.cpp.

◆ set_default()

void set_default ( )

Sets the members of the training strategy object to their default values:

  • Display: true.

Definition at line 710 of file training_strategy.cpp.

◆ set_display()

void set_display ( const bool &  new_display)

Sets a new display value. If it is set to true messages from this class are to be displayed on the screen; if it is set to false messages from this class are not to be displayed on the screen.

Parameters
new_displayDisplay value.

Definition at line 635 of file training_strategy.cpp.

◆ set_display_period()

void set_display_period ( const int &  display_period)

Definition at line 688 of file training_strategy.cpp.

◆ set_loss_goal()

void set_loss_goal ( const type &  new_loss_goal)

Definition at line 659 of file training_strategy.cpp.

◆ set_loss_index_data_set_pointer()

void set_loss_index_data_set_pointer ( DataSet new_data_set_pointer)

Definition at line 608 of file training_strategy.cpp.

◆ set_loss_index_neural_network_pointer()

void set_loss_index_neural_network_pointer ( NeuralNetwork new_neural_network_pointer)

Definition at line 619 of file training_strategy.cpp.

◆ set_loss_index_pointer()

void set_loss_index_pointer ( LossIndex new_loss_index_pointer)

Sets a pointer to a loss index object to be associated to the training strategy.

Parameters
new_loss_index_pointerPointer to a loss index object.

Definition at line 597 of file training_strategy.cpp.

◆ set_loss_index_threads_number()

void set_loss_index_threads_number ( const int &  new_threads_number)

Definition at line 572 of file training_strategy.cpp.

◆ set_loss_method() [1/2]

void set_loss_method ( const LossMethod new_loss_method)

Sets the loss index method. If that object does not exists, an exception is thrown.

Parameters
new_loss_methodNew method type.

Definition at line 489 of file training_strategy.cpp.

◆ set_loss_method() [2/2]

void set_loss_method ( const string &  new_loss_method)

Sets the loss index method. If that object does not exists, an exception is thrown.

Parameters
new_loss_methodString with the name of the new method.

Definition at line 446 of file training_strategy.cpp.

◆ set_maximum_epochs_number()

void set_maximum_epochs_number ( const int &  maximum_epochs_number)

Definition at line 677 of file training_strategy.cpp.

◆ set_maximum_selection_failures()

void set_maximum_selection_failures ( const Index &  maximum_selection_failures)

Definition at line 668 of file training_strategy.cpp.

◆ set_maximum_time()

void set_maximum_time ( const type &  maximum_time)

Definition at line 694 of file training_strategy.cpp.

◆ set_neural_network_pointer()

void set_neural_network_pointer ( NeuralNetwork new_neural_network_pointer)

Definition at line 564 of file training_strategy.cpp.

◆ set_optimization_algorithm_threads_number()

void set_optimization_algorithm_threads_number ( const int &  new_threads_number)

Definition at line 583 of file training_strategy.cpp.

◆ set_optimization_method() [1/2]

void set_optimization_method ( const OptimizationMethod new_optimization_method)

Sets a new type of main optimization algorithm.

Parameters
new_optimization_methodType of main optimization algorithm.

Definition at line 500 of file training_strategy.cpp.

◆ set_optimization_method() [2/2]

void set_optimization_method ( const string &  new_optimization_method)

Sets a new main optimization algorithm from a string containing the type.

Parameters
new_optimization_methodString with the type of main optimization algorithm.

Definition at line 509 of file training_strategy.cpp.

◆ set_threads_number()

void set_threads_number ( const int &  new_threads_number)

Definition at line 548 of file training_strategy.cpp.

◆ write_loss_method()

string write_loss_method ( ) const

Returns a string with the type of the main loss algorithm composing this training strategy object.

Definition at line 273 of file training_strategy.cpp.

◆ write_loss_method_text()

string write_loss_method_text ( ) const

Returns a string with the main loss method type in text format.

Definition at line 386 of file training_strategy.cpp.

◆ write_optimization_method()

string write_optimization_method ( ) const

Returns a string with the type of the main optimization algorithm composing this training strategy object. If that object does not exists, an exception is thrown.

Definition at line 303 of file training_strategy.cpp.

◆ write_optimization_method_text()

string write_optimization_method_text ( ) const

Returns a string with the main type in text format. If that object does not exists, an exception is thrown.

Definition at line 345 of file training_strategy.cpp.

◆ write_XML()

void write_XML ( tinyxml2::XMLPrinter file_stream) const

Serializes the training strategy object into a XML document of the TinyXML library without keep the DOM tree in memory. See the OpenNN manual for more information about the format of this document.

Definition at line 856 of file training_strategy.cpp.

Member Data Documentation

◆ adaptive_moment_estimation

AdaptiveMomentEstimation adaptive_moment_estimation
private

Adaptive moment estimation algorithm object to be used as a main optimization algorithm.

Definition at line 246 of file training_strategy.h.

◆ conjugate_gradient

ConjugateGradient conjugate_gradient
private

Conjugate gradient object to be used as a main optimization algorithm.

Definition at line 230 of file training_strategy.h.

◆ cross_entropy_error

CrossEntropyError cross_entropy_error
private

Pointer to the cross entropy error object wich can be used as the error term.

Definition at line 212 of file training_strategy.h.

◆ data_set_pointer

DataSet* data_set_pointer = nullptr
private

Definition at line 188 of file training_strategy.h.

◆ display

bool display = true
private

Display messages to screen.

Definition at line 254 of file training_strategy.h.

◆ gradient_descent

GradientDescent gradient_descent
private

Gradient descent object to be used as a main optimization algorithm.

Definition at line 226 of file training_strategy.h.

◆ Levenberg_Marquardt_algorithm

LevenbergMarquardtAlgorithm Levenberg_Marquardt_algorithm
private

Levenberg-Marquardt algorithm object to be used as a main optimization algorithm.

Definition at line 238 of file training_strategy.h.

◆ loss_method

LossMethod loss_method
private

Type of loss method.

Definition at line 220 of file training_strategy.h.

◆ mean_squared_error

MeanSquaredError mean_squared_error
private

Pointer to the mean squared error object wich can be used as the error term.

Definition at line 200 of file training_strategy.h.

◆ Minkowski_error

MinkowskiError Minkowski_error
private

Pointer to the Mikowski error object wich can be used as the error term.

Definition at line 208 of file training_strategy.h.

◆ neural_network_pointer

NeuralNetwork* neural_network_pointer = nullptr
private

Definition at line 190 of file training_strategy.h.

◆ normalized_squared_error

NormalizedSquaredError normalized_squared_error
private

Pointer to the normalized squared error object wich can be used as the error term.

Definition at line 204 of file training_strategy.h.

◆ optimization_method

OptimizationMethod optimization_method
private

Type of main optimization algorithm.

Definition at line 250 of file training_strategy.h.

◆ quasi_Newton_method

QuasiNewtonMethod quasi_Newton_method
private

Quasi-Newton method object to be used as a main optimization algorithm.

Definition at line 234 of file training_strategy.h.

◆ stochastic_gradient_descent

StochasticGradientDescent stochastic_gradient_descent
private

Stochastic gradient descent algorithm object to be used as a main optimization algorithm.

Definition at line 242 of file training_strategy.h.

◆ sum_squared_error

SumSquaredError sum_squared_error
private

Pointer to the sum squared error object wich can be used as the error term.

Definition at line 196 of file training_strategy.h.

◆ weighted_squared_error

WeightedSquaredError weighted_squared_error
private

Pointer to the weighted squared error object wich can be used as the error term.

Definition at line 216 of file training_strategy.h.


The documentation for this class was generated from the following files: