TrainingStrategy Class Reference

This class represents the concept of training strategy for a neural network in OpenNN. More...

#include <training_strategy.h>

Public Types

enum  LossMethod {
  SUM_SQUARED_ERROR, MEAN_SQUARED_ERROR, NORMALIZED_SQUARED_ERROR, MINKOWSKI_ERROR,
  WEIGHTED_SQUARED_ERROR, CROSS_ENTROPY_ERROR
}
 Enumeration of available error terms in OpenNN.
 
enum  OptimizationMethod {
  GRADIENT_DESCENT, CONJUGATE_GRADIENT, QUASI_NEWTON_METHOD, LEVENBERG_MARQUARDT_ALGORITHM,
  STOCHASTIC_GRADIENT_DESCENT, ADAPTIVE_MOMENT_ESTIMATION
}
 Enumeration of all the available types of optimization algorithms.
 

Public Member Functions

 TrainingStrategy ()
 
 TrainingStrategy (NeuralNetwork *, DataSet *)
 
 TrainingStrategy (const tinyxml2::XMLDocument &)
 
 TrainingStrategy (const string &)
 
virtual ~TrainingStrategy ()
 
NeuralNetworkget_neural_network_pointer () const
 Returns a pointer to the NeuralNetwork class.
 
LossIndexget_loss_index_pointer () const
 Returns a pointer to the LossIndex class.
 
OptimizationAlgorithmget_optimization_algorithm_pointer () const
 Returns a pointer to the OptimizationAlgorithm class.
 
bool has_neural_network () const
 
bool has_data_set () const
 
bool has_loss_index () const
 Return true if contain loss index, and false otherwise.
 
bool has_optimization_algorithm () const
 
GradientDescentget_gradient_descent_pointer () const
 
ConjugateGradientget_conjugate_gradient_pointer () const
 
QuasiNewtonMethodget_quasi_Newton_method_pointer () const
 
LevenbergMarquardtAlgorithmget_Levenberg_Marquardt_algorithm_pointer () const
 
StochasticGradientDescentget_stochastic_gradient_descent_pointer () const
 
AdaptiveMomentEstimationget_adaptive_moment_estimation_pointer () const
 
SumSquaredErrorget_sum_squared_error_pointer () const
 
MeanSquaredErrorget_mean_squared_error_pointer () const
 
NormalizedSquaredErrorget_normalized_squared_error_pointer () const
 
MinkowskiErrorget_Minkowski_error_pointer () const
 
CrossEntropyErrorget_cross_entropy_error_pointer () const
 
WeightedSquaredErrorget_weighted_squared_error_pointer () const
 
const LossMethodget_loss_method () const
 Returns the type of the main loss algorithm composing this training strategy object.
 
const OptimizationMethodget_optimization_method () const
 Returns the type of the main optimization algorithm composing this training strategy object.
 
string write_loss_method () const
 Returns a string with the type of the main loss algorithm composing this training strategy object.
 
string write_optimization_method () const
 
string write_optimization_method_text () const
 
string write_loss_method_text () const
 Returns a string with the main loss method type in text format.
 
const bool & get_display () const
 
void set ()
 
void set_default ()
 
void set_loss_index_pointer (LossIndex *)
 
void set_loss_method (const LossMethod &)
 
void set_optimization_method (const OptimizationMethod &)
 
void set_loss_method (const string &)
 
void set_optimization_method (const string &)
 
void set_display (const bool &)
 
void destruct_optimization_algorithm ()
 This method deletes the optimization algorithm object which composes this training strategy object.
 
OptimizationAlgorithm::Results perform_training () const
 
void perform_training_void () const
 Perfom the training with the selected method.
 
bool check_forecasting () const
 Check the time steps and the batch size in forecasting problems.
 
string object_to_string () const
 Returns a string representation of the training strategy.
 
void print () const
 Prints to the screen the string representation of the training strategy object.
 
