9#include "learning_rate_algorithm.h"
19 : loss_index_pointer(nullptr)
31 : loss_index_pointer(new_loss_index_pointer)
41 delete non_blocking_thread_pool;
42 delete thread_pool_device;
58 buffer <<
"OpenNN Exception: LearningRateAlgorithm class.\n"
59 <<
"LossIndex* get_loss_index_pointer() const method.\n"
60 <<
"Loss index pointer is nullptr.\n";
62 throw logic_error(buffer.str());
101 case LearningRateMethod::GoldenSection:
102 return "GoldenSection";
104 case LearningRateMethod::BrentMethod:
105 return "BrentMethod";
112const type& LearningRateAlgorithm::get_learning_rate_tolerance()
const
154 delete non_blocking_thread_pool;
155 delete thread_pool_device;
157 const int n = omp_get_max_threads();
158 non_blocking_thread_pool =
new NonBlockingThreadPool(n);
159 thread_pool_device =
new ThreadPoolDevice(non_blocking_thread_pool, n);
168 loss_tolerance = numeric_limits<type>::epsilon();
181void LearningRateAlgorithm::set_threads_number(
const int& new_threads_number)
183 if(non_blocking_thread_pool !=
nullptr)
delete this->non_blocking_thread_pool;
184 if(thread_pool_device !=
nullptr)
delete this->thread_pool_device;
186 non_blocking_thread_pool =
new NonBlockingThreadPool(new_threads_number);
187 thread_pool_device =
new ThreadPoolDevice(non_blocking_thread_pool, new_threads_number);
206 if(new_learning_rate_method ==
"GoldenSection")
210 else if(new_learning_rate_method ==
"BrentMethod")
216 ostringstream buffer;
218 buffer <<
"OpenNN Exception: LearningRateAlgorithm class.\n"
219 <<
"void set_method(const string&) method.\n"
220 <<
"Unknown learning rate method: " << new_learning_rate_method <<
".\n";
222 throw logic_error(buffer.str());
234 if(new_learning_rate_tolerance <=
static_cast<type
>(0.0))
236 ostringstream buffer;
238 buffer <<
"OpenNN Exception: LearningRateAlgorithm class.\n"
239 <<
"void set_learning_rate_tolerance(const type&) method.\n"
240 <<
"Tolerance must be greater than 0.\n";
242 throw logic_error(buffer.str());
280 ostringstream buffer;
282 buffer <<
"OpenNN Error: LearningRateAlgorithm class.\n"
283 <<
"pair<type, 1> calculate_directional_point() const method.\n"
284 <<
"Pointer to loss index is nullptr.\n";
286 throw logic_error(buffer.str());
289 if(neural_network_pointer ==
nullptr)
291 ostringstream buffer;
293 buffer <<
"OpenNN Error: LearningRateAlgorithm class.\n"
294 <<
"Tensor<type, 1> calculate_directional_point() const method.\n"
295 <<
"Pointer to neural network is nullptr.\n";
297 throw logic_error(buffer.str());
300 if(thread_pool_device ==
nullptr)
302 ostringstream buffer;
304 buffer <<
"OpenNN Error: LearningRateAlgorithm class.\n"
305 <<
"pair<type, 1> calculate_directional_point() const method.\n"
306 <<
"Pointer to thread pool device is nullptr.\n";
308 throw logic_error(buffer.str());
313 ostringstream buffer;
326 catch(
const logic_error& error)
334 return triplet.minimum();
344 ||
abs(triplet.
A.second - triplet.
B.second) > loss_tolerance)
355 catch(
const logic_error& error)
357 cout << error.what() << endl;
359 return triplet.minimum();
364 optimization_data.potential_parameters.device(*thread_pool_device)
365 = back_propagation.parameters + optimization_data.training_direction*V.first;
367 neural_network_pointer->
forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
374 V.second = back_propagation.error + regularization_weight*regularization;
378 if(V.first <= triplet.
U.first)
380 if(V.second >= triplet.
U.second)
384 else if(V.second <= triplet.
U.second)
386 triplet.
B = triplet.
U;
390 else if(V.first >= triplet.
U.first)
392 if(V.second >= triplet.
U.second)
396 else if(V.second <= triplet.
U.second)
398 triplet.
A = triplet.
U;
404 buffer <<
"OpenNN Exception: LearningRateAlgorithm class.\n"
405 <<
"Tensor<type, 1> calculate_Brent_method_directional_point() const method.\n"
407 <<
"A = (" << triplet.
A.first <<
"," << triplet.
A.second <<
")\n"
408 <<
"B = (" << triplet.
B.first <<
"," << triplet.
B.second <<
")\n"
409 <<
"U = (" << triplet.
U.first <<
"," << triplet.
U.second <<
")\n"
410 <<
"V = (" << V.first <<
"," << V.second <<
")\n";
412 throw logic_error(buffer.str());
421 catch(
const logic_error& error)
423 return triplet.minimum();
444 ostringstream buffer;
448 buffer <<
"OpenNN Error: LearningRateAlgorithm class.\n"
449 <<
"Triplet calculate_bracketing_triplet() const method.\n"
450 <<
"Pointer to loss index is nullptr.\n";
452 throw logic_error(buffer.str());
460 if(neural_network_pointer ==
nullptr)
462 buffer <<
"OpenNN Error: LearningRateAlgorithm class.\n"
463 <<
"Triplet calculate_bracketing_triplet() const method.\n"
464 <<
"Pointer to neural network is nullptr.\n";
466 throw logic_error(buffer.str());
469 if(thread_pool_device ==
nullptr)
471 buffer <<
"OpenNN Error: LearningRateAlgorithm class.\n"
472 <<
"Triplet calculate_bracketing_triplet() const method.\n"
473 <<
"Pointer to thread pool device is nullptr.\n";
475 throw logic_error(buffer.str());
478 if(is_zero(optimization_data.training_direction))
480 buffer <<
"OpenNN Error: LearningRateAlgorithm class.\n"
481 <<
"Triplet calculate_bracketing_triplet() const method.\n"
482 <<
"Training direction is zero.\n";
484 throw logic_error(buffer.str());
487 if(optimization_data.initial_learning_rate < type(NUMERIC_LIMITS_MIN))
489 buffer <<
"OpenNN Error: LearningRateAlgorithm class.\n"
490 <<
"Triplet calculate_bracketing_triplet() const method.\n"
491 <<
"Initial learning rate is zero.\n";
493 throw logic_error(buffer.str());
498 const type loss = back_propagation.loss;
504 triplet.
A.first = type(0);
505 triplet.
A.second = loss;
515 triplet.
B.first = optimization_data.initial_learning_rate*
static_cast<type
>(count);
517 optimization_data.potential_parameters.device(*thread_pool_device)
518 = back_propagation.parameters + optimization_data.training_direction*triplet.
B.first;
520 neural_network_pointer->
forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
527 triplet.
B.second = back_propagation.error + regularization_weight*regularization;
529 }
while(
abs(triplet.
A.second - triplet.
B.second) < loss_tolerance);
532 if(triplet.
A.second > triplet.
B.second)
534 triplet.
U = triplet.
B;
536 triplet.
B.first *= golden_ratio;
538 optimization_data.potential_parameters.device(*thread_pool_device)
539 = back_propagation.parameters + optimization_data.training_direction*triplet.
B.first;
541 neural_network_pointer->
forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
548 triplet.
B.second = back_propagation.error + regularization_weight*regularization;
550 while(triplet.
U.second > triplet.
B.second)
552 triplet.
A = triplet.
U;
553 triplet.
U = triplet.
B;
555 triplet.
B.first *= golden_ratio;
557 optimization_data.potential_parameters.device(*thread_pool_device)
558 = back_propagation.parameters + optimization_data.training_direction*triplet.
B.first;
560 neural_network_pointer->
forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
567 triplet.
B.second = back_propagation.error + regularization_weight*regularization;
570 else if(triplet.
A.second < triplet.
B.second)
572 triplet.
U.first = triplet.
A.first + (triplet.
B.first - triplet.
A.first)*
static_cast<type
>(0.382);
574 optimization_data.potential_parameters.device(*thread_pool_device)
575 = back_propagation.parameters + optimization_data.training_direction*triplet.
U.first;
577 neural_network_pointer->
forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
584 triplet.
U.second = back_propagation.error + regularization_weight*regularization;
586 while(triplet.
A.second < triplet.
U.second)
588 triplet.
B = triplet.
U;
590 triplet.
U.first = triplet.
A.first + (triplet.
B.first-triplet.
A.first)*
static_cast<type
>(0.382);
592 optimization_data.potential_parameters.device(*thread_pool_device)
593 = back_propagation.parameters + optimization_data.training_direction*triplet.
U.first;
595 neural_network_pointer->
forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
602 triplet.
U.second = back_propagation.error + regularization_weight*regularization;
606 triplet.
U = triplet.
A;
607 triplet.
B = triplet.
A;
625 const type middle = triplet.
A.first +
static_cast<type
>(0.5)*(triplet.
B.first - triplet.
A.first);
627 if(triplet.
U.first < middle)
629 learning_rate = triplet.
A.first +
static_cast<type
>(0.618)*(triplet.
B.first - triplet.
A.first);
633 learning_rate = triplet.
A.first +
static_cast<type
>(0.382)*(triplet.
B.first - triplet.
A.first);
638 if(learning_rate < triplet.
A.first)
640 ostringstream buffer;
642 buffer <<
"OpenNN Error: LearningRateAlgorithm class.\n"
643 <<
"type calculate_golden_section_learning_rate(const Triplet&) const method.\n"
644 <<
"Learning rate(" << learning_rate <<
") is less than left point("
645 << triplet.
A.first <<
").\n";
647 throw logic_error(buffer.str());
650 if(learning_rate > triplet.
B.first)
652 ostringstream buffer;
654 buffer <<
"OpenNN Error: LearningRateAlgorithm class.\n"
655 <<
"type calculate_golden_section_learning_rate(const Triplet&) const method.\n"
656 <<
"Learning rate(" << learning_rate <<
") is greater than right point("
657 << triplet.
B.first <<
").\n";
659 throw logic_error(buffer.str());
664 return learning_rate;
673 const type a = triplet.
A.first;
674 const type u = triplet.
U.first;
675 const type b = triplet.
B.first;
677 const type fa = triplet.
A.second;
678 const type fu = triplet.
U.second;
679 const type fb = triplet.
B.second;
681 const type numerator = (u-a)*(u-a)*(fu-fb) - (u-b)*(u-b)*(fu-fa);
683 const type denominator = (u-a)*(fu-fb) - (u-b)*(fu-fa);
685 return u -
static_cast<type
>(0.5)*numerator/denominator;
695 ostringstream buffer;
699 file_stream.OpenElement(
"LearningRateAlgorithm");
703 file_stream.OpenElement(
"LearningRateMethod");
711 file_stream.OpenElement(
"LearningRateTolerance");
716 file_stream.
PushText(buffer.str().c_str());
736 ostringstream buffer;
738 buffer <<
"OpenNN Exception: LearningRateAlgorithm class.\n"
739 <<
"void from_XML(const tinyxml2::XMLDocument&) method.\n"
740 <<
"Learning rate algorithm element is nullptr.\n";
742 throw logic_error(buffer.str());
751 string new_learning_rate_method = element->GetText();
757 catch(
const logic_error& e)
759 cerr << e.what() << endl;
770 const type new_learning_rate_tolerance =
static_cast<type
>(atof(element->GetText()));
776 catch(
const logic_error& e)
778 cerr << e.what() << endl;
789 const string new_display = element->GetText();
795 catch(
const logic_error& e)
797 cerr << e.what() << endl;
void set_loss_index_pointer(LossIndex *)
type calculate_Brent_method_learning_rate(const Triplet &) const
const bool & get_display() const
void from_XML(const tinyxml2::XMLDocument &)
void set_default()
Sets the members of the learning rate algorithm to their default values.
void set_learning_rate_tolerance(const type &)
LossIndex * loss_index_pointer
Pointer to an external loss index object.
bool display
Display messages to screen.
void set_learning_rate_method(const LearningRateMethod &)
string write_learning_rate_method() const
Returns a string with the name of the learning rate method to be used.
LearningRateMethod
Available training operators for obtaining the perform_training rate.
bool has_loss_index() const
LearningRateMethod learning_rate_method
Variable containing the actual method used to obtain a suitable perform_training rate.
const LearningRateMethod & get_learning_rate_method() const
Returns the learning rate method used for training.
type calculate_golden_section_learning_rate(const Triplet &) const
virtual ~LearningRateAlgorithm()
Destructor.
Triplet calculate_bracketing_triplet(const DataSetBatch &, NeuralNetworkForwardPropagation &, LossIndexBackPropagation &, OptimizationAlgorithmData &) const
pair< type, type > calculate_directional_point(const DataSetBatch &, NeuralNetworkForwardPropagation &, LossIndexBackPropagation &, OptimizationAlgorithmData &) const
void set_display(const bool &)
void write_XML(tinyxml2::XMLPrinter &) const
type learning_rate_tolerance
Maximum interval length for the learning rate.
LossIndex * get_loss_index_pointer() const
This abstract class represents the concept of loss index composed of an error term and a regularizati...
NeuralNetwork * get_neural_network_pointer() const
Returns a pointer to the neural network object associated to the error term.
type calculate_regularization(const Tensor< type, 1 > &) const
const type & get_regularization_weight() const
Returns regularization weight.
void forward_propagate(const DataSetBatch &, NeuralNetworkForwardPropagation &) const
Calculate forward propagation in neural network.
void PushText(const char *text, bool cdata=false)
Add a text node.
virtual void CloseElement(bool compactMode=false)
If streaming, close the Element.
HALF_CONSTEXPR half abs(half arg)
Defines a set of three points (A, U, B) for bracketing a directional minimum.
pair< type, type > B
Right point of the triplet.
pair< type, type > A
Left point of the triplet.
pair< type, type > U
Interior point of the triplet.