learning_rate_algorithm.h
1// OpenNN: Open Neural Networks Library
2// www.opennn.net
3//
4// L E A R N I N G R A T E A L G O R I T H M C L A S S H E A D E R
5//
6// Artificial Intelligence Techniques SL
7// artelnics@artelnics.com
8
9#ifndef LEARNINGRATEALGORITHM_H
10#define LEARNINGRATEALGORITHM_H
11
12// System includes
13
14#include <iostream>
15#include <fstream>
16#include <algorithm>
17#include <functional>
18#include <limits>
19#include <cmath>
20#include <ctime>
21#include <cstdlib>
22
23// OpenNN includes
24
25#include "config.h"
26#include "neural_network.h"
27#include "loss_index.h"
28#include "optimization_algorithm.h"
29
30namespace OpenNN
31{
32
34
39
41{
42
43public:
44
45 // Enumerations
46
48
49 enum class LearningRateMethod{GoldenSection, BrentMethod};
50
51 // Constructors
52
53 explicit LearningRateAlgorithm();
54
56
57 // Destructor
58
59 virtual ~LearningRateAlgorithm();
60
62
63 struct Triplet
64 {
66
68 {
69 A = make_pair(numeric_limits<type>::max(), numeric_limits<type>::max());
70 U = make_pair(numeric_limits<type>::max(), numeric_limits<type>::max());
71 B = make_pair(numeric_limits<type>::max(), numeric_limits<type>::max());
72 }
73
75
76 virtual ~Triplet()
77 {
78 }
79
84
85 inline bool operator == (const Triplet& other_triplet) const
86 {
87 if(A == other_triplet.A
88 && U == other_triplet.U
89 && B == other_triplet.B)
90 {
91 return true;
92 }
93 else
94 {
95 return false;
96 }
97 }
98
99 inline type get_length() const
100 {
101 return abs(B.first - A.first);
102 }
103
104
105 inline pair<type,type> minimum() const
106 {
107 Tensor<type, 1> losses(3);
108
109 losses.setValues({A.second, U.second, B.second});
110
111 const Index minimal_index = OpenNN::minimal_index(losses);
112
113 if(minimal_index == 0) return A;
114 else if(minimal_index == 1) return U;
115 else return B;
116 }
117
119
120 inline string struct_to_string() const
121 {
122 ostringstream buffer;
123
124 buffer << "A = (" << A.first << "," << A.second << ")\n"
125 << "U = (" << U.first << "," << U.second << ")\n"
126 << "B = (" << B.first << "," << B.second << ")" << endl;
127
128 return buffer.str();
129 }
130
132
133 inline void print() const
134 {
135 cout << struct_to_string();
136 cout << "Lenght: " << get_length() << endl;
137 }
138
142
143 inline void check() const
144 {
145 ostringstream buffer;
146
147 if(U.first < A.first)
148 {
149 buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
150 << "void check() const method.\n"
151 << "U is less than A:\n"
152 << struct_to_string();
153
154 throw logic_error(buffer.str());
155 }
156
157 if(U.first > B.first)
158 {
159 buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
160 << "void check() const method.\n"
161 << "U is greater than A:\n"
162 << struct_to_string();
163
164 throw logic_error(buffer.str());
165 }
166
167 if(U.second >= A.second)
168 {
169 buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
170 << "void check() const method.\n"
171 << "fU is equal or greater than fA:\n"
172 << struct_to_string();
173
174 throw logic_error(buffer.str());
175 }
176
177 if(U.second >= B.second)
178 {
179 buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
180 << "void check() const method.\n"
181 << "fU is equal or greater than fB:\n"
182 << struct_to_string();
183
184 throw logic_error(buffer.str());
185 }
186 }
187
189
190 pair<type, type> A;
191
193
194 pair<type, type> U;
195
197
198 pair<type, type> B;
199 };
200
201 // Get methods
202
204
205 bool has_loss_index() const;
206
207 // Training operators
208
210 string write_learning_rate_method() const;
211
212 // Training parameters
213
214 const type& get_learning_rate_tolerance() const;
215
216 // Utilities
217
218 const bool& get_display() const;
219
220 // Set methods
221
222 void set();
223 void set(LossIndex*);
224
226 void set_threads_number(const int&);
227
228 // Training operators
229
231 void set_learning_rate_method(const string&);
232
233 // Training parameters
234
235 void set_learning_rate_tolerance(const type&);
236
237 // Utilities
238
239 void set_display(const bool&);
240
241 void set_default();
242
243 // Learning rate methods
244
247
252
253 pair<type, type> calculate_directional_point(const DataSetBatch&,
257
258 // Serialization methods
259
260 void from_XML(const tinyxml2::XMLDocument&);
261
262 void write_XML(tinyxml2::XMLPrinter&) const;
263
264protected:
265
266 // FIELDS
267
269
271
272 // TRAINING OPERATORS
273
275
277
279
281
282 type loss_tolerance;
283
284 // UTILITIES
285
287
288 bool display = true;
289
290 const type golden_ratio = static_cast<type>(1.618);
291
292 NonBlockingThreadPool* non_blocking_thread_pool = nullptr;
293 ThreadPoolDevice* thread_pool_device = nullptr;
294};
295
296}
297
298#endif
299
300
301// OpenNN: Open Neural Networks Library.
302// Copyright(C) 2005-2021 Artificial Intelligence Techniques, SL.
303//
304// This library is free software; you can redistribute it and/or
305// modify it under the terms of the GNU Lesser General Public
306// License as published by the Free Software Foundation; either
307// version 2.1 of the License, or any later version.
308//
309// This library is distributed in the hope that it will be useful,
310// but WITHOUT ANY WARRANTY; without even the implied warranty of
311// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
312// Lesser General Public License for more details.
313
314// You should have received a copy of the GNU Lesser General Public
315// License along with this library; if not, write to the Free Software
316// Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
A learning rate that is adjusted according to an algorithm during training to minimize training time.
type calculate_Brent_method_learning_rate(const Triplet &) const
void from_XML(const tinyxml2::XMLDocument &)
void set_default()
Sets the members of the learning rate algorithm to their default values.
LossIndex * loss_index_pointer
Pointer to an external loss index object.
bool display
Display messages to screen.
void set_learning_rate_method(const LearningRateMethod &)
string write_learning_rate_method() const
Returns a string with the name of the learning rate method to be used.
LearningRateMethod
Available training operators for obtaining the perform_training rate.
LearningRateMethod learning_rate_method
Variable containing the actual method used to obtain a suitable perform_training rate.
const LearningRateMethod & get_learning_rate_method() const
Returns the learning rate method used for training.
type calculate_golden_section_learning_rate(const Triplet &) const
Triplet calculate_bracketing_triplet(const DataSetBatch &, NeuralNetworkForwardPropagation &, LossIndexBackPropagation &, OptimizationAlgorithmData &) const
pair< type, type > calculate_directional_point(const DataSetBatch &, NeuralNetworkForwardPropagation &, LossIndexBackPropagation &, OptimizationAlgorithmData &) const
void write_XML(tinyxml2::XMLPrinter &) const
type learning_rate_tolerance
Maximum interval length for the learning rate.
This abstract class represents the concept of loss index composed of an error term and a regularizati...
Definition: loss_index.h:48
HALF_CONSTEXPR half abs(half arg)
Definition: half.hpp:2735
Defines a set of three points (A, U, B) for bracketing a directional minimum.
bool operator==(const Triplet &other_triplet) const
pair< type, type > B
Right point of the triplet.
string struct_to_string() const
Writes a string with the values of A, U and B.
void print() const
Prints the triplet points to the standard output.
pair< type, type > A
Left point of the triplet.
pair< type, type > U
Interior point of the triplet.