tinyxml2::XMLDocumentto_XML () const
 
void from_XML (const tinyxml2::XMLDocument &)
 
void write_XML (tinyxml2::XMLPrinter &) const
 
void save (const string &) const
 
void load (const string &)
 

Private Attributes

DataSetdata_set_pointer = nullptr
 
NeuralNetworkneural_network_pointer = nullptr
 
SumSquaredErrorsum_squared_error_pointer = nullptr
 Pointer to the sum squared error object wich can be used as the error term.
 
MeanSquaredErrormean_squared_error_pointer = nullptr
 Pointer to the mean squared error object wich can be used as the error term.
 
NormalizedSquaredErrornormalized_squared_error_pointer = nullptr
 Pointer to the normalized squared error object wich can be used as the error term.
 
MinkowskiErrorMinkowski_error_pointer = nullptr
 Pointer to the Mikowski error object wich can be used as the error term.
 
CrossEntropyErrorcross_entropy_error_pointer = nullptr
 Pointer to the cross entropy error object wich can be used as the error term.
 
WeightedSquaredErrorweighted_squared_error_pointer = nullptr
 Pointer to the weighted squared error object wich can be used as the error term.
 
LossMethod loss_method
 Type of loss method.
 
GradientDescentgradient_descent_pointer = nullptr
 Pointer to a gradient descent object to be used as a main optimization algorithm.
 
ConjugateGradientconjugate_gradient_pointer = nullptr
 Pointer to a conjugate gradient object to be used as a main optimization algorithm.
 
QuasiNewtonMethodquasi_Newton_method_pointer = nullptr
 Pointer to a quasi-Newton method object to be used as a main optimization algorithm.
 
LevenbergMarquardtAlgorithmLevenberg_Marquardt_algorithm_pointer = nullptr
 Pointer to a Levenberg-Marquardt algorithm object to be used as a main optimization algorithm.
 
StochasticGradientDescentstochastic_gradient_descent_pointer = nullptr
 Pointer to a stochastic gradient descent algorithm object to be used as a main optimization algorithm.
 
AdaptiveMomentEstimationadaptive_moment_estimation_pointer = nullptr
 Pointer to a adaptive moment estimation algorithm object to be used as a main optimization algorithm.
 
OptimizationMethod optimization_method
 Type of main optimization algorithm.
 
bool display
 Display messages to screen.
 

Detailed Description

This class represents the concept of training strategy for a neural network in OpenNN.

A training strategy is composed of two objects:

  • Loss index.
  • Optimization algorithm.


Definition at line 56 of file training_strategy.h.

Constructor & Destructor Documentation

◆ TrainingStrategy() [1/4]

TrainingStrategy ( )
explicit

Default constructor. It creates a training strategy object not associated to any loss index object. It also constructs the main optimization algorithm object.

Definition at line 18 of file training_strategy.cpp.

◆ TrainingStrategy() [2/4]

TrainingStrategy ( NeuralNetwork new_neural_network_pointer,
DataSet new_data_set_pointer 
)
explicit

Pointer constuctor. It creates a training strategy object not associated to any loss index object. It also loads the members of this object from NeuralNetwork and DataSet class.

Definition at line 36 of file training_strategy.cpp.

◆ TrainingStrategy() [3/4]

TrainingStrategy ( const tinyxml2::XMLDocument document)
explicit

XML constructor. It creates a training strategy object not associated to any loss index object. It also loads the members of this object from a XML document.

Parameters
documentDocument of the TinyXML library.

Definition at line 55 of file training_strategy.cpp.

◆ TrainingStrategy() [4/4]

TrainingStrategy ( const string &  file_name)
explicit

File constructor. It creates a training strategy object associated to a loss index object. It also loads the members of this object from a XML file.

Parameters
file_nameName of training strategy XML file.

Definition at line 70 of file training_strategy.cpp.

◆ ~TrainingStrategy()

~TrainingStrategy ( )
virtual

