9#ifndef TRAININGSTRATEGY_H
10#define TRAININGSTRATEGY_H
25#include "loss_index.h"
26#include "sum_squared_error.h"
27#include "mean_squared_error.h"
28#include "normalized_squared_error.h"
29#include "minkowski_error.h"
30#include "cross_entropy_error.h"
31#include "weighted_squared_error.h"
33#include "optimization_algorithm.h"
35#include "gradient_descent.h"
36#include "conjugate_gradient.h"
37#include "quasi_newton_method.h"
38#include "levenberg_marquardt_algorithm.h"
39#include "stochastic_gradient_descent.h"
40#include "adaptive_moment_estimation.h"
78 NORMALIZED_SQUARED_ERROR,
80 WEIGHTED_SQUARED_ERROR,
91 LEVENBERG_MARQUARDT_ALGORITHM,
92 STOCHASTIC_GRADIENT_DESCENT,
93 ADAPTIVE_MOMENT_ESTIMATION
105 bool has_neural_network()
const;
106 bool has_data_set()
const;
139 void set_threads_number(
const int&);
141 void set_data_set_pointer(
DataSet*);
144 void set_loss_index_threads_number(
const int&);
145 void set_optimization_algorithm_threads_number(
const int&);
148 void set_loss_index_data_set_pointer(
DataSet*);
159 void set_loss_goal(
const type&);
160 void set_maximum_selection_failures(
const Index&);
161 void set_maximum_epochs_number(
const int&);
162 void set_display_period(
const int&);
164 void set_maximum_time(
const type&);
183 void save(
const string&)
const;
184 void load(
const string&);
188 DataSet* data_set_pointer =
nullptr;
257#include "../../opennn-cuda/opennn_cuda/training_strategy_cuda.h"
This concrete class represents the adaptive moment estimation (Adam) optimization algorithm,...
This concrete class represents a conjugate gradient optimization algorithm, based on solving sparse s...
This class represents the cross entropy error term, used for predicting probabilities.
This class represents the concept of a data set for data modelling problems, such as approximation,...
This concrete class represents the gradient descent optimization algorithm, used to minimize the loss...
This concrete class represents the Levenberg-Marquardt optimization algorithm, used to minimize loss ...
This abstract class represents the concept of loss index composed of an error term and a regularizati...
This class represents the mean squared error term.
This class represents the Minkowski error term.
This class represents the concept of neural network in the OpenNN library.
This class represents the normalized squared error term.
This abstract class represents the concept of optimization algorithm for a neural network in OpenNN l...
This concrete class represents a quasi-Newton optimization algorithm, used to minimize the loss funct...
This concrete class represents the stochastic gradient descent optimization algorithm for the loss in...
This class represents the sum squared peformance term functional.
This class represents the concept of training strategy for a neural network in OpenNN.
AdaptiveMomentEstimation * get_adaptive_moment_estimation_pointer()
TrainingResults perform_training()
QuasiNewtonMethod quasi_Newton_method
Quasi-Newton method object to be used as a main optimization algorithm.
LevenbergMarquardtAlgorithm * get_Levenberg_Marquardt_algorithm_pointer()
NormalizedSquaredError * get_normalized_squared_error_pointer()
LossMethod loss_method
Type of loss method.
GradientDescent * get_gradient_descent_pointer()
virtual ~TrainingStrategy()
CrossEntropyError * get_cross_entropy_error_pointer()
AdaptiveMomentEstimation adaptive_moment_estimation
Adaptive moment estimation algorithm object to be used as a main optimization algorithm.
void set_loss_index_pointer(LossIndex *)
const bool & get_display() const
LossIndex * get_loss_index_pointer()
Returns a pointer to the LossIndex class.
void from_XML(const tinyxml2::XMLDocument &)
string write_loss_method() const
Returns a string with the type of the main loss algorithm composing this training strategy object.
void load(const string &)
bool display
Display messages to screen.
WeightedSquaredError weighted_squared_error
Pointer to the weighted squared error object wich can be used as the error term.
StochasticGradientDescent stochastic_gradient_descent
Stochastic gradient descent algorithm object to be used as a main optimization algorithm.
ConjugateGradient * get_conjugate_gradient_pointer()
void set_loss_method(const LossMethod &)
WeightedSquaredError * get_weighted_squared_error_pointer()
OptimizationAlgorithm * get_optimization_algorithm_pointer()
Returns a pointer to the OptimizationAlgorithm class.
LevenbergMarquardtAlgorithm Levenberg_Marquardt_algorithm
Levenberg-Marquardt algorithm object to be used as a main optimization algorithm.
ConjugateGradient conjugate_gradient
Conjugate gradient object to be used as a main optimization algorithm.
MinkowskiError Minkowski_error
Pointer to the Mikowski error object wich can be used as the error term.
MinkowskiError * get_Minkowski_error_pointer()
CrossEntropyError cross_entropy_error
Pointer to the cross entropy error object wich can be used as the error term.
GradientDescent gradient_descent
Gradient descent object to be used as a main optimization algorithm.
StochasticGradientDescent * get_stochastic_gradient_descent_pointer()
void save(const string &) const
string write_optimization_method() const
NeuralNetwork * get_neural_network_pointer() const
Returns a pointer to the NeuralNetwork class.
void set_optimization_method(const OptimizationMethod &)
MeanSquaredError * get_mean_squared_error_pointer()
QuasiNewtonMethod * get_quasi_Newton_method_pointer()
NormalizedSquaredError normalized_squared_error
Pointer to the normalized squared error object wich can be used as the error term.
LossMethod
Enumeration of available error terms in OpenNN.
void print() const
Prints to the screen the string representation of the training strategy object.
void set_display(const bool &)
SumSquaredError sum_squared_error
Pointer to the sum squared error object wich can be used as the error term.
const LossMethod & get_loss_method() const
Returns the type of the main loss algorithm composing this training strategy object.
void write_XML(tinyxml2::XMLPrinter &) const
SumSquaredError * get_sum_squared_error_pointer()
OptimizationMethod
Enumeration of all the available types of optimization algorithms.
string write_optimization_method_text() const
MeanSquaredError mean_squared_error
Pointer to the mean squared error object wich can be used as the error term.
DataSet * get_data_set_pointer()
Returns a pointer to the DataSet class.
const OptimizationMethod & get_optimization_method() const
Returns the type of the main optimization algorithm composing this training strategy object.
string write_loss_method_text() const
Returns a string with the main loss method type in text format.
OptimizationMethod optimization_method
Type of main optimization algorithm.
This class represents the weighted squared error term.
This structure contains the optimization algorithm results.