learning_rate_algorithm.cpp
1// OpenNN: Open Neural Networks Library
2// www.opennn.net
3//
4// L E A R N I N G R A T E A L G O R I T H M C L A S S
5//
6// Artificial Intelligence Techniques SL
7// artelnics@artelnics.com
8
9#include "learning_rate_algorithm.h"
10
11namespace OpenNN
12{
13
17
19 : loss_index_pointer(nullptr)
20{
22}
23
24
29
31 : loss_index_pointer(new_loss_index_pointer)
32{
34}
35
36
38
40{
41 delete non_blocking_thread_pool;
42 delete thread_pool_device;
43}
44
45
49
51{
52#ifdef OPENNN_DEBUG
53
55 {
56 ostringstream buffer;
57
58 buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
59 << "LossIndex* get_loss_index_pointer() const method.\n"
60 << "Loss index pointer is nullptr.\n";
61
62 throw logic_error(buffer.str());
63 }
64
65#endif
66
67 return loss_index_pointer;
68}
69
70
73
75{
77 {
78 return true;
79 }
80 else
81 {
82 return false;
83 }
84}
85
86
88
90{
92}
93
94
96
98{
100 {
101 case LearningRateMethod::GoldenSection:
102 return "GoldenSection";
103
104 case LearningRateMethod::BrentMethod:
105 return "BrentMethod";
106 }
107
108 return string();
109}
110
111
112const type& LearningRateAlgorithm::get_learning_rate_tolerance() const
113{
115}
116
117
120
122{
123 return display;
124}
125
126
129
131{
132 loss_index_pointer = nullptr;
133
134 set_default();
135}
136
137
141
142void LearningRateAlgorithm::set(LossIndex* new_loss_index_pointer)
143{
144 loss_index_pointer = new_loss_index_pointer;
145
146 set_default();
147}
148
149
151
153{
154 delete non_blocking_thread_pool;
155 delete thread_pool_device;
156
157 const int n = omp_get_max_threads();
158 non_blocking_thread_pool = new NonBlockingThreadPool(n);
159 thread_pool_device = new ThreadPoolDevice(non_blocking_thread_pool, n);
160
161 // TRAINING OPERATORS
162
163 learning_rate_method = LearningRateMethod::BrentMethod;
164
165 // TRAINING PARAMETERS
166
167 learning_rate_tolerance = numeric_limits<type>::epsilon();
168 loss_tolerance = numeric_limits<type>::epsilon();
169}
170
171
174
176{
177 loss_index_pointer = new_loss_index_pointer;
178}
179
180
181void LearningRateAlgorithm::set_threads_number(const int& new_threads_number)
182{
183 if(non_blocking_thread_pool != nullptr) delete this->non_blocking_thread_pool;
184 if(thread_pool_device != nullptr) delete this->thread_pool_device;
185
186 non_blocking_thread_pool = new NonBlockingThreadPool(new_threads_number);
187 thread_pool_device = new ThreadPoolDevice(non_blocking_thread_pool, new_threads_number);
188}
189
190
193
195 const LearningRateAlgorithm::LearningRateMethod& new_learning_rate_method)
196{
197 learning_rate_method = new_learning_rate_method;
198}
199
200
203
204void LearningRateAlgorithm::set_learning_rate_method(const string& new_learning_rate_method)
205{
206 if(new_learning_rate_method == "GoldenSection")
207 {
208 learning_rate_method = LearningRateMethod::GoldenSection;
209 }
210 else if(new_learning_rate_method == "BrentMethod")
211 {
212 learning_rate_method = LearningRateMethod::BrentMethod;
213 }
214 else
215 {
216 ostringstream buffer;
217
218 buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
219 << "void set_method(const string&) method.\n"
220 << "Unknown learning rate method: " << new_learning_rate_method << ".\n";
221
222 throw logic_error(buffer.str());
223 }
224}
225
226
229
230void LearningRateAlgorithm::set_learning_rate_tolerance(const type& new_learning_rate_tolerance)
231{
232#ifdef OPENNN_DEBUG
233
234 if(new_learning_rate_tolerance <= static_cast<type>(0.0))
235 {
236 ostringstream buffer;
237
238 buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
239 << "void set_learning_rate_tolerance(const type&) method.\n"
240 << "Tolerance must be greater than 0.\n";
241
242 throw logic_error(buffer.str());
243 }
244
245#endif
246
247 // Set loss tolerance
248
249 learning_rate_tolerance = new_learning_rate_tolerance;
250}
251
252
257
258void LearningRateAlgorithm::set_display(const bool& new_display)
259{
260 display = new_display;
261}
262
263
267
269 const DataSetBatch& batch,
270 NeuralNetworkForwardPropagation& forward_propagation,
271 LossIndexBackPropagation& back_propagation,
272 OptimizationAlgorithmData& optimization_data) const
273{
274 const NeuralNetwork* neural_network_pointer = loss_index_pointer->get_neural_network_pointer();
275
276#ifdef OPENNN_DEBUG
277
278 if(loss_index_pointer == nullptr)
279 {
280 ostringstream buffer;
281
282 buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
283 << "pair<type, 1> calculate_directional_point() const method.\n"
284 << "Pointer to loss index is nullptr.\n";
285
286 throw logic_error(buffer.str());
287 }
288
289 if(neural_network_pointer == nullptr)
290 {
291 ostringstream buffer;
292
293 buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
294 << "Tensor<type, 1> calculate_directional_point() const method.\n"
295 << "Pointer to neural network is nullptr.\n";
296
297 throw logic_error(buffer.str());
298 }
299
300 if(thread_pool_device == nullptr)
301 {
302 ostringstream buffer;
303
304 buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
305 << "pair<type, 1> calculate_directional_point() const method.\n"
306 << "Pointer to thread pool device is nullptr.\n";
307
308 throw logic_error(buffer.str());
309 }
310
311#endif
312
313 ostringstream buffer;
314
315 // Bracket minimum
316
318 forward_propagation,
319 back_propagation,
320 optimization_data);
321
322 try
323 {
324 triplet.check();
325 }
326 catch(const logic_error& error)
327 {
328// cout << "Triplet bracketing" << endl;
329
330// cout << error.what() << endl;
331
332// cout << "!";
333
334 return triplet.minimum();
335 }
336
337 const type regularization_weight = loss_index_pointer->get_regularization_weight();
338
339 pair<type, type> V;
340
341 // Reduce the interval
342
343 while(abs(triplet.A.first - triplet.B.first) > learning_rate_tolerance
344 || abs(triplet.A.second - triplet.B.second) > loss_tolerance)
345 {
346 try
347 {
349 {
350 case LearningRateMethod::GoldenSection: V.first = calculate_golden_section_learning_rate(triplet); break;
351
352 case LearningRateMethod::BrentMethod: V.first = calculate_Brent_method_learning_rate(triplet); break;
353 }
354 }
355 catch(const logic_error& error)
356 {
357 cout << error.what() << endl;
358
359 return triplet.minimum();
360 }
361
362 // Calculate loss for V
363
364 optimization_data.potential_parameters.device(*thread_pool_device)
365 = back_propagation.parameters + optimization_data.training_direction*V.first;
366
367 neural_network_pointer->forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
368
369 loss_index_pointer->calculate_errors(batch, forward_propagation, back_propagation);
370 loss_index_pointer->calculate_error(batch, forward_propagation, back_propagation);
371
372 const type regularization = loss_index_pointer->calculate_regularization(optimization_data.potential_parameters);
373
374 V.second = back_propagation.error + regularization_weight*regularization;
375
376 // Update points
377
378 if(V.first <= triplet.U.first)
379 {
380 if(V.second >= triplet.U.second)
381 {
382 triplet.A = V;
383 }
384 else if(V.second <= triplet.U.second)
385 {
386 triplet.B = triplet.U;
387 triplet.U = V;
388 }
389 }
390 else if(V.first >= triplet.U.first)
391 {
392 if(V.second >= triplet.U.second)
393 {
394 triplet.B = V;
395 }
396 else if(V.second <= triplet.U.second)
397 {
398 triplet.A = triplet.U;
399 triplet.U = V;
400 }
401 }
402 else
403 {
404 buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
405 << "Tensor<type, 1> calculate_Brent_method_directional_point() const method.\n"
406 << "Unknown set:\n"
407 << "A = (" << triplet.A.first << "," << triplet.A.second << ")\n"
408 << "B = (" << triplet.B.first << "," << triplet.B.second << ")\n"
409 << "U = (" << triplet.U.first << "," << triplet.U.second << ")\n"
410 << "V = (" << V.first << "," << V.second << ")\n";
411
412 throw logic_error(buffer.str());
413 }
414
415 // Check triplet
416
417 try
418 {
419 triplet.check();
420 }
421 catch(const logic_error& error)
422 {
423 return triplet.minimum();
424 }
425 }
426
427 return triplet.U;
428}
429
430
433
435 const DataSetBatch& batch,
436 NeuralNetworkForwardPropagation& forward_propagation,
437 LossIndexBackPropagation& back_propagation,
438 OptimizationAlgorithmData& optimization_data) const
439{
440 Triplet triplet;
441
442#ifdef OPENNN_DEBUG
443
444 ostringstream buffer;
445
446 if(loss_index_pointer == nullptr)
447 {
448 buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
449 << "Triplet calculate_bracketing_triplet() const method.\n"
450 << "Pointer to loss index is nullptr.\n";
451
452 throw logic_error(buffer.str());
453 }
454#endif
455
456 const NeuralNetwork* neural_network_pointer = loss_index_pointer->get_neural_network_pointer();
457
458#ifdef OPENNN_DEBUG
459
460 if(neural_network_pointer == nullptr)
461 {
462 buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
463 << "Triplet calculate_bracketing_triplet() const method.\n"
464 << "Pointer to neural network is nullptr.\n";
465
466 throw logic_error(buffer.str());
467 }
468
469 if(thread_pool_device == nullptr)
470 {
471 buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
472 << "Triplet calculate_bracketing_triplet() const method.\n"
473 << "Pointer to thread pool device is nullptr.\n";
474
475 throw logic_error(buffer.str());
476 }
477
478 if(is_zero(optimization_data.training_direction))
479 {
480 buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
481 << "Triplet calculate_bracketing_triplet() const method.\n"
482 << "Training direction is zero.\n";
483
484 throw logic_error(buffer.str());
485 }
486
487 if(optimization_data.initial_learning_rate < type(NUMERIC_LIMITS_MIN))
488 {
489 buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
490 << "Triplet calculate_bracketing_triplet() const method.\n"
491 << "Initial learning rate is zero.\n";
492
493 throw logic_error(buffer.str());
494 }
495
496#endif
497
498 const type loss = back_propagation.loss;
499
500 const type regularization_weight = loss_index_pointer->get_regularization_weight();
501
502 // Left point
503
504 triplet.A.first = type(0);
505 triplet.A.second = loss;
506
507 // Right point
508
509 Index count = 0;
510
511 do
512 {
513 count++;
514
515 triplet.B.first = optimization_data.initial_learning_rate*static_cast<type>(count);
516
517 optimization_data.potential_parameters.device(*thread_pool_device)
518 = back_propagation.parameters + optimization_data.training_direction*triplet.B.first;
519
520 neural_network_pointer->forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
521
522 loss_index_pointer->calculate_errors(batch, forward_propagation, back_propagation);
523 loss_index_pointer->calculate_error(batch, forward_propagation, back_propagation);
524
525 const type regularization = loss_index_pointer->calculate_regularization(optimization_data.potential_parameters);
526
527 triplet.B.second = back_propagation.error + regularization_weight*regularization;
528
529 } while(abs(triplet.A.second - triplet.B.second) < loss_tolerance);
530
531
532 if(triplet.A.second > triplet.B.second)
533 {
534 triplet.U = triplet.B;
535
536 triplet.B.first *= golden_ratio;
537
538 optimization_data.potential_parameters.device(*thread_pool_device)
539 = back_propagation.parameters + optimization_data.training_direction*triplet.B.first;
540
541 neural_network_pointer->forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
542
543 loss_index_pointer->calculate_errors(batch, forward_propagation, back_propagation);
544 loss_index_pointer->calculate_error(batch, forward_propagation, back_propagation);
545
546 const type regularization = loss_index_pointer->calculate_regularization(optimization_data.potential_parameters);
547
548 triplet.B.second = back_propagation.error + regularization_weight*regularization;
549
550 while(triplet.U.second > triplet.B.second)
551 {
552 triplet.A = triplet.U;
553 triplet.U = triplet.B;
554
555 triplet.B.first *= golden_ratio;
556
557 optimization_data.potential_parameters.device(*thread_pool_device)
558 = back_propagation.parameters + optimization_data.training_direction*triplet.B.first;
559
560 neural_network_pointer->forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
561
562 loss_index_pointer->calculate_errors(batch, forward_propagation, back_propagation);
563 loss_index_pointer->calculate_error(batch, forward_propagation, back_propagation);
564
565 const type regularization = loss_index_pointer->calculate_regularization(optimization_data.potential_parameters);
566
567 triplet.B.second = back_propagation.error + regularization_weight*regularization;
568 }
569 }
570 else if(triplet.A.second < triplet.B.second)
571 {
572 triplet.U.first = triplet.A.first + (triplet.B.first - triplet.A.first)*static_cast<type>(0.382);
573
574 optimization_data.potential_parameters.device(*thread_pool_device)
575 = back_propagation.parameters + optimization_data.training_direction*triplet.U.first;
576
577 neural_network_pointer->forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
578
579 loss_index_pointer->calculate_errors(batch, forward_propagation, back_propagation);
580 loss_index_pointer->calculate_error(batch, forward_propagation, back_propagation);
581
582 const type regularization = loss_index_pointer->calculate_regularization(optimization_data.potential_parameters);
583
584 triplet.U.second = back_propagation.error + regularization_weight*regularization;
585
586 while(triplet.A.second < triplet.U.second)
587 {
588 triplet.B = triplet.U;
589
590 triplet.U.first = triplet.A.first + (triplet.B.first-triplet.A.first)*static_cast<type>(0.382);
591
592 optimization_data.potential_parameters.device(*thread_pool_device)
593 = back_propagation.parameters + optimization_data.training_direction*triplet.U.first;
594
595 neural_network_pointer->forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
596
597 loss_index_pointer->calculate_errors(batch, forward_propagation, back_propagation);
598 loss_index_pointer->calculate_error(batch, forward_propagation, back_propagation);
599
600 const type regularization = loss_index_pointer->calculate_regularization(optimization_data.potential_parameters);
601
602 triplet.U.second = back_propagation.error + regularization_weight*regularization;
603
604 if(triplet.U.first - triplet.A.first <= learning_rate_tolerance)
605 {
606 triplet.U = triplet.A;
607 triplet.B = triplet.A;
608
609 return triplet;
610 }
611 }
612 }
613
614 return triplet;
615}
616
617
620
622{
623 type learning_rate;
624
625 const type middle = triplet.A.first + static_cast<type>(0.5)*(triplet.B.first - triplet.A.first);
626
627 if(triplet.U.first < middle)
628 {
629 learning_rate = triplet.A.first + static_cast<type>(0.618)*(triplet.B.first - triplet.A.first);
630 }
631 else
632 {
633 learning_rate = triplet.A.first + static_cast<type>(0.382)*(triplet.B.first - triplet.A.first);
634 }
635
636#ifdef OPENNN_DEBUG
637
638 if(learning_rate < triplet.A.first)
639 {
640 ostringstream buffer;
641
642 buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
643 << "type calculate_golden_section_learning_rate(const Triplet&) const method.\n"
644 << "Learning rate(" << learning_rate << ") is less than left point("
645 << triplet.A.first << ").\n";
646
647 throw logic_error(buffer.str());
648 }
649
650 if(learning_rate > triplet.B.first)
651 {
652 ostringstream buffer;
653
654 buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
655 << "type calculate_golden_section_learning_rate(const Triplet&) const method.\n"
656 << "Learning rate(" << learning_rate << ") is greater than right point("
657 << triplet.B.first << ").\n";
658
659 throw logic_error(buffer.str());
660 }
661
662#endif
663
664 return learning_rate;
665}
666
667
670
672{
673 const type a = triplet.A.first;
674 const type u = triplet.U.first;
675 const type b = triplet.B.first;
676
677 const type fa = triplet.A.second;
678 const type fu = triplet.U.second;
679 const type fb = triplet.B.second;
680
681 const type numerator = (u-a)*(u-a)*(fu-fb) - (u-b)*(u-b)*(fu-fa);
682
683 const type denominator = (u-a)*(fu-fb) - (u-b)*(fu-fa);
684
685 return u - static_cast<type>(0.5)*numerator/denominator;
686}
687
688
692
694{
695 ostringstream buffer;
696
697 // Learning rate algorithm
698
699 file_stream.OpenElement("LearningRateAlgorithm");
700
701 // Learning rate method
702
703 file_stream.OpenElement("LearningRateMethod");
704
705 file_stream.PushText(write_learning_rate_method().c_str());
706
707 file_stream.CloseElement();
708
709 // Learning rate tolerance
710
711 file_stream.OpenElement("LearningRateTolerance");
712
713 buffer.str("");
714 buffer << learning_rate_tolerance;
715
716 file_stream.PushText(buffer.str().c_str());
717
718 file_stream.CloseElement();
719
720 // Learning rate algorithm (end tag)
721
722 file_stream.CloseElement();
723}
724
725
729
731{
732 const tinyxml2::XMLElement* root_element = document.FirstChildElement("LearningRateAlgorithm");
733
734 if(!root_element)
735 {
736 ostringstream buffer;
737
738 buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
739 << "void from_XML(const tinyxml2::XMLDocument&) method.\n"
740 << "Learning rate algorithm element is nullptr.\n";
741
742 throw logic_error(buffer.str());
743 }
744
745 // Learning rate method
746 {
747 const tinyxml2::XMLElement* element = root_element->FirstChildElement("LearningRateMethod");
748
749 if(element)
750 {
751 string new_learning_rate_method = element->GetText();
752
753 try
754 {
755 set_learning_rate_method(new_learning_rate_method);
756 }
757 catch(const logic_error& e)
758 {
759 cerr << e.what() << endl;
760 }
761 }
762 }
763
764 // Learning rate tolerance
765 {
766 const tinyxml2::XMLElement* element = root_element->FirstChildElement("LearningRateTolerance");
767
768 if(element)
769 {
770 const type new_learning_rate_tolerance = static_cast<type>(atof(element->GetText()));
771
772 try
773 {
774 set_learning_rate_tolerance(new_learning_rate_tolerance);
775 }
776 catch(const logic_error& e)
777 {
778 cerr << e.what() << endl;
779 }
780 }
781 }
782
783 // Display warnings
784 {
785 const tinyxml2::XMLElement* element = root_element->FirstChildElement("Display");
786
787 if(element)
788 {
789 const string new_display = element->GetText();
790
791 try
792 {
793 set_display(new_display != "0");
794 }
795 catch(const logic_error& e)
796 {
797 cerr << e.what() << endl;
798 }
799 }
800 }
801}
802
803}
804
805
806// OpenNN: Open Neural Networks Library.
807// Copyright(C) 2005-2021 Artificial Intelligence Techniques, SL.
808//
809// This library is free software; you can redistribute it and/or
810// modify it under the terms of the GNU Lesser General Public
811// License as published by the Free Software Foundation; either
812// version 2.1 of the License, or any later version.
813//
814// This library is distributed in the hope that it will be useful,
815// but WITHOUT ANY WARRANTY; without even the implied warranty of
816// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
817// Lesser General Public License for more details.
818
819// You should have received a copy of the GNU Lesser General Public
820// License along with this library; if not, write to the Free Software
821// Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
type calculate_Brent_method_learning_rate(const Triplet &) const
void from_XML(const tinyxml2::XMLDocument &)
void set_default()
Sets the members of the learning rate algorithm to their default values.
LossIndex * loss_index_pointer
Pointer to an external loss index object.
bool display
Display messages to screen.
void set_learning_rate_method(const LearningRateMethod &)
string write_learning_rate_method() const
Returns a string with the name of the learning rate method to be used.
LearningRateMethod
Available training operators for obtaining the perform_training rate.
LearningRateMethod learning_rate_method
Variable containing the actual method used to obtain a suitable perform_training rate.
const LearningRateMethod & get_learning_rate_method() const
Returns the learning rate method used for training.
type calculate_golden_section_learning_rate(const Triplet &) const
Triplet calculate_bracketing_triplet(const DataSetBatch &, NeuralNetworkForwardPropagation &, LossIndexBackPropagation &, OptimizationAlgorithmData &) const
pair< type, type > calculate_directional_point(const DataSetBatch &, NeuralNetworkForwardPropagation &, LossIndexBackPropagation &, OptimizationAlgorithmData &) const
void write_XML(tinyxml2::XMLPrinter &) const
type learning_rate_tolerance
Maximum interval length for the learning rate.
This abstract class represents the concept of loss index composed of an error term and a regularizati...
Definition: loss_index.h:48
NeuralNetwork * get_neural_network_pointer() const
Returns a pointer to the neural network object associated to the error term.
Definition: loss_index.h:70
type calculate_regularization(const Tensor< type, 1 > &) const
Definition: loss_index.cpp:648
const type & get_regularization_weight() const
Returns regularization weight.
Definition: loss_index.cpp:52
void forward_propagate(const DataSetBatch &, NeuralNetworkForwardPropagation &) const
Calculate forward propagation in neural network.
void PushText(const char *text, bool cdata=false)
Add a text node.
Definition: tinyxml2.cpp:2878
virtual void CloseElement(bool compactMode=false)
If streaming, close the Element.
Definition: tinyxml2.cpp:2834
HALF_CONSTEXPR half abs(half arg)
Definition: half.hpp:2735
Defines a set of three points (A, U, B) for bracketing a directional minimum.
pair< type, type > B
Right point of the triplet.
pair< type, type > A
Left point of the triplet.
pair< type, type > U
Interior point of the triplet.