This class represents the concept of training strategy for a neural network in OpenNN. More...
#include <training_strategy.h>
Public Types | |
enum class | LossMethod { SUM_SQUARED_ERROR , MEAN_SQUARED_ERROR , NORMALIZED_SQUARED_ERROR , MINKOWSKI_ERROR , WEIGHTED_SQUARED_ERROR , CROSS_ENTROPY_ERROR } |
Enumeration of available error terms in OpenNN. More... | |
enum class | OptimizationMethod { GRADIENT_DESCENT , CONJUGATE_GRADIENT , QUASI_NEWTON_METHOD , LEVENBERG_MARQUARDT_ALGORITHM , STOCHASTIC_GRADIENT_DESCENT , ADAPTIVE_MOMENT_ESTIMATION } |
Enumeration of all the available types of optimization algorithms. More... | |
Private Attributes | |
DataSet * | data_set_pointer = nullptr |
NeuralNetwork * | neural_network_pointer = nullptr |
SumSquaredError | sum_squared_error |
Pointer to the sum squared error object wich can be used as the error term. More... | |
MeanSquaredError | mean_squared_error |
Pointer to the mean squared error object wich can be used as the error term. More... | |
NormalizedSquaredError | normalized_squared_error |
Pointer to the normalized squared error object wich can be used as the error term. More... | |
MinkowskiError | Minkowski_error |
Pointer to the Mikowski error object wich can be used as the error term. More... | |
CrossEntropyError | cross_entropy_error |
Pointer to the cross entropy error object wich can be used as the error term. More... | |
WeightedSquaredError | weighted_squared_error |
Pointer to the weighted squared error object wich can be used as the error term. More... | |
LossMethod | loss_method |
Type of loss method. More... | |
GradientDescent | gradient_descent |
Gradient descent object to be used as a main optimization algorithm. More... | |
ConjugateGradient | conjugate_gradient |
Conjugate gradient object to be used as a main optimization algorithm. More... | |
QuasiNewtonMethod | quasi_Newton_method |
Quasi-Newton method object to be used as a main optimization algorithm. More... | |
LevenbergMarquardtAlgorithm | Levenberg_Marquardt_algorithm |
Levenberg-Marquardt algorithm object to be used as a main optimization algorithm. More... | |
StochasticGradientDescent | stochastic_gradient_descent |
Stochastic gradient descent algorithm object to be used as a main optimization algorithm. More... | |
AdaptiveMomentEstimation | adaptive_moment_estimation |
Adaptive moment estimation algorithm object to be used as a main optimization algorithm. More... | |
OptimizationMethod | optimization_method |
Type of main optimization algorithm. More... | |
bool | display = true |
Display messages to screen. More... | |
This class represents the concept of training strategy for a neural network in OpenNN.
A training strategy is composed of two objects:
Definition at line 55 of file training_strategy.h.
|
strong |
Enumeration of available error terms in OpenNN.
Definition at line 74 of file training_strategy.h.
|
strong |
Enumeration of all the available types of optimization algorithms.
Definition at line 86 of file training_strategy.h.
|
explicit |
Default constructor. It creates a training strategy object not associated to any loss index object. It also constructs the main optimization algorithm object.
Definition at line 19 of file training_strategy.cpp.
|
explicit |
Pointer constuctor. It creates a training strategy object not associated to any loss index object. It also loads the members of this object from NeuralNetwork and DataSet class.
Definition at line 42 of file training_strategy.cpp.
|
virtual |
Destructor. This destructor deletes the loss index and optimization algorithm objects.
Definition at line 61 of file training_strategy.cpp.
void fix_forecasting | ( | ) |
Check the time steps and the batch size in forecasting problems. The batch size must be multiple of the time step. If they are not multiples, then the batch size is changed to a multiple (the first multiple that is lower than the batch size).
Definition at line 791 of file training_strategy.cpp.
void from_XML | ( | const tinyxml2::XMLDocument & | document | ) |
Loads the members of this training strategy object from a XML document.
document | XML document of the TinyXML library. |
Definition at line 914 of file training_strategy.cpp.
AdaptiveMomentEstimation * get_adaptive_moment_estimation_pointer | ( | ) |
Returns a pointer to the adaptive moment estimation main algorithm. It also throws an exception if that pointer is nullptr.
Definition at line 192 of file training_strategy.cpp.
ConjugateGradient * get_conjugate_gradient_pointer | ( | ) |
Returns a pointer to the conjugate gradient main algorithm. It also throws an exception if that pointer is nullptr.
Definition at line 156 of file training_strategy.cpp.
CrossEntropyError * get_cross_entropy_error_pointer | ( | ) |
Returns a pointer to the cross entropy error which is used as error. If that object does not exists, an exception is thrown.
Definition at line 240 of file training_strategy.cpp.
DataSet * get_data_set_pointer | ( | ) |
Returns a pointer to the DataSet class.
Definition at line 68 of file training_strategy.cpp.
const bool & get_display | ( | ) | const |
Returns true if messages from this class can be displayed on the screen, or false if messages from this class can't be displayed on the screen.
Definition at line 416 of file training_strategy.cpp.
GradientDescent * get_gradient_descent_pointer | ( | ) |
Returns a pointer to the gradient descent main algorithm. It also throws an exception if that pointer is nullptr.
Definition at line 147 of file training_strategy.cpp.
LevenbergMarquardtAlgorithm * get_Levenberg_Marquardt_algorithm_pointer | ( | ) |
Returns a pointer to the Levenberg-Marquardt main algorithm. It also throws an exception if that pointer is nullptr.
Definition at line 174 of file training_strategy.cpp.
LossIndex * get_loss_index_pointer | ( | ) |
Returns a pointer to the LossIndex class.
Definition at line 84 of file training_strategy.cpp.
const TrainingStrategy::LossMethod & get_loss_method | ( | ) | const |
Returns the type of the main loss algorithm composing this training strategy object.
Definition at line 257 of file training_strategy.cpp.
MeanSquaredError * get_mean_squared_error_pointer | ( | ) |
Returns a pointer to the mean squared error which is used as error. If that object does not exists, an exception is thrown.
Definition at line 210 of file training_strategy.cpp.
MinkowskiError * get_Minkowski_error_pointer | ( | ) |
Returns a pointer to the Minkowski error which is used as error. If that object does not exists, an exception is thrown.
Definition at line 230 of file training_strategy.cpp.
NeuralNetwork * get_neural_network_pointer | ( | ) | const |
Returns a pointer to the NeuralNetwork class.
Definition at line 76 of file training_strategy.cpp.
NormalizedSquaredError * get_normalized_squared_error_pointer | ( | ) |
Returns a pointer to the normalized squared error which is used as error. If that object does not exists, an exception is thrown.
Definition at line 219 of file training_strategy.cpp.
OptimizationAlgorithm * get_optimization_algorithm_pointer | ( | ) |
Returns a pointer to the OptimizationAlgorithm class.
Definition at line 107 of file training_strategy.cpp.
const TrainingStrategy::OptimizationMethod & get_optimization_method | ( | ) | const |
Returns the type of the main optimization algorithm composing this training strategy object.
Definition at line 265 of file training_strategy.cpp.
QuasiNewtonMethod * get_quasi_Newton_method_pointer | ( | ) |
Returns a pointer to the Newton method main algorithm. It also throws an exception if that pointer is nullptr.
Definition at line 165 of file training_strategy.cpp.
StochasticGradientDescent * get_stochastic_gradient_descent_pointer | ( | ) |
Returns a pointer to the stochastic gradient descent main algorithm. It also throws an exception if that pointer is nullptr.
Definition at line 183 of file training_strategy.cpp.
SumSquaredError * get_sum_squared_error_pointer | ( | ) |
Returns a pointer to the sum squared error which is used as error. If that object does not exists, an exception is thrown.
Definition at line 201 of file training_strategy.cpp.
WeightedSquaredError * get_weighted_squared_error_pointer | ( | ) |
Returns a pointer to the weighted squared error which is used as error. If that object does not exists, an exception is thrown.
Definition at line 249 of file training_strategy.cpp.
bool has_data_set | ( | ) | const |
Definition at line 136 of file training_strategy.cpp.
bool has_neural_network | ( | ) | const |
Definition at line 128 of file training_strategy.cpp.
void load | ( | const string & | file_name | ) |
Loads a gradient descent object from a XML-type file. Please mind about the file format, wich is specified in the User's Guide.
file_name | Name of optimization algorithm XML-type file. |
Definition at line 1205 of file training_strategy.cpp.
TrainingResults perform_training | ( | ) |
This is the most important method of this class. It optimizes the loss index of a neural network. This method also returns a structure with the results from training.
Definition at line 719 of file training_strategy.cpp.
void print | ( | ) | const |
Prints to the screen the string representation of the training strategy object.
Definition at line 845 of file training_strategy.cpp.
void save | ( | const string & | file_name | ) | const |
Saves to a XML-type file the members of the optimization algorithm object.
file_name | Name of optimization algorithm XML-type file. |
Definition at line 1189 of file training_strategy.cpp.
void set | ( | ) |
Sets the loss index pointer to nullptr. It also destructs the loss index and the optimization algorithm. Finally, it sets the rest of members to their default values.
Definition at line 426 of file training_strategy.cpp.
void set | ( | NeuralNetwork * | new_neural_network_pointer, |
DataSet * | new_data_set_pointer | ||
) |
Definition at line 434 of file training_strategy.cpp.
void set_data_set_pointer | ( | DataSet * | new_data_set_pointer | ) |
Definition at line 556 of file training_strategy.cpp.
void set_default | ( | ) |
Sets the members of the training strategy object to their default values:
Definition at line 710 of file training_strategy.cpp.
void set_display | ( | const bool & | new_display | ) |
Sets a new display value. If it is set to true messages from this class are to be displayed on the screen; if it is set to false messages from this class are not to be displayed on the screen.
new_display | Display value. |
Definition at line 635 of file training_strategy.cpp.
void set_display_period | ( | const int & | display_period | ) |
Definition at line 688 of file training_strategy.cpp.
void set_loss_goal | ( | const type & | new_loss_goal | ) |
Definition at line 659 of file training_strategy.cpp.
void set_loss_index_data_set_pointer | ( | DataSet * | new_data_set_pointer | ) |
Definition at line 608 of file training_strategy.cpp.
void set_loss_index_neural_network_pointer | ( | NeuralNetwork * | new_neural_network_pointer | ) |
Definition at line 619 of file training_strategy.cpp.
void set_loss_index_pointer | ( | LossIndex * | new_loss_index_pointer | ) |
Sets a pointer to a loss index object to be associated to the training strategy.
new_loss_index_pointer | Pointer to a loss index object. |
Definition at line 597 of file training_strategy.cpp.
void set_loss_index_threads_number | ( | const int & | new_threads_number | ) |
Definition at line 572 of file training_strategy.cpp.
void set_loss_method | ( | const LossMethod & | new_loss_method | ) |
Sets the loss index method. If that object does not exists, an exception is thrown.
new_loss_method | New method type. |
Definition at line 489 of file training_strategy.cpp.
void set_loss_method | ( | const string & | new_loss_method | ) |
Sets the loss index method. If that object does not exists, an exception is thrown.
new_loss_method | String with the name of the new method. |
Definition at line 446 of file training_strategy.cpp.
void set_maximum_epochs_number | ( | const int & | maximum_epochs_number | ) |
Definition at line 677 of file training_strategy.cpp.
void set_maximum_selection_failures | ( | const Index & | maximum_selection_failures | ) |
Definition at line 668 of file training_strategy.cpp.
void set_maximum_time | ( | const type & | maximum_time | ) |
Definition at line 694 of file training_strategy.cpp.
void set_neural_network_pointer | ( | NeuralNetwork * | new_neural_network_pointer | ) |
Definition at line 564 of file training_strategy.cpp.
void set_optimization_algorithm_threads_number | ( | const int & | new_threads_number | ) |
Definition at line 583 of file training_strategy.cpp.
void set_optimization_method | ( | const OptimizationMethod & | new_optimization_method | ) |
Sets a new type of main optimization algorithm.
new_optimization_method | Type of main optimization algorithm. |
Definition at line 500 of file training_strategy.cpp.
void set_optimization_method | ( | const string & | new_optimization_method | ) |
Sets a new main optimization algorithm from a string containing the type.
new_optimization_method | String with the type of main optimization algorithm. |
Definition at line 509 of file training_strategy.cpp.
void set_threads_number | ( | const int & | new_threads_number | ) |
Definition at line 548 of file training_strategy.cpp.
string write_loss_method | ( | ) | const |
Returns a string with the type of the main loss algorithm composing this training strategy object.
Definition at line 273 of file training_strategy.cpp.
string write_loss_method_text | ( | ) | const |
Returns a string with the main loss method type in text format.
Definition at line 386 of file training_strategy.cpp.
string write_optimization_method | ( | ) | const |
Returns a string with the type of the main optimization algorithm composing this training strategy object. If that object does not exists, an exception is thrown.
Definition at line 303 of file training_strategy.cpp.
string write_optimization_method_text | ( | ) | const |
Returns a string with the main type in text format. If that object does not exists, an exception is thrown.
Definition at line 345 of file training_strategy.cpp.
void write_XML | ( | tinyxml2::XMLPrinter & | file_stream | ) | const |
Serializes the training strategy object into a XML document of the TinyXML library without keep the DOM tree in memory. See the OpenNN manual for more information about the format of this document.
Definition at line 856 of file training_strategy.cpp.
|
private |
Adaptive moment estimation algorithm object to be used as a main optimization algorithm.
Definition at line 246 of file training_strategy.h.
|
private |
Conjugate gradient object to be used as a main optimization algorithm.
Definition at line 230 of file training_strategy.h.
|
private |
Pointer to the cross entropy error object wich can be used as the error term.
Definition at line 212 of file training_strategy.h.
|
private |
Definition at line 188 of file training_strategy.h.
|
private |
Display messages to screen.
Definition at line 254 of file training_strategy.h.
|
private |
Gradient descent object to be used as a main optimization algorithm.
Definition at line 226 of file training_strategy.h.
|
private |
Levenberg-Marquardt algorithm object to be used as a main optimization algorithm.
Definition at line 238 of file training_strategy.h.
|
private |
Type of loss method.
Definition at line 220 of file training_strategy.h.
|
private |
Pointer to the mean squared error object wich can be used as the error term.
Definition at line 200 of file training_strategy.h.
|
private |
Pointer to the Mikowski error object wich can be used as the error term.
Definition at line 208 of file training_strategy.h.
|
private |
Definition at line 190 of file training_strategy.h.
|
private |
Pointer to the normalized squared error object wich can be used as the error term.
Definition at line 204 of file training_strategy.h.
|
private |
Type of main optimization algorithm.
Definition at line 250 of file training_strategy.h.
|
private |
Quasi-Newton method object to be used as a main optimization algorithm.
Definition at line 234 of file training_strategy.h.
|
private |
Stochastic gradient descent algorithm object to be used as a main optimization algorithm.
Definition at line 242 of file training_strategy.h.
|
private |
Pointer to the sum squared error object wich can be used as the error term.
Definition at line 196 of file training_strategy.h.
|
private |
Pointer to the weighted squared error object wich can be used as the error term.
Definition at line 216 of file training_strategy.h.