Destructor. This destructor deletes the loss index and optimization algorithm objects.

Definition at line 83 of file training_strategy.cpp.

Member Function Documentation

◆ from_XML()

void from_XML ( const tinyxml2::XMLDocument document)

Loads the members of this training strategy object from a XML document.

Parameters
documentXML document of the TinyXML library.

Definition at line 1577 of file training_strategy.cpp.

◆ get_adaptive_moment_estimation_pointer()

AdaptiveMomentEstimation * get_adaptive_moment_estimation_pointer ( ) const

Returns a pointer to the adaptive moment estimation main algorithm. It also throws an exception if that pointer is nullptr.

Definition at line 290 of file training_strategy.cpp.

◆ get_conjugate_gradient_pointer()

ConjugateGradient * get_conjugate_gradient_pointer ( ) const

Returns a pointer to the conjugate gradient main algorithm. It also throws an exception if that pointer is nullptr.

Definition at line 210 of file training_strategy.cpp.

◆ get_cross_entropy_error_pointer()

CrossEntropyError * get_cross_entropy_error_pointer ( ) const

Returns a pointer to the cross entropy error which is used as error. If that object does not exists, an exception is thrown.

Definition at line 412 of file training_strategy.cpp.

◆ get_display()

const bool & get_display ( ) const

Returns true if messages from this class can be displayed on the screen, or false if messages from this class can't be displayed on the screen.

Definition at line 647 of file training_strategy.cpp.

◆ get_gradient_descent_pointer()

GradientDescent * get_gradient_descent_pointer ( ) const

Returns a pointer to the gradient descent main algorithm. It also throws an exception if that pointer is nullptr.

Definition at line 190 of file training_strategy.cpp.

◆ get_Levenberg_Marquardt_algorithm_pointer()

LevenbergMarquardtAlgorithm * get_Levenberg_Marquardt_algorithm_pointer ( ) const

Returns a pointer to the Levenberg-Marquardt main algorithm. It also throws an exception if that pointer is nullptr.

Definition at line 250 of file training_strategy.cpp.

◆ get_mean_squared_error_pointer()

MeanSquaredError * get_mean_squared_error_pointer ( ) const

Returns a pointer to the mean squared error which is used as error. If that object does not exists, an exception is thrown.

Definition at line 334 of file training_strategy.cpp.

◆ get_Minkowski_error_pointer()

MinkowskiError * get_Minkowski_error_pointer ( ) const

Returns a pointer to the Minkowski error which is used as error. If that object does not exists, an exception is thrown.

Definition at line 386 of file training_strategy.cpp.

◆ get_normalized_squared_error_pointer()

NormalizedSquaredError * get_normalized_squared_error_pointer ( ) const

Returns a pointer to the normalized squared error which is used as error. If that object does not exists, an exception is thrown.

Definition at line 358 of file training_strategy.cpp.

◆ get_quasi_Newton_method_pointer()

QuasiNewtonMethod * get_quasi_Newton_method_pointer ( ) const

Returns a pointer to the Newton method main algorithm. It also throws an exception if that pointer is nullptr.

Definition at line 230 of file training_strategy.cpp.

◆ get_stochastic_gradient_descent_pointer()

StochasticGradientDescent * get_stochastic_gradient_descent_pointer ( ) const

Returns a pointer to the stochastic gradient descent main algorithm. It also throws an exception if that pointer is nullptr.

Definition at line 270 of file training_strategy.cpp.

◆ get_sum_squared_error_pointer()

SumSquaredError * get_sum_squared_error_pointer ( ) const

Returns a pointer to the sum squared error which is used as error. If that object does not exists, an exception is thrown.

Definition at line 310 of file training_strategy.cpp.

◆ get_weighted_squared_error_pointer()

WeightedSquaredError * get_weighted_squared_error_pointer ( ) const

Returns a pointer to the weighted squared error which is used as error. If that object does not exists, an exception is thrown.

