9#include "training_strategy.h"
10#include "optimization_algorithm.h"
21 data_set_pointer =
nullptr;
23 neural_network_pointer =
nullptr;
44 data_set_pointer = new_data_set_pointer;
46 neural_network_pointer = new_neural_network_pointer;
51 set_loss_index_neural_network_pointer(neural_network_pointer);
52 set_loss_index_data_set_pointer(data_set_pointer);
70 return data_set_pointer;
78 return neural_network_pointer;
128bool TrainingStrategy::has_neural_network()
const
130 if(neural_network_pointer ==
nullptr)
return false;
136bool TrainingStrategy::has_data_set()
const
138 if(data_set_pointer ==
nullptr)
return false;
277 case LossMethod::SUM_SQUARED_ERROR:
278 return "SUM_SQUARED_ERROR";
280 case LossMethod::MEAN_SQUARED_ERROR:
281 return "MEAN_SQUARED_ERROR";
283 case LossMethod::NORMALIZED_SQUARED_ERROR:
284 return "NORMALIZED_SQUARED_ERROR";
286 case LossMethod::MINKOWSKI_ERROR:
287 return "MINKOWSKI_ERROR";
289 case LossMethod::WEIGHTED_SQUARED_ERROR:
290 return "WEIGHTED_SQUARED_ERROR";
292 case LossMethod::CROSS_ENTROPY_ERROR:
293 return "CROSS_ENTROPY_ERROR";
307 return "GRADIENT_DESCENT";
311 return "CONJUGATE_GRADIENT";
315 return "QUASI_NEWTON_METHOD";
319 return "LEVENBERG_MARQUARDT_ALGORITHM";
323 return "STOCHASTIC_GRADIENT_DESCENT";
327 return "ADAPTIVE_MOMENT_ESTIMATION";
331 ostringstream buffer;
333 buffer <<
"OpenNN Exception: TrainingStrategy class.\n"
334 <<
"string write_optimization_method() const method.\n"
335 <<
"Unknown main type.\n";
337 throw logic_error(buffer.str());
349 return "gradient descent";
353 return "conjugate gradient";
357 return "quasi-Newton method";
361 return "Levenberg-Marquardt algorithm";
365 return "stochastic gradient descent";
369 return "adaptive moment estimation";
373 ostringstream buffer;
375 buffer <<
"OpenNN Exception: TrainingStrategy class.\n"
376 <<
"string write_optimization_method_text() const method.\n"
377 <<
"Unknown main type.\n";
379 throw logic_error(buffer.str());
390 case LossMethod::SUM_SQUARED_ERROR:
391 return "Sum squared error";
393 case LossMethod::MEAN_SQUARED_ERROR:
394 return "Mean squared error";
396 case LossMethod::NORMALIZED_SQUARED_ERROR:
397 return "Normalized squared error";
399 case LossMethod::MINKOWSKI_ERROR:
400 return "Minkowski error";
402 case LossMethod::WEIGHTED_SQUARED_ERROR:
403 return "Weighted squared error";
405 case LossMethod::CROSS_ENTROPY_ERROR:
406 return "Cross entropy error";
436 set_neural_network_pointer(new_neural_network_pointer);
438 set_data_set_pointer(new_data_set_pointer);
448 if(new_loss_method ==
"SUM_SQUARED_ERROR")
452 else if(new_loss_method ==
"MEAN_SQUARED_ERROR")
456 else if(new_loss_method ==
"NORMALIZED_SQUARED_ERROR")
460 else if(new_loss_method ==
"MINKOWSKI_ERROR")
464 else if(new_loss_method ==
"WEIGHTED_SQUARED_ERROR")
468 else if(new_loss_method ==
"CROSS_ENTROPY_ERROR")
474 ostringstream buffer;
476 buffer <<
"OpenNN Exception: TrainingStrategy class.\n"
477 <<
"void set_loss_method(const string&) method.\n"
478 <<
"Unknown loss method: " << new_loss_method <<
".\n";
480 throw logic_error(buffer.str());
511 if(new_optimization_method ==
"GRADIENT_DESCENT")
515 else if(new_optimization_method ==
"CONJUGATE_GRADIENT")
519 else if(new_optimization_method ==
"QUASI_NEWTON_METHOD")
523 else if(new_optimization_method ==
"LEVENBERG_MARQUARDT_ALGORITHM")
527 else if(new_optimization_method ==
"STOCHASTIC_GRADIENT_DESCENT")
531 else if(new_optimization_method ==
"ADAPTIVE_MOMENT_ESTIMATION")
537 ostringstream buffer;
539 buffer <<
"OpenNN Exception: TrainingStrategy class.\n"
540 <<
"void set_optimization_method(const string&) method.\n"
541 <<
"Unknown main type: " << new_optimization_method <<
".\n";
543 throw logic_error(buffer.str());
548void TrainingStrategy::set_threads_number(
const int& new_threads_number)
550 set_loss_index_threads_number(new_threads_number);
552 set_optimization_algorithm_threads_number(new_threads_number);
556void TrainingStrategy::set_data_set_pointer(DataSet* new_data_set_pointer)
558 data_set_pointer = new_data_set_pointer;
560 set_loss_index_data_set_pointer(data_set_pointer);
564void TrainingStrategy::set_neural_network_pointer(NeuralNetwork* new_neural_network_pointer)
566 neural_network_pointer = new_neural_network_pointer;
568 set_loss_index_neural_network_pointer(neural_network_pointer);
572void TrainingStrategy::set_loss_index_threads_number(
const int& new_threads_number)
583void TrainingStrategy::set_optimization_algorithm_threads_number(
const int& new_threads_number)
608void TrainingStrategy::set_loss_index_data_set_pointer(
DataSet* new_data_set_pointer)
619void TrainingStrategy::set_loss_index_neural_network_pointer(NeuralNetwork* new_neural_network_pointer)
659void TrainingStrategy::set_loss_goal(
const type& new_loss_goal)
668void TrainingStrategy::set_maximum_selection_failures(
const Index& maximum_selection_failures)
677void TrainingStrategy::set_maximum_epochs_number(
const int & maximum_epochs_number)
688void TrainingStrategy::set_display_period(
const int & display_period)
694void TrainingStrategy::set_maximum_time(
const type& maximum_time)
728 ostringstream buffer;
730 buffer <<
"OpenNN Exception: TrainingStrategy class.\n"
731 <<
"TrainingResults perform_training() const method.\n"
732 <<
"Convolutional Layer is not available yet. It will be included in future versions.\n";
734 throw logic_error(buffer.str());
739 case OptimizationMethod::GRADIENT_DESCENT:
746 case OptimizationMethod::CONJUGATE_GRADIENT:
753 case OptimizationMethod::QUASI_NEWTON_METHOD:
761 case OptimizationMethod::LEVENBERG_MARQUARDT_ALGORITHM:
768 case OptimizationMethod::STOCHASTIC_GRADIENT_DESCENT:
775 case OptimizationMethod::ADAPTIVE_MOMENT_ESTIMATION:
808 Index batch_samples_number = 0;
823 if(batch_samples_number%timesteps == 0)
829 const Index constant =
static_cast<Index
>(batch_samples_number/timesteps);
847 cout <<
"Training strategy object" << endl;
858 file_stream.OpenElement(
"TrainingStrategy");
862 file_stream.OpenElement(
"LossIndex");
866 file_stream.OpenElement(
"LossMethod");
878 case LossMethod::MEAN_SQUARED_ERROR :
mean_squared_error.write_regularization_XML(file_stream);
break;
880 case LossMethod::MINKOWSKI_ERROR :
Minkowski_error.write_regularization_XML(file_stream);
break;
881 case LossMethod::CROSS_ENTROPY_ERROR :
cross_entropy_error.write_regularization_XML(file_stream);
break;
882 case LossMethod::WEIGHTED_SQUARED_ERROR :
weighted_squared_error.write_regularization_XML(file_stream);
break;
883 case LossMethod::SUM_SQUARED_ERROR :
sum_squared_error.write_regularization_XML(file_stream);
break;
890 file_stream.OpenElement(
"OptimizationAlgorithm");
892 file_stream.OpenElement(
"OptimizationMethod");
920 ostringstream buffer;
922 buffer <<
"OpenNN Exception: TrainingStrategy class.\n"
923 <<
"void from_XML(const tinyxml2::XMLDocument&) method.\n"
924 <<
"Training strategy element is nullptr.\n";
926 throw logic_error(buffer.str());
933 if(loss_index_element)
935 const tinyxml2::XMLElement* loss_method_element = loss_index_element->FirstChildElement(
"LossMethod");
941 const tinyxml2::XMLElement* Minkowski_error_element = loss_index_element->FirstChildElement(
"MinkowskiError");
943 if(Minkowski_error_element)
952 Minkowski_error_element_copy->InsertEndChild(copy );
955 new_document.InsertEndChild(Minkowski_error_element_copy);
966 const tinyxml2::XMLElement* cross_entropy_element = loss_index_element->FirstChildElement(
"CrossEntropyError");
968 if(cross_entropy_element)
972 tinyxml2::XMLElement* cross_entropy_error_element_copy = new_document.NewElement(
"CrossEntropyError");
977 cross_entropy_error_element_copy->InsertEndChild(copy );
980 new_document.InsertEndChild(cross_entropy_error_element_copy);
987 const tinyxml2::XMLElement* weighted_squared_error_element = loss_index_element->FirstChildElement(
"WeightedSquaredError");
989 if(weighted_squared_error_element)
993 tinyxml2::XMLElement* weighted_squared_error_element_copy = new_document.NewElement(
"WeightedSquaredError");
998 weighted_squared_error_element_copy->InsertEndChild(copy );
1001 new_document.InsertEndChild(weighted_squared_error_element_copy);
1013 const tinyxml2::XMLElement* regularization_element = loss_index_element->FirstChildElement(
"Regularization");
1015 if(regularization_element)
1020 element_clone = regularization_element->DeepClone(®ularization_document);
1022 regularization_document.InsertFirstChild(element_clone);
1030 const tinyxml2::XMLElement* optimization_algorithm_element = root_element->FirstChildElement(
"OptimizationAlgorithm");
1032 if(optimization_algorithm_element)
1034 const tinyxml2::XMLElement* optimization_method_element = optimization_algorithm_element->FirstChildElement(
"OptimizationMethod");
1040 const tinyxml2::XMLElement* gradient_descent_element = optimization_algorithm_element->FirstChildElement(
"GradientDescent");
1042 if(gradient_descent_element)
1046 tinyxml2::XMLElement* gradient_descent_element_copy = gradient_descent_document.NewElement(
"GradientDescent");
1051 gradient_descent_element_copy->InsertEndChild(copy );
1054 gradient_descent_document.InsertEndChild(gradient_descent_element_copy);
1061 const tinyxml2::XMLElement* conjugate_gradient_element = optimization_algorithm_element->FirstChildElement(
"ConjugateGradient");
1063 if(conjugate_gradient_element)
1067 tinyxml2::XMLElement* conjugate_gradient_element_copy = conjugate_gradient_document.NewElement(
"ConjugateGradient");
1072 conjugate_gradient_element_copy->InsertEndChild(copy );
1075 conjugate_gradient_document.InsertEndChild(conjugate_gradient_element_copy);
1082 const tinyxml2::XMLElement* stochastic_gradient_descent_element = optimization_algorithm_element->FirstChildElement(
"StochasticGradientDescent");
1084 if(stochastic_gradient_descent_element)
1088 tinyxml2::XMLElement* stochastic_gradient_descent_element_copy = stochastic_gradient_descent_document.NewElement(
"StochasticGradientDescent");
1092 tinyxml2::XMLNode* copy = nodeFor->DeepClone(&stochastic_gradient_descent_document );
1093 stochastic_gradient_descent_element_copy->InsertEndChild(copy );
1096 stochastic_gradient_descent_document.InsertEndChild(stochastic_gradient_descent_element_copy);
1103 const tinyxml2::XMLElement* adaptive_moment_estimation_element = optimization_algorithm_element->FirstChildElement(
"AdaptiveMomentEstimation");
1105 if(adaptive_moment_estimation_element)
1109 tinyxml2::XMLElement* adaptive_moment_estimation_element_copy = adaptive_moment_estimation_document.NewElement(
"AdaptiveMomentEstimation");
1113 tinyxml2::XMLNode* copy = nodeFor->DeepClone(&adaptive_moment_estimation_document );
1114 adaptive_moment_estimation_element_copy->InsertEndChild(copy );
1117 adaptive_moment_estimation_document.InsertEndChild(adaptive_moment_estimation_element_copy);
1124 const tinyxml2::XMLElement* quasi_Newton_method_element = optimization_algorithm_element->FirstChildElement(
"QuasiNewtonMethod");
1126 if(quasi_Newton_method_element)
1130 tinyxml2::XMLElement* quasi_newton_method_element_copy = quasi_Newton_document.NewElement(
"QuasiNewtonMethod");
1135 quasi_newton_method_element_copy->InsertEndChild(copy );
1138 quasi_Newton_document.InsertEndChild(quasi_newton_method_element_copy);
1145 const tinyxml2::XMLElement* Levenberg_Marquardt_element = optimization_algorithm_element->FirstChildElement(
"LevenbergMarquardt");
1147 if(Levenberg_Marquardt_element)
1151 tinyxml2::XMLElement* levenberg_marquardt_algorithm_element_copy = Levenberg_Marquardt_document.NewElement(
"LevenbergMarquardt");
1156 levenberg_marquardt_algorithm_element_copy->InsertEndChild(copy );
1159 Levenberg_Marquardt_document.InsertEndChild(levenberg_marquardt_algorithm_element_copy);
1171 const string new_display = element->GetText();
1177 catch(
const logic_error& e)
1179 cerr << e.what() << endl;
1191 FILE * file = fopen(file_name.c_str(),
"w");
1211 if(document.LoadFile(file_name.c_str()))
1213 ostringstream buffer;
1215 buffer <<
"OpenNN Exception: TrainingStrategy class.\n"
1216 <<
"void load(const string&) method.\n"
1217 <<
"Cannot load XML file " << file_name <<
".\n";
1219 throw logic_error(buffer.str());
TrainingResults perform_training()
void set_loss_index_pointer(LossIndex *)
void from_XML(const tinyxml2::XMLDocument &)
void set_maximum_time(const type &)
void set_batch_samples_number(const Index &new_batch_samples_number)
Set number of samples in each batch. Default 1000.
void set_maximum_epochs_number(const Index &)
void write_XML(tinyxml2::XMLPrinter &) const
TrainingResults perform_training()
void set_maximum_selection_failures(const Index &)
void set_loss_index_pointer(LossIndex *)
void from_XML(const tinyxml2::XMLDocument &)
void set_maximum_time(const type &)
void set_loss_goal(const type &)
void set_maximum_epochs_number(const Index &)
void write_XML(tinyxml2::XMLPrinter &) const
This class represents the cross entropy error term, used for predicting probabilities.
void from_XML(const tinyxml2::XMLDocument &)
void write_XML(tinyxml2::XMLPrinter &) const
This class represents the concept of data set for data modelling problems, such as approximation,...
TrainingResults perform_training()
void set_maximum_selection_failures(const Index &)
void set_loss_index_pointer(LossIndex *)
void from_XML(const tinyxml2::XMLDocument &)
void set_maximum_time(const type &)
void set_loss_goal(const type &)
void set_maximum_epochs_number(const Index &)
void write_XML(tinyxml2::XMLPrinter &) const
Levenberg-Marquardt Algorithm will always compute the approximate Hessian matrix, which has dimension...
TrainingResults perform_training()
void set_maximum_selection_failures(const Index &)
void from_XML(const tinyxml2::XMLDocument &)
void set_maximum_time(const type &)
void set_loss_goal(const type &)
void set_maximum_epochs_number(const Index &)
void write_XML(tinyxml2::XMLPrinter &) const
Index get_timesteps() const
Returns the number of timesteps.
This abstract class represents the concept of loss index composed of an error term and a regularizati...
virtual void set_data_set_pointer(DataSet *)
Sets a new data set on which the error term is to be measured.
void set_neural_network_pointer(NeuralNetwork *)
void set_display(const bool &)
This class represents the mean squared error term.
void write_XML(tinyxml2::XMLPrinter &) const
This class represents the Minkowski error term.
void from_XML(const tinyxml2::XMLDocument &)
void set_Minkowski_parameter(const type &)
void write_XML(tinyxml2::XMLPrinter &) const
bool has_long_short_term_memory_layer() const
LongShortTermMemoryLayer * get_long_short_term_memory_layer_pointer() const
Returns a pointer to the long short term memory layer of this neural network, if exits.
bool has_recurrent_layer() const
bool has_convolutional_layer() const
RecurrentLayer * get_recurrent_layer_pointer() const
Returns a pointer to the recurrent layer of this neural network, if exits.
This class represents the normalized squared error term.
void set_data_set_pointer(DataSet *new_data_set_pointer)
set_data_set_pointer
void write_XML(tinyxml2::XMLPrinter &) const
void set_display_period(const Index &)
virtual void set_loss_index_pointer(LossIndex *)
virtual void set_display(const bool &)
TrainingResults perform_training()
void set_maximum_selection_failures(const Index &)
void set_loss_index_pointer(LossIndex *)
void from_XML(const tinyxml2::XMLDocument &)
void set_maximum_time(const type &)
void set_loss_goal(const type &)
void set_maximum_epochs_number(const Index &)
void set_display(const bool &)
void write_XML(tinyxml2::XMLPrinter &) const
This concrete class represents the stochastic gradient descent optimization algorithm[1] for a loss i...
TrainingResults perform_training()
void set_loss_index_pointer(LossIndex *)
void from_XML(const tinyxml2::XMLDocument &)
void set_maximum_time(const type &)
void set_maximum_epochs_number(const Index &)
void write_XML(tinyxml2::XMLPrinter &) const
This class represents the sum squared peformance term functional.
AdaptiveMomentEstimation * get_adaptive_moment_estimation_pointer()
TrainingResults perform_training()
QuasiNewtonMethod quasi_Newton_method
Quasi-Newton method object to be used as a main optimization algorithm.
LevenbergMarquardtAlgorithm * get_Levenberg_Marquardt_algorithm_pointer()
NormalizedSquaredError * get_normalized_squared_error_pointer()
LossMethod loss_method
Type of loss method.
GradientDescent * get_gradient_descent_pointer()
virtual ~TrainingStrategy()
CrossEntropyError * get_cross_entropy_error_pointer()
AdaptiveMomentEstimation adaptive_moment_estimation
Adaptive moment estimation algorithm object to be used as a main optimization algorithm.
void set_loss_index_pointer(LossIndex *)
const bool & get_display() const
LossIndex * get_loss_index_pointer()
Returns a pointer to the LossIndex class.
void from_XML(const tinyxml2::XMLDocument &)
string write_loss_method() const
Returns a string with the type of the main loss algorithm composing this training strategy object.
void load(const string &)
bool display
Display messages to screen.
WeightedSquaredError weighted_squared_error
Pointer to the weighted squared error object wich can be used as the error term.
StochasticGradientDescent stochastic_gradient_descent
Stochastic gradient descent algorithm object to be used as a main optimization algorithm.
ConjugateGradient * get_conjugate_gradient_pointer()
void set_loss_method(const LossMethod &)
WeightedSquaredError * get_weighted_squared_error_pointer()
OptimizationAlgorithm * get_optimization_algorithm_pointer()
Returns a pointer to the OptimizationAlgorithm class.
LevenbergMarquardtAlgorithm Levenberg_Marquardt_algorithm
Levenberg-Marquardt algorithm object to be used as a main optimization algorithm.
ConjugateGradient conjugate_gradient
Conjugate gradient object to be used as a main optimization algorithm.
MinkowskiError Minkowski_error
Pointer to the Mikowski error object wich can be used as the error term.
MinkowskiError * get_Minkowski_error_pointer()
CrossEntropyError cross_entropy_error
Pointer to the cross entropy error object wich can be used as the error term.
GradientDescent gradient_descent
Gradient descent object to be used as a main optimization algorithm.
StochasticGradientDescent * get_stochastic_gradient_descent_pointer()
void save(const string &) const
string write_optimization_method() const
NeuralNetwork * get_neural_network_pointer() const
Returns a pointer to the NeuralNetwork class.
void set_optimization_method(const OptimizationMethod &)
MeanSquaredError * get_mean_squared_error_pointer()
QuasiNewtonMethod * get_quasi_Newton_method_pointer()
NormalizedSquaredError normalized_squared_error
Pointer to the normalized squared error object wich can be used as the error term.
LossMethod
Enumeration of available error terms in OpenNN.
void print() const
Prints to the screen the string representation of the training strategy object.
void set_display(const bool &)
SumSquaredError sum_squared_error
Pointer to the sum squared error object wich can be used as the error term.
const LossMethod & get_loss_method() const
Returns the type of the main loss algorithm composing this training strategy object.
void write_XML(tinyxml2::XMLPrinter &) const
SumSquaredError * get_sum_squared_error_pointer()
OptimizationMethod
Enumeration of all the available types of optimization algorithms.
string write_optimization_method_text() const
MeanSquaredError mean_squared_error
Pointer to the mean squared error object wich can be used as the error term.
DataSet * get_data_set_pointer()
Returns a pointer to the DataSet class.
const OptimizationMethod & get_optimization_method() const
Returns the type of the main optimization algorithm composing this training strategy object.
string write_loss_method_text() const
Returns a string with the main loss method type in text format.
OptimizationMethod optimization_method
Type of main optimization algorithm.
This class represents the weighted squared error term.
void set_data_set_pointer(DataSet *)
set_data_set_pointer
void from_XML(const tinyxml2::XMLDocument &)
void set_negatives_weight(const type &)
void set_positives_weight(const type &)
void write_XML(tinyxml2::XMLPrinter &) const
const XMLNode * NextSibling() const
Get the next(right) sibling node of this node.
const XMLNode * FirstChild() const
Get the first child node, or null if none exists.
void PushText(const char *text, bool cdata=false)
Add a text node.
virtual void CloseElement(bool compactMode=false)
If streaming, close the Element.
This structure contains the optimization algorithm results.