LearningRateAlgorithm Class Reference

This class is used by many different optimization algorithms to calculate the learning rate given a training direction. More...

#include <learning_rate_algorithm.h>

Classes

struct  Triplet
 This structure defines a set of three points (A, U, B) for bracketing a directional minimum. More...
 

Public Types

enum class  LearningRateMethod { GoldenSection , BrentMethod }
 Available training operators for obtaining the perform_training rate. More...
 

Public Member Functions

 LearningRateAlgorithm ()
 
 LearningRateAlgorithm (LossIndex *)
 
virtual ~LearningRateAlgorithm ()
 Destructor. More...
 
LossIndexget_loss_index_pointer () const
 
bool has_loss_index () const
 
const LearningRateMethodget_learning_rate_method () const
 Returns the learning rate method used for training. More...
 
string write_learning_rate_method () const
 Returns a string with the name of the learning rate method to be used. More...
 
const type & get_learning_rate_tolerance () const
 
const bool & get_display () const
 
void set ()
 
void set (LossIndex *)
 
void set_loss_index_pointer (LossIndex *)
 
void set_threads_number (const int &)
 
void set_learning_rate_method (const LearningRateMethod &)
 
void set_learning_rate_method (const string &)
 
void set_learning_rate_tolerance (const type &)
 
void set_display (const bool &)
 
void set_default ()
 Sets the members of the learning rate algorithm to their default values. More...
 
type calculate_golden_section_learning_rate (const Triplet &) const
 
type calculate_Brent_method_learning_rate (const Triplet &) const
 
Triplet calculate_bracketing_triplet (const DataSetBatch &, NeuralNetworkForwardPropagation &, LossIndexBackPropagation &, OptimizationAlgorithmData &) const
 
pair< type, type > calculate_directional_point (const DataSetBatch &, NeuralNetworkForwardPropagation &, LossIndexBackPropagation &, OptimizationAlgorithmData &) const
 
void from_XML (const tinyxml2::XMLDocument &)
 
void write_XML (tinyxml2::XMLPrinter &) const
 

Protected Attributes

LossIndexloss_index_pointer = nullptr
 Pointer to an external loss index object. More...
 
LearningRateMethod learning_rate_method
 Variable containing the actual method used to obtain a suitable perform_training rate. More...
 
type learning_rate_tolerance
 Maximum interval length for the learning rate. More...
 
type loss_tolerance
 
bool display = true
 Display messages to screen. More...
 
const type golden_ratio = static_cast<type>(1.618)
 
ThreadPool * thread_pool = nullptr
 
ThreadPoolDevice * thread_pool_device = nullptr
 

Detailed Description

This class is used by many different optimization algorithms to calculate the learning rate given a training direction.

The learning rate is adjusted according to an algorithm during training to minimize training time. It implements the golden section method and the Brent's method.

Definition at line 38 of file learning_rate_algorithm.h.

Member Enumeration Documentation

◆ LearningRateMethod

enum class LearningRateMethod
strong

Available training operators for obtaining the perform_training rate.

Definition at line 47 of file learning_rate_algorithm.h.

Constructor & Destructor Documentation

◆ LearningRateAlgorithm() [1/2]

LearningRateAlgorithm ( )
explicit

Default constructor. It creates a learning rate algorithm object not associated to any loss index object. It also initializes the class members to their default values.

Definition at line 18 of file learning_rate_algorithm.cpp.

◆ LearningRateAlgorithm() [2/2]

LearningRateAlgorithm ( LossIndex new_loss_index_pointer)
explicit

Destructor. It creates a learning rate algorithm associated to a loss index. It also initializes the class members to their default values.

Parameters
new_loss_index_pointerPointer to a loss index object.

Definition at line 30 of file learning_rate_algorithm.cpp.

◆ ~LearningRateAlgorithm()

~LearningRateAlgorithm ( )
virtual

Destructor.

Definition at line 39 of file learning_rate_algorithm.cpp.

Member Function Documentation

◆ calculate_bracketing_triplet()

LearningRateAlgorithm::Triplet calculate_bracketing_triplet ( const DataSetBatch batch,
NeuralNetworkForwardPropagation forward_propagation,
LossIndexBackPropagation back_propagation,
OptimizationAlgorithmData optimization_data 
) const

Returns bracketing triplet. This algorithm is used by line minimization algorithms.

Definition at line 434 of file learning_rate_algorithm.cpp.

◆ calculate_Brent_method_learning_rate()

type calculate_Brent_method_learning_rate ( const Triplet triplet) const

Returns the minimimal learning rate of a parabola defined by three directional points.

Parameters
tripletTriplet containing a minimum.

Definition at line 671 of file learning_rate_algorithm.cpp.

◆ calculate_directional_point()

pair< type, type > calculate_directional_point ( const DataSetBatch batch,
NeuralNetworkForwardPropagation forward_propagation,
LossIndexBackPropagation back_propagation,
OptimizationAlgorithmData optimization_data 
) const

Returns a vector with two elements: (i) the learning rate calculated by means of the corresponding algorithm, and (ii) the loss for that learning rate.

Definition at line 268 of file learning_rate_algorithm.cpp.

◆ calculate_golden_section_learning_rate()

type calculate_golden_section_learning_rate ( const Triplet triplet) const

Calculates the golden section point within a minimum interval defined by three points.

Parameters
tripletTriplet containing a minimum.

Definition at line 621 of file learning_rate_algorithm.cpp.

◆ from_XML()

void from_XML ( const tinyxml2::XMLDocument document)

