9#include "mean_squared_error.h"
32 :
LossIndex(new_neural_network_pointer, new_data_set_pointer)
53 Tensor<type, 0> sum_squared_error;
55 const Index batch_samples_number = batch.inputs_2d.dimension(0);
57 const type coefficient =
static_cast<type
>(batch_samples_number);
59 sum_squared_error.device(*thread_pool_device) = back_propagation.errors.contract(back_propagation.errors, SSE);
61 back_propagation.error = sum_squared_error(0)/coefficient;
65void MeanSquaredError::calculate_error_lm(
const DataSetBatch& batch,
69 Tensor<type, 0> sum_squared_error;
71 const Index batch_samples_number = batch.inputs_2d.dimension(0);
73 sum_squared_error.device(*thread_pool_device) = (back_propagation.squared_errors*back_propagation.squared_errors).sum();
75 const type coefficient =
static_cast<type
>(batch_samples_number);
77 back_propagation.error = sum_squared_error(0)/coefficient;
81void MeanSquaredError::calculate_output_delta(
const DataSetBatch& batch,
82 NeuralNetworkForwardPropagation&,
83 LossIndexBackPropagation& back_propagation)
const
91 LayerBackPropagation* output_layer_back_propagation = back_propagation.neural_network.layers(trainable_layers_number-1);
93 const Index batch_samples_number = batch.inputs_2d.dimension(0);
95 const type coefficient =
static_cast<type
>(2.0)/
static_cast<type
>(batch_samples_number);
97 switch(output_layer_back_propagation->layer_pointer->get_type())
99 case Layer::Type::Perceptron:
101 PerceptronLayerBackPropagation* perceptron_layer_back_propagation
102 =
static_cast<PerceptronLayerBackPropagation*
>(output_layer_back_propagation);
104 perceptron_layer_back_propagation->delta.device(*thread_pool_device) = coefficient*back_propagation.errors;
108 case Layer::Type::Probabilistic:
110 ProbabilisticLayerBackPropagation* probabilistic_layer_back_propagation
111 =
static_cast<ProbabilisticLayerBackPropagation*
>(output_layer_back_propagation);
113 probabilistic_layer_back_propagation->delta.device(*thread_pool_device) = coefficient*back_propagation.errors;
117 case Layer::Type::Recurrent:
119 RecurrentLayerBackPropagation* recurrent_layer_back_propagation
120 =
static_cast<RecurrentLayerBackPropagation*
>(output_layer_back_propagation);
122 recurrent_layer_back_propagation->delta.device(*thread_pool_device) = coefficient*back_propagation.errors;
126 case Layer::Type::LongShortTermMemory:
128 LongShortTermMemoryLayerBackPropagation* long_short_term_memory_layer_back_propagation
129 =
static_cast<LongShortTermMemoryLayerBackPropagation*
>(output_layer_back_propagation);
131 long_short_term_memory_layer_back_propagation->delta.device(*thread_pool_device) = coefficient*back_propagation.errors;
140void MeanSquaredError::calculate_output_delta_lm(
const DataSetBatch&,
141 NeuralNetworkForwardPropagation&,
142 LossIndexBackPropagationLM& loss_index_back_propagation)
const
150 LayerBackPropagationLM* output_layer_back_propagation = loss_index_back_propagation.neural_network.layers(trainable_layers_number-1);
152 Layer* output_layer_pointer = output_layer_back_propagation->layer_pointer;
154 switch(output_layer_pointer->get_type())
156 case Layer::Type::Perceptron:
158 PerceptronLayerBackPropagationLM* perceptron_layer_back_propagation
159 =
static_cast<PerceptronLayerBackPropagationLM*
>(output_layer_back_propagation);
161 memcpy(perceptron_layer_back_propagation->delta.data(),
162 loss_index_back_propagation.errors.data(),
163 static_cast<size_t>(loss_index_back_propagation.errors.size())*
sizeof(type));
165 divide_columns(perceptron_layer_back_propagation->delta, loss_index_back_propagation.squared_errors);
169 case Layer::Type::Probabilistic:
171 ProbabilisticLayerBackPropagationLM* probabilistic_layer_back_propagation
172 =
static_cast<ProbabilisticLayerBackPropagationLM*
>(output_layer_back_propagation);
174 memcpy(probabilistic_layer_back_propagation->delta.data(),
175 loss_index_back_propagation.errors.data(),
176 static_cast<size_t>(loss_index_back_propagation.errors.size())*
sizeof(type));
178 divide_columns(probabilistic_layer_back_propagation->delta, loss_index_back_propagation.squared_errors);
184 ostringstream buffer;
186 buffer <<
"OpenNN Exception: MeanSquaredError class.\n"
187 <<
"Levenberg-Marquardt can only be used with Perceptron and Probabilistic layers.\n";
189 throw logic_error(buffer.str());
195void MeanSquaredError::calculate_error_gradient_lm(
const DataSetBatch& batch,
196 LossIndexBackPropagationLM& loss_index_back_propagation_lm)
const
204 const Index batch_samples_number = batch.get_samples_number();
206 const type coefficient = type(2)/
static_cast<type
>(batch_samples_number);
208 loss_index_back_propagation_lm.gradient.device(*thread_pool_device)
209 = loss_index_back_propagation_lm.squared_errors_jacobian.contract(loss_index_back_propagation_lm.squared_errors, AT_B);
211 loss_index_back_propagation_lm.gradient.device(*thread_pool_device)
212 = coefficient * loss_index_back_propagation_lm.gradient;
216void MeanSquaredError::calculate_error_hessian_lm(
const DataSetBatch& batch,
217 LossIndexBackPropagationLM& loss_index_back_propagation_lm)
const
223 const Index batch_samples_number = batch.inputs_2d.dimension(0);
225 const type coefficient = (
static_cast<type
>(2.0)/
static_cast<type
>(batch_samples_number));
227 loss_index_back_propagation_lm.hessian.device(*thread_pool_device)
228 = loss_index_back_propagation_lm.squared_errors_jacobian.contract(loss_index_back_propagation_lm.squared_errors_jacobian, AT_B);
230 loss_index_back_propagation_lm.hessian.device(*thread_pool_device)
231 = coefficient*loss_index_back_propagation_lm.hessian;
239 return "MEAN_SQUARED_ERROR";
247 return "Mean squared error";
258 file_stream.OpenElement(
"MeanSquaredError");
This class represents the concept of data set for data modelling problems, such as approximation,...
This abstract class represents the concept of loss index composed of an error term and a regularizati...
NeuralNetwork * neural_network_pointer
Pointer to a neural network object.
virtual ~MeanSquaredError()
Destructor.
void calculate_error(const DataSetBatch &, const NeuralNetworkForwardPropagation &, LossIndexBackPropagation &) const
MeanSquaredError::calculate_error.
string get_error_type() const
Returns a string with the name of the mean squared error loss type, "MEAN_SQUARED_ERROR".
void write_XML(tinyxml2::XMLPrinter &) const
string get_error_type_text() const
Returns a string with the name of the mean squared error loss type in text format.
virtual void CloseElement(bool compactMode=false)
If streaming, close the Element.
A loss index composed of several terms, this structure represent the First Order for this function.