This class is used by many different optimization algorithms to calculate the learning rate given a training direction. More...
#include <learning_rate_algorithm.h>
Classes | |
struct | Triplet |
This structure defines a set of three points (A, U, B) for bracketing a directional minimum. More... | |
Public Types | |
enum class | LearningRateMethod { GoldenSection , BrentMethod } |
Available training operators for obtaining the perform_training rate. More... | |
Protected Attributes | |
LossIndex * | loss_index_pointer = nullptr |
Pointer to an external loss index object. More... | |
LearningRateMethod | learning_rate_method |
Variable containing the actual method used to obtain a suitable perform_training rate. More... | |
type | learning_rate_tolerance |
Maximum interval length for the learning rate. More... | |
type | loss_tolerance |
bool | display = true |
Display messages to screen. More... | |
const type | golden_ratio = static_cast<type>(1.618) |
ThreadPool * | thread_pool = nullptr |
ThreadPoolDevice * | thread_pool_device = nullptr |
This class is used by many different optimization algorithms to calculate the learning rate given a training direction.
The learning rate is adjusted according to an algorithm during training to minimize training time. It implements the golden section method and the Brent's method.
Definition at line 38 of file learning_rate_algorithm.h.
|
strong |
Available training operators for obtaining the perform_training rate.
Definition at line 47 of file learning_rate_algorithm.h.
|
explicit |
Default constructor. It creates a learning rate algorithm object not associated to any loss index object. It also initializes the class members to their default values.
Definition at line 18 of file learning_rate_algorithm.cpp.
|
explicit |
Destructor. It creates a learning rate algorithm associated to a loss index. It also initializes the class members to their default values.
new_loss_index_pointer | Pointer to a loss index object. |
Definition at line 30 of file learning_rate_algorithm.cpp.
|
virtual |
Destructor.
Definition at line 39 of file learning_rate_algorithm.cpp.
LearningRateAlgorithm::Triplet calculate_bracketing_triplet | ( | const DataSetBatch & | batch, |
NeuralNetworkForwardPropagation & | forward_propagation, | ||
LossIndexBackPropagation & | back_propagation, | ||
OptimizationAlgorithmData & | optimization_data | ||
) | const |
Returns bracketing triplet. This algorithm is used by line minimization algorithms.
Definition at line 434 of file learning_rate_algorithm.cpp.
type calculate_Brent_method_learning_rate | ( | const Triplet & | triplet | ) | const |
Returns the minimimal learning rate of a parabola defined by three directional points.
triplet | Triplet containing a minimum. |
Definition at line 671 of file learning_rate_algorithm.cpp.
pair< type, type > calculate_directional_point | ( | const DataSetBatch & | batch, |
NeuralNetworkForwardPropagation & | forward_propagation, | ||
LossIndexBackPropagation & | back_propagation, | ||
OptimizationAlgorithmData & | optimization_data | ||
) | const |
Returns a vector with two elements: (i) the learning rate calculated by means of the corresponding algorithm, and (ii) the loss for that learning rate.
Definition at line 268 of file learning_rate_algorithm.cpp.
type calculate_golden_section_learning_rate | ( | const Triplet & | triplet | ) | const |
Calculates the golden section point within a minimum interval defined by three points.
triplet | Triplet containing a minimum. |
Definition at line 621 of file learning_rate_algorithm.cpp.
void from_XML | ( | const tinyxml2::XMLDocument & | document | ) |
Loads a learning rate algorithm object from a XML-type file. Please mind about the file format, wich is specified in the manual.
document | TinyXML document with the learning rate algorithm members. |
Definition at line 730 of file learning_rate_algorithm.cpp.
const bool & get_display | ( | ) | const |
Returns true if messages from this class can be displayed on the screen, or false if messages from this class can't be displayed on the screen.
Definition at line 121 of file learning_rate_algorithm.cpp.
const LearningRateAlgorithm::LearningRateMethod & get_learning_rate_method | ( | ) | const |
Returns the learning rate method used for training.
Definition at line 89 of file learning_rate_algorithm.cpp.
const type & get_learning_rate_tolerance | ( | ) | const |
Definition at line 112 of file learning_rate_algorithm.cpp.
LossIndex * get_loss_index_pointer | ( | ) | const |
Returns a pointer to the loss index object to which the learning rate algorithm is associated. If the loss index pointer is nullptr, this method throws an exception.
Definition at line 50 of file learning_rate_algorithm.cpp.
bool has_loss_index | ( | ) | const |
Returns true if this learning rate algorithm has an associated loss index, and false otherwise.
Definition at line 74 of file learning_rate_algorithm.cpp.
void set | ( | ) |
Sets the loss index pointer to nullptr. It also sets the rest of members to their default values.
Definition at line 130 of file learning_rate_algorithm.cpp.
void set | ( | LossIndex * | new_loss_index_pointer | ) |
Sets a new loss index pointer. It also sets the rest of members to their default values.
new_loss_index_pointer | Pointer to a loss index object. |
Definition at line 142 of file learning_rate_algorithm.cpp.
void set_default | ( | ) |
Sets the members of the learning rate algorithm to their default values.
Definition at line 152 of file learning_rate_algorithm.cpp.
void set_display | ( | const bool & | new_display | ) |
Sets a new display value. If it is set to true messages from this class are to be displayed on the screen; if it is set to false messages from this class are not to be displayed on the screen.
new_display | Display value. |
Definition at line 258 of file learning_rate_algorithm.cpp.
void set_learning_rate_method | ( | const LearningRateMethod & | new_learning_rate_method | ) |
Sets a new learning rate method to be used for training.
new_learning_rate_method | Learning rate method. |
Definition at line 194 of file learning_rate_algorithm.cpp.
void set_learning_rate_method | ( | const string & | new_learning_rate_method | ) |
Sets the method for obtaining the learning rate from a string with the name of the method.
new_learning_rate_method | Name of learning rate method("Fixed", "GoldenSection", "BrentMethod"). |
Definition at line 204 of file learning_rate_algorithm.cpp.
void set_learning_rate_tolerance | ( | const type & | new_learning_rate_tolerance | ) |
Sets a new tolerance value to be used in line minimization.
new_learning_rate_tolerance | Tolerance value in line minimization. |
Definition at line 230 of file learning_rate_algorithm.cpp.
void set_loss_index_pointer | ( | LossIndex * | new_loss_index_pointer | ) |
Sets a pointer to a loss index object to be associated to the optimization algorithm.
new_loss_index_pointer | Pointer to a loss index object. |
Definition at line 175 of file learning_rate_algorithm.cpp.
void set_threads_number | ( | const int & | new_threads_number | ) |
Definition at line 181 of file learning_rate_algorithm.cpp.
string write_learning_rate_method | ( | ) | const |
Returns a string with the name of the learning rate method to be used.
Definition at line 97 of file learning_rate_algorithm.cpp.
void write_XML | ( | tinyxml2::XMLPrinter & | file_stream | ) | const |
Serializes the learning rate algorithm object into a XML document of the TinyXML library without keep the DOM tree in memory. See the OpenNN manual for more information about the format of this document.
Definition at line 693 of file learning_rate_algorithm.cpp.
|
protected |
Display messages to screen.
Definition at line 286 of file learning_rate_algorithm.h.
|
protected |
Definition at line 288 of file learning_rate_algorithm.h.
|
protected |
Variable containing the actual method used to obtain a suitable perform_training rate.
Definition at line 274 of file learning_rate_algorithm.h.
|
protected |
Maximum interval length for the learning rate.
Definition at line 278 of file learning_rate_algorithm.h.
|
protected |
Pointer to an external loss index object.
Definition at line 268 of file learning_rate_algorithm.h.
|
protected |
Definition at line 280 of file learning_rate_algorithm.h.
|
protected |
Definition at line 290 of file learning_rate_algorithm.h.
|
protected |
Definition at line 291 of file learning_rate_algorithm.h.