training_strategy.h
1// OpenNN: Open Neural Networks Library
2// www.opennn.net
3//
4// T R A I N I N G S T R A T E G Y C L A S S H E A D E R
5//
6// Artificial Intelligence Techniques SL
7// artelnics@artelnics.com
8
9#ifndef TRAININGSTRATEGY_H
10#define TRAININGSTRATEGY_H
11
12// System includes
13
14#include <iostream>
15#include <fstream>
16#include <algorithm>
17#include <functional>
18#include <limits>
19#include <cmath>
20#include <ctime>
21
22// OpenNN includes
23
24#include "config.h"
25#include "loss_index.h"
26#include "sum_squared_error.h"
27#include "mean_squared_error.h"
28#include "normalized_squared_error.h"
29#include "minkowski_error.h"
30#include "cross_entropy_error.h"
31#include "weighted_squared_error.h"
32
33#include "optimization_algorithm.h"
34
35#include "gradient_descent.h"
36#include "conjugate_gradient.h"
37#include "quasi_newton_method.h"
38#include "levenberg_marquardt_algorithm.h"
39#include "stochastic_gradient_descent.h"
40#include "adaptive_moment_estimation.h"
41
42
43namespace OpenNN
44{
45
47
54
56{
57
58public:
59
60 // Constructors
61
62 explicit TrainingStrategy();
63
65
66 // Destructor
67
68 virtual ~TrainingStrategy();
69
70 // Enumerations
71
73
74 enum class LossMethod
75 {
76 SUM_SQUARED_ERROR,
77 MEAN_SQUARED_ERROR,
78 NORMALIZED_SQUARED_ERROR,
79 MINKOWSKI_ERROR,
80 WEIGHTED_SQUARED_ERROR,
81 CROSS_ENTROPY_ERROR
82 };
83
85
87 {
88 GRADIENT_DESCENT,
89 CONJUGATE_GRADIENT,
90 QUASI_NEWTON_METHOD,
91 LEVENBERG_MARQUARDT_ALGORITHM,
92 STOCHASTIC_GRADIENT_DESCENT,
93 ADAPTIVE_MOMENT_ESTIMATION
94 };
95
96 // Get methods
97
99
101
104
105 bool has_neural_network() const;
106 bool has_data_set() const;
107
114
121
122 const LossMethod& get_loss_method() const;
124
125 string write_loss_method() const;
126 string write_optimization_method() const;
127
128 string write_optimization_method_text() const;
129 string write_loss_method_text() const;
130
131 const bool& get_display() const;
132
133 // Set methods
134
135 void set();
136 void set(NeuralNetwork*, DataSet*);
137 void set_default();
138
139 void set_threads_number(const int&);
140
141 void set_data_set_pointer(DataSet*);
142 void set_neural_network_pointer(NeuralNetwork*);
143
144 void set_loss_index_threads_number(const int&);
145 void set_optimization_algorithm_threads_number(const int&);
146
148 void set_loss_index_data_set_pointer(DataSet*);
149 void set_loss_index_neural_network_pointer(NeuralNetwork*);
150
151 void set_loss_method(const LossMethod&);
153
154 void set_loss_method(const string&);
155 void set_optimization_method(const string&);
156
157 void set_display(const bool&);
158
159 void set_loss_goal(const type&);
160 void set_maximum_selection_failures(const Index&);
161 void set_maximum_epochs_number(const int&);
162 void set_display_period(const int&);
163
164 void set_maximum_time(const type&);
165
166 // Training methods
167
169
170
171 // Check methods
172
173 void fix_forecasting();
174
175 // Serialization methods
176
177 void print() const;
178
179 void from_XML(const tinyxml2::XMLDocument&);
180
181 void write_XML(tinyxml2::XMLPrinter&) const;
182
183 void save(const string&) const;
184 void load(const string&);
185
186private:
187
188 DataSet* data_set_pointer = nullptr;
189
190 NeuralNetwork* neural_network_pointer = nullptr;
191
192 // Loss index
193
195
197
199
201
203
205
207
209
211
213
215
217
219
221
222 // Optimization algorithm
223
225
227
229
231
233
235
237
239
241
243
245
247
249
251
253
254 bool display = true;
255
256#ifdef OPENNN_CUDA
257#include "../../opennn-cuda/opennn_cuda/training_strategy_cuda.h"
258#endif
259
260};
261
262}
263
264#endif
265
266
267// OpenNN: Open Neural Networks Library.
268// Copyright(C) 2005-2021 Artificial Intelligence Techniques, SL.
269//
270// This library is free software; you can redistribute it and/or
271// modify it under the terms of the GNU Lesser General Public
272// License as published by the Free Software Foundation; either
273// version 2.1 of the License, or any later version.
274//
275// This library is distributed in the hope that it will be useful,
276// but WITHOUT ANY WARRANTY; without even the implied warranty of
277// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
278// Lesser General Public License for more details.
279
280// You should have received a copy of the GNU Lesser General Public
281// License along with this library; if not, write to the Free Software
282// Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
283
This concrete class represents the adaptive moment estimation (Adam) optimization algorithm,...
This concrete class represents a conjugate gradient optimization algorithm, based on solving sparse s...
This class represents the cross entropy error term, used for predicting probabilities.
This class represents the concept of a data set for data modelling problems, such as approximation,...
Definition: data_set.h:56
This concrete class represents the gradient descent optimization algorithm, used to minimize the loss...
This concrete class represents the Levenberg-Marquardt optimization algorithm, used to minimize loss ...
This abstract class represents the concept of loss index composed of an error term and a regularizati...
Definition: loss_index.h:49
This class represents the mean squared error term.
This class represents the Minkowski error term.
This class represents the concept of neural network in the OpenNN library.
This class represents the normalized squared error term.
This abstract class represents the concept of optimization algorithm for a neural network in OpenNN l...
This concrete class represents a quasi-Newton optimization algorithm, used to minimize the loss funct...
This concrete class represents the stochastic gradient descent optimization algorithm for the loss in...
This class represents the sum squared peformance term functional.
This class represents the concept of training strategy for a neural network in OpenNN.
AdaptiveMomentEstimation * get_adaptive_moment_estimation_pointer()
TrainingResults perform_training()
QuasiNewtonMethod quasi_Newton_method
Quasi-Newton method object to be used as a main optimization algorithm.
LevenbergMarquardtAlgorithm * get_Levenberg_Marquardt_algorithm_pointer()
NormalizedSquaredError * get_normalized_squared_error_pointer()
LossMethod loss_method
Type of loss method.
GradientDescent * get_gradient_descent_pointer()
CrossEntropyError * get_cross_entropy_error_pointer()
AdaptiveMomentEstimation adaptive_moment_estimation
Adaptive moment estimation algorithm object to be used as a main optimization algorithm.
void set_loss_index_pointer(LossIndex *)
const bool & get_display() const
LossIndex * get_loss_index_pointer()
Returns a pointer to the LossIndex class.
void from_XML(const tinyxml2::XMLDocument &)
string write_loss_method() const
Returns a string with the type of the main loss algorithm composing this training strategy object.
bool display
Display messages to screen.
WeightedSquaredError weighted_squared_error
Pointer to the weighted squared error object wich can be used as the error term.
StochasticGradientDescent stochastic_gradient_descent
Stochastic gradient descent algorithm object to be used as a main optimization algorithm.
ConjugateGradient * get_conjugate_gradient_pointer()
void set_loss_method(const LossMethod &)
WeightedSquaredError * get_weighted_squared_error_pointer()
OptimizationAlgorithm * get_optimization_algorithm_pointer()
Returns a pointer to the OptimizationAlgorithm class.
LevenbergMarquardtAlgorithm Levenberg_Marquardt_algorithm
Levenberg-Marquardt algorithm object to be used as a main optimization algorithm.
ConjugateGradient conjugate_gradient
Conjugate gradient object to be used as a main optimization algorithm.
MinkowskiError Minkowski_error
Pointer to the Mikowski error object wich can be used as the error term.
MinkowskiError * get_Minkowski_error_pointer()
CrossEntropyError cross_entropy_error
Pointer to the cross entropy error object wich can be used as the error term.
GradientDescent gradient_descent
Gradient descent object to be used as a main optimization algorithm.
StochasticGradientDescent * get_stochastic_gradient_descent_pointer()
void save(const string &) const
string write_optimization_method() const
NeuralNetwork * get_neural_network_pointer() const
Returns a pointer to the NeuralNetwork class.
void set_optimization_method(const OptimizationMethod &)
MeanSquaredError * get_mean_squared_error_pointer()
QuasiNewtonMethod * get_quasi_Newton_method_pointer()
NormalizedSquaredError normalized_squared_error
Pointer to the normalized squared error object wich can be used as the error term.
LossMethod
Enumeration of available error terms in OpenNN.
void print() const
Prints to the screen the string representation of the training strategy object.
void set_display(const bool &)
SumSquaredError sum_squared_error
Pointer to the sum squared error object wich can be used as the error term.
const LossMethod & get_loss_method() const
Returns the type of the main loss algorithm composing this training strategy object.
void write_XML(tinyxml2::XMLPrinter &) const
SumSquaredError * get_sum_squared_error_pointer()
OptimizationMethod
Enumeration of all the available types of optimization algorithms.
string write_optimization_method_text() const
MeanSquaredError mean_squared_error
Pointer to the mean squared error object wich can be used as the error term.
DataSet * get_data_set_pointer()
Returns a pointer to the DataSet class.
const OptimizationMethod & get_optimization_method() const
Returns the type of the main optimization algorithm composing this training strategy object.
string write_loss_method_text() const
Returns a string with the main loss method type in text format.
OptimizationMethod optimization_method
Type of main optimization algorithm.
This class represents the weighted squared error term.
This structure contains the optimization algorithm results.