9#ifndef LEARNINGRATEALGORITHM_H
10#define LEARNINGRATEALGORITHM_H
26#include "neural_network.h"
27#include "loss_index.h"
28#include "optimization_algorithm.h"
69 A = make_pair(numeric_limits<type>::max(), numeric_limits<type>::max());
70 U = make_pair(numeric_limits<type>::max(), numeric_limits<type>::max());
71 B = make_pair(numeric_limits<type>::max(), numeric_limits<type>::max());
87 if(
A == other_triplet.
A
88 &&
U == other_triplet.
U
89 &&
B == other_triplet.
B)
99 inline type get_length()
const
101 return abs(
B.first -
A.first);
105 inline pair<type,type> minimum()
const
107 Tensor<type, 1> losses(3);
109 losses.setValues({
A.second,
U.second,
B.second});
111 const Index minimal_index = OpenNN::minimal_index(losses);
113 if(minimal_index == 0)
return A;
114 else if(minimal_index == 1)
return U;
122 ostringstream buffer;
124 buffer <<
"A = (" <<
A.first <<
"," <<
A.second <<
")\n"
125 <<
"U = (" <<
U.first <<
"," <<
U.second <<
")\n"
126 <<
"B = (" <<
B.first <<
"," <<
B.second <<
")" << endl;
136 cout <<
"Lenght: " << get_length() << endl;
145 ostringstream buffer;
147 if(
U.first <
A.first)
149 buffer <<
"OpenNN Exception: LearningRateAlgorithm class.\n"
150 <<
"void check() const method.\n"
151 <<
"U is less than A:\n"
154 throw logic_error(buffer.str());
157 if(
U.first >
B.first)
159 buffer <<
"OpenNN Exception: LearningRateAlgorithm class.\n"
160 <<
"void check() const method.\n"
161 <<
"U is greater than A:\n"
164 throw logic_error(buffer.str());
167 if(
U.second >=
A.second)
169 buffer <<
"OpenNN Exception: LearningRateAlgorithm class.\n"
170 <<
"void check() const method.\n"
171 <<
"fU is equal or greater than fA:\n"
174 throw logic_error(buffer.str());
177 if(
U.second >=
B.second)
179 buffer <<
"OpenNN Exception: LearningRateAlgorithm class.\n"
180 <<
"void check() const method.\n"
181 <<
"fU is equal or greater than fB:\n"
184 throw logic_error(buffer.str());
214 const type& get_learning_rate_tolerance()
const;
226 void set_threads_number(
const int&);
290 const type golden_ratio =
static_cast<type
>(1.618);
292 NonBlockingThreadPool* non_blocking_thread_pool =
nullptr;
293 ThreadPoolDevice* thread_pool_device =
nullptr;
A learning rate that is adjusted according to an algorithm during training to minimize training time.
void set_loss_index_pointer(LossIndex *)
type calculate_Brent_method_learning_rate(const Triplet &) const
const bool & get_display() const
void from_XML(const tinyxml2::XMLDocument &)
void set_default()
Sets the members of the learning rate algorithm to their default values.
void set_learning_rate_tolerance(const type &)
LossIndex * loss_index_pointer
Pointer to an external loss index object.
bool display
Display messages to screen.
void set_learning_rate_method(const LearningRateMethod &)
string write_learning_rate_method() const
Returns a string with the name of the learning rate method to be used.
LearningRateMethod
Available training operators for obtaining the perform_training rate.
bool has_loss_index() const
LearningRateMethod learning_rate_method
Variable containing the actual method used to obtain a suitable perform_training rate.
const LearningRateMethod & get_learning_rate_method() const
Returns the learning rate method used for training.
type calculate_golden_section_learning_rate(const Triplet &) const
virtual ~LearningRateAlgorithm()
Destructor.
Triplet calculate_bracketing_triplet(const DataSetBatch &, NeuralNetworkForwardPropagation &, LossIndexBackPropagation &, OptimizationAlgorithmData &) const
pair< type, type > calculate_directional_point(const DataSetBatch &, NeuralNetworkForwardPropagation &, LossIndexBackPropagation &, OptimizationAlgorithmData &) const
void set_display(const bool &)
void write_XML(tinyxml2::XMLPrinter &) const
type learning_rate_tolerance
Maximum interval length for the learning rate.
LossIndex * get_loss_index_pointer() const
This abstract class represents the concept of loss index composed of an error term and a regularizati...
HALF_CONSTEXPR half abs(half arg)
Defines a set of three points (A, U, B) for bracketing a directional minimum.
bool operator==(const Triplet &other_triplet) const
Triplet()
Default constructor.
pair< type, type > B
Right point of the triplet.
virtual ~Triplet()
Destructor.
string struct_to_string() const
Writes a string with the values of A, U and B.
void print() const
Prints the triplet points to the standard output.
pair< type, type > A
Left point of the triplet.
pair< type, type > U
Interior point of the triplet.