Definition at line 440 of file training_strategy.cpp.

◆ load()

void load ( const string &  file_name)

Loads a gradient descent object from a XML-type file. Please mind about the file format, wich is specified in the User's Guide.

Parameters
file_nameName of optimization algorithm XML-type file.

Definition at line 1875 of file training_strategy.cpp.

◆ perform_training()

OptimizationAlgorithm::Results perform_training ( ) const

This is the most important method of this class. It optimizes the loss index of a neural network. This method also returns a structure with the results from training.

Definition at line 1018 of file training_strategy.cpp.

◆ save()

void save ( const string &  file_name) const

Saves to a XML-type file the members of the optimization algorithm object.

Parameters
file_nameName of optimization algorithm XML-type file.

Definition at line 1861 of file training_strategy.cpp.

◆ set()

void set ( )

Sets the loss index pointer to nullptr. It also destructs the loss index and the optimization algorithm. Finally, it sets the rest of members to their default values.

Definition at line 657 of file training_strategy.cpp.

◆ set_default()

void set_default ( )

Sets the members of the training strategy object to their default values:

  • Display: true.


Definition at line 988 of file training_strategy.cpp.

◆ set_display()

void set_display ( const bool &  new_display)

Sets a new display value. If it is set to true messages from this class are to be displayed on the screen; if it is set to false messages from this class are not to be displayed on the screen.

Parameters
new_displayDisplay value.

Definition at line 938 of file training_strategy.cpp.

◆ set_loss_index_pointer()

void set_loss_index_pointer ( LossIndex new_loss_index_pointer)

Sets a pointer to a loss index object to be associated to the training strategy.

Parameters
new_loss_index_pointerPointer to a loss index object.

Definition at line 887 of file training_strategy.cpp.

◆ set_loss_method() [1/2]

void set_loss_method ( const LossMethod new_loss_method)

Sets the loss index method. If that object does not exists, an exception is thrown.

Parameters
new_loss_methodNew method type.

Definition at line 712 of file training_strategy.cpp.

◆ set_loss_method() [2/2]

void set_loss_method ( const string &  new_loss_method)

Sets the loss index method. If that object does not exists, an exception is thrown.

Parameters
new_loss_methodString with the name of the new method.

Definition at line 669 of file training_strategy.cpp.

◆ set_optimization_method() [1/2]

void set_optimization_method ( const OptimizationMethod new_optimization_method)

Sets a new type of main optimization algorithm.

Parameters
new_optimization_methodType of main optimization algorithm.

Definition at line 795 of file training_strategy.cpp.

◆ set_optimization_method() [2/2]

void set_optimization_method ( const string &  new_optimization_method)

Sets a new main optimization algorithm from a string containing the type.

Parameters
new_optimization_methodString with the type of main optimization algorithm.

Definition at line 845 of file training_strategy.cpp.

◆ to_XML()

tinyxml2::XMLDocument * to_XML ( ) const

Returns a default string representation in XML-type format of the optimization algorithm object. This containts the training operators, the training parameters, stopping criteria and other stuff.

Definition at line 1268 of file training_strategy.cpp.

◆ write_optimization_method()

string write_optimization_method ( ) const

Returns a string with the type of the main optimization algorithm composing this training strategy object. If that object does not exists, an exception is thrown.

Definition at line 522 of file training_strategy.cpp.

◆ write_optimization_method_text()

string write_optimization_method_text ( ) const

Returns a string with the main type in text format. If that object does not exists, an exception is thrown.

Definition at line 564 of file training_strategy.cpp.

◆ write_XML()

void write_XML ( tinyxml2::XMLPrinter file_stream) const

Serializes the training strategy object into a XML document of the TinyXML library without keep the DOM tree in memory. See the OpenNN manual for more information about the format of this document.

Definition at line 1425 of file training_strategy.cpp.


The documentation for this class was generated from the following files: