9#include "sum_squared_error.h"
30 :
LossIndex(new_neural_network_pointer, new_data_set_pointer)
42void SumSquaredError::calculate_error(
const DataSetBatch&,
46 Tensor<type, 0> sum_squared_error;
48 sum_squared_error.device(*thread_pool_device) = back_propagation.errors.contract(back_propagation.errors, SSE);
50 back_propagation.error = sum_squared_error(0);
54void SumSquaredError::calculate_error_lm(
const DataSetBatch&,
55 const NeuralNetworkForwardPropagation&,
56 LossIndexBackPropagationLM& back_propagation)
const
58 Tensor<type, 0> sum_squared_error;
60 sum_squared_error.device(*thread_pool_device) = (back_propagation.squared_errors*back_propagation.squared_errors).sum();
62 back_propagation.error = sum_squared_error(0);
66void SumSquaredError::calculate_output_delta(
const DataSetBatch&,
67 NeuralNetworkForwardPropagation&,
68 LossIndexBackPropagation& back_propagation)
const
78 LayerBackPropagation* output_layer_back_propagation = back_propagation.neural_network.layers(trainable_layers_number-1);
80 Layer* output_layer_pointer = output_layer_back_propagation->layer_pointer;
82 const type coefficient =
static_cast<type
>(2.0);
84 switch(output_layer_pointer->get_type())
86 case Layer::Type::Perceptron:
88 PerceptronLayerBackPropagation* perceptron_layer_back_propagation
89 =
static_cast<PerceptronLayerBackPropagation*
>(output_layer_back_propagation);
91 perceptron_layer_back_propagation->delta.device(*thread_pool_device) = coefficient*back_propagation.errors;
95 case Layer::Type::Probabilistic:
97 ProbabilisticLayerBackPropagation* probabilistic_layer_back_propagation
98 =
static_cast<ProbabilisticLayerBackPropagation*
>(output_layer_back_propagation);
100 probabilistic_layer_back_propagation->delta.device(*thread_pool_device) = coefficient*back_propagation.errors;
104 case Layer::Type::Recurrent:
106 RecurrentLayerBackPropagation* recurrent_layer_back_propagation
107 =
static_cast<RecurrentLayerBackPropagation*
>(output_layer_back_propagation);
109 recurrent_layer_back_propagation->delta.device(*thread_pool_device) = coefficient*back_propagation.errors;
113 case Layer::Type::LongShortTermMemory:
115 LongShortTermMemoryLayerBackPropagation* long_short_term_memory_layer_back_propagation
116 =
static_cast<LongShortTermMemoryLayerBackPropagation*
>(output_layer_back_propagation);
118 long_short_term_memory_layer_back_propagation->delta.device(*thread_pool_device) = coefficient*back_propagation.errors;
127void SumSquaredError::calculate_output_delta_lm(
const DataSetBatch&,
128 NeuralNetworkForwardPropagation&,
129 LossIndexBackPropagationLM& loss_index_back_propagation)
const
139 LayerBackPropagationLM* output_layer_back_propagation = loss_index_back_propagation.neural_network.layers(trainable_layers_number-1);
141 Layer* output_layer_pointer = output_layer_back_propagation->layer_pointer;
143 switch(output_layer_pointer->get_type())
145 case Layer::Type::Perceptron:
147 PerceptronLayerBackPropagationLM* perceptron_layer_back_propagation
148 =
static_cast<PerceptronLayerBackPropagationLM*
>(output_layer_back_propagation);
150 memcpy(perceptron_layer_back_propagation->delta.data(),
151 loss_index_back_propagation.errors.data(),
152 static_cast<size_t>(loss_index_back_propagation.errors.size())*
sizeof(type));
154 divide_columns(perceptron_layer_back_propagation->delta, loss_index_back_propagation.squared_errors);
158 case Layer::Type::Probabilistic:
160 ProbabilisticLayerBackPropagationLM* probabilistic_layer_back_propagation
161 =
static_cast<ProbabilisticLayerBackPropagationLM*
>(output_layer_back_propagation);
163 memcpy(probabilistic_layer_back_propagation->delta.data(),
164 loss_index_back_propagation.errors.data(),
165 static_cast<size_t>(loss_index_back_propagation.errors.size())*
sizeof(type));
167 divide_columns(probabilistic_layer_back_propagation->delta, loss_index_back_propagation.squared_errors);
173 ostringstream buffer;
175 buffer <<
"OpenNN Exception: MeanSquaredError class.\n"
176 <<
"Levenberg-Marquardt can only be used with Perceptron and Probabilistic layers.\n";
178 throw logic_error(buffer.str());
184void SumSquaredError::calculate_error_gradient_lm(
const DataSetBatch& ,
185 LossIndexBackPropagationLM& loss_index_back_propagation_lm)
const
193 const type coefficient = (
static_cast<type
>(2.0));
195 loss_index_back_propagation_lm.gradient.device(*thread_pool_device)
196 = loss_index_back_propagation_lm.squared_errors_jacobian.contract(loss_index_back_propagation_lm.squared_errors, AT_B);
198 loss_index_back_propagation_lm.gradient.device(*thread_pool_device)
199 = coefficient*loss_index_back_propagation_lm.gradient;
203void SumSquaredError::calculate_error_hessian_lm(
const DataSetBatch&,
204 LossIndexBackPropagationLM& loss_index_back_propagation_lm)
const
212 const type coefficient =
static_cast<type
>(2.0);
214 loss_index_back_propagation_lm.hessian.device(*thread_pool_device)
215 = loss_index_back_propagation_lm.squared_errors_jacobian.contract(loss_index_back_propagation_lm.squared_errors_jacobian, AT_B);
217 loss_index_back_propagation_lm.hessian.device(*thread_pool_device)
218 = coefficient*loss_index_back_propagation_lm.hessian;
226 return "SUM_SQUARED_ERROR";
234 return "Sum squared error";
245 file_stream.OpenElement(
"SumSquaredError");
260 ostringstream buffer;
262 buffer <<
"OpenNN Exception: SumSquaredError class.\n"
263 <<
"void from_XML(const tinyxml2::XMLDocument&) method.\n"
264 <<
"Sum squared element is nullptr.\n";
266 throw logic_error(buffer.str());
This class represents the concept of data set for data modelling problems, such as approximation,...
This abstract class represents the concept of loss index composed of an error term and a regularizati...
NeuralNetwork * neural_network_pointer
Pointer to a neural network object.
virtual ~SumSquaredError()
Destructor.
void from_XML(const tinyxml2::XMLDocument &)
string get_error_type() const
Returns a string with the name of the sum squared error loss type, "SUM_SQUARED_ERROR".
void write_XML(tinyxml2::XMLPrinter &) const
string get_error_type_text() const
Returns a string with the name of the sum squared error loss type in text format.
virtual void CloseElement(bool compactMode=false)
If streaming, close the Element.