Loads a learning rate algorithm object from a XML-type file. Please mind about the file format, wich is specified in the manual.

Parameters
documentTinyXML document with the learning rate algorithm members.

Definition at line 730 of file learning_rate_algorithm.cpp.

◆ get_display()

const bool & get_display ( ) const

Returns true if messages from this class can be displayed on the screen, or false if messages from this class can't be displayed on the screen.

Definition at line 121 of file learning_rate_algorithm.cpp.

◆ get_learning_rate_method()

const LearningRateAlgorithm::LearningRateMethod & get_learning_rate_method ( ) const

Returns the learning rate method used for training.

Definition at line 89 of file learning_rate_algorithm.cpp.

◆ get_learning_rate_tolerance()

const type & get_learning_rate_tolerance ( ) const

Definition at line 112 of file learning_rate_algorithm.cpp.

◆ get_loss_index_pointer()

LossIndex * get_loss_index_pointer ( ) const

Returns a pointer to the loss index object to which the learning rate algorithm is associated. If the loss index pointer is nullptr, this method throws an exception.

Definition at line 50 of file learning_rate_algorithm.cpp.

◆ has_loss_index()

bool has_loss_index ( ) const

Returns true if this learning rate algorithm has an associated loss index, and false otherwise.

Definition at line 74 of file learning_rate_algorithm.cpp.

◆ set() [1/2]

void set ( )

Sets the loss index pointer to nullptr. It also sets the rest of members to their default values.

Definition at line 130 of file learning_rate_algorithm.cpp.

◆ set() [2/2]

void set ( LossIndex new_loss_index_pointer)

Sets a new loss index pointer. It also sets the rest of members to their default values.

Parameters
new_loss_index_pointerPointer to a loss index object.

Definition at line 142 of file learning_rate_algorithm.cpp.

◆ set_default()

void set_default ( )

Sets the members of the learning rate algorithm to their default values.

Definition at line 152 of file learning_rate_algorithm.cpp.

◆ set_display()

void set_display ( const bool &  new_display)

Sets a new display value. If it is set to true messages from this class are to be displayed on the screen; if it is set to false messages from this class are not to be displayed on the screen.

Parameters
new_displayDisplay value.

Definition at line 258 of file learning_rate_algorithm.cpp.

◆ set_learning_rate_method() [1/2]

void set_learning_rate_method ( const LearningRateMethod new_learning_rate_method)

Sets a new learning rate method to be used for training.

Parameters
new_learning_rate_methodLearning rate method.

Definition at line 194 of file learning_rate_algorithm.cpp.

◆ set_learning_rate_method() [2/2]

void set_learning_rate_method ( const string &  new_learning_rate_method)

Sets the method for obtaining the learning rate from a string with the name of the method.

Parameters
new_learning_rate_methodName of learning rate method("Fixed", "GoldenSection", "BrentMethod").

Definition at line 204 of file learning_rate_algorithm.cpp.

◆ set_learning_rate_tolerance()

void set_learning_rate_tolerance ( const type &  new_learning_rate_tolerance)

Sets a new tolerance value to be used in line minimization.

Parameters
new_learning_rate_toleranceTolerance value in line minimization.

Definition at line 230 of file learning_rate_algorithm.cpp.

◆ set_loss_index_pointer()

void set_loss_index_pointer ( LossIndex new_loss_index_pointer)

Sets a pointer to a loss index object to be associated to the optimization algorithm.

Parameters
new_loss_index_pointerPointer to a loss index object.

Definition at line 175 of file learning_rate_algorithm.cpp.

◆ set_threads_number()

void set_threads_number ( const int &  new_threads_number)

Definition at line 181 of file learning_rate_algorithm.cpp.

◆ write_learning_rate_method()

string write_learning_rate_method ( ) const

Returns a string with the name of the learning rate method to be used.

Definition at line 97 of file learning_rate_algorithm.cpp.

◆ write_XML()

void write_XML ( tinyxml2::XMLPrinter file_stream) const

Serializes the learning rate algorithm object into a XML document of the TinyXML library without keep the DOM tree in memory. See the OpenNN manual for more information about the format of this document.

Definition at line 693 of file learning_rate_algorithm.cpp.

Member Data Documentation

◆ display

bool display = true
protected

Display messages to screen.

Definition at line 286 of file learning_rate_algorithm.h.

◆ golden_ratio

const type golden_ratio = static_cast<type>(1.618)
protected

Definition at line 288 of file learning_rate_algorithm.h.

◆ learning_rate_method

LearningRateMethod learning_rate_method
protected

Variable containing the actual method used to obtain a suitable perform_training rate.

Definition at line 274 of file learning_rate_algorithm.h.

◆ learning_rate_tolerance

type learning_rate_tolerance
protected

Maximum interval length for the learning rate.

Definition at line 278 of file learning_rate_algorithm.h.

◆ loss_index_pointer

LossIndex* loss_index_pointer = nullptr
protected

Pointer to an external loss index object.

Definition at line 268 of file learning_rate_algorithm.h.

◆ loss_tolerance

type loss_tolerance
protected

Definition at line 280 of file learning_rate_algorithm.h.

◆ thread_pool

ThreadPool* thread_pool = nullptr
protected

Definition at line 290 of file learning_rate_algorithm.h.

◆ thread_pool_device

ThreadPoolDevice* thread_pool_device = nullptr
protected

Definition at line 291 of file learning_rate_algorithm.h.


The documentation for this class was generated from the following files: