weighted_squared_error.h
1// OpenNN: Open Neural Networks Library
2// www.opennn.net
3//
4// W E I G H T E D S Q U A R E D E R R O R C L A S S H E A D E R
5//
6// Artificial Intelligence Techniques SL
7// artelnics@artelnics.com
8
9#ifndef WEIGHTEDSQUAREDERROR_H
10#define WEIGHTEDSQUAREDERROR_H
11
12// System includes
13
14#include <string>
15#include <sstream>
16#include <iostream>
17#include <fstream>
18#include <limits>
19#include <math.h>
20
21// OpenNN includes
22
23#include "config.h"
24#include "loss_index.h"
25#include "data_set.h"
26
27namespace OpenNN
28{
29
31
36
38{
39
40public:
41
42 // Constructors
43
44 explicit WeightedSquaredError();
45
47
48 // Destructor
49
50 virtual ~WeightedSquaredError();
51
52 // Get methods
53
54 type get_positives_weight() const;
55 type get_negatives_weight() const;
56
57 type get_normalizaton_coefficient() const;
58
59 // Set methods
60
61 void set_default();
62
63 void set_positives_weight(const type&);
64 void set_negatives_weight(const type&);
65
66 void set_weights(const type&, const type&);
67
68 void set_weights();
69
71
73
74 type weighted_sum_squared_error(const Tensor<type, 2>& x, const Tensor<type, 2>& y) const;
75
76 string get_error_type() const;
77
78 string get_error_type_text() const;
79
80 // Back propagation
81
82 void calculate_error(const DataSetBatch&,
85
86 void calculate_output_delta(const DataSetBatch&,
89
90 // Back propagation LM
91
92 void calculate_squared_errors_lm(const DataSetBatch&,
95
96 void calculate_error_lm(const DataSetBatch&,
99
102
103 void calculate_error_hessian_lm(const DataSetBatch&,
105
106 // Serialization methods
107
108 void from_XML(const tinyxml2::XMLDocument&);
109
110 void write_XML(tinyxml2::XMLPrinter&) const;
111
112private:
113
115
116 type positives_weight = type(NAN);
117
119
120 type negatives_weight = type(NAN);
121
123
125
126#ifdef OPENNN_CUDA
127 #include "../../opennn-cuda/opennn-cuda/weighted_squared_error_cuda.h"
128#endif
129
130};
131
132}
133
134#endif
135
136
137// OpenNN: Open Neural Networks Library.
138// Copyright(C) 2005-2021 Artificial Intelligence Techniques, SL.
139//
140// This library is free software; you can redistribute it and/or
141// modify it under the terms of the GNU Lesser General Public
142// License as published by the Free Software Foundation; either
143// version 2.1 of the License, or any later version.
144//
145// This library is distributed in the hope that it will be useful,
146// but WITHOUT ANY WARRANTY; without even the implied warranty of
147// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
148// Lesser General Public License for more details.
149
150// You should have received a copy of the GNU Lesser General Public
151// License along with this library; if not, write to the Free Software
152// Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
This class represents the concept of data set for data modelling problems, such as approximation,...
Definition: data_set.h:57
This abstract class represents the concept of loss index composed of an error term and a regularizati...
Definition: loss_index.h:48
This class represents the weighted squared error term.
type positives_weight
Weight for the positives for the calculation of the error.
void set_weights()
Calculates of the weights for the positives and negatives values with the data of the data set.
type normalization_coefficient
Coefficient of normalization.
void set_data_set_pointer(DataSet *)
set_data_set_pointer
void set_normalization_coefficient()
Calculates of the normalization coefficient with the data of the data set.
void from_XML(const tinyxml2::XMLDocument &)
void set_default()
Set the default values for the object.
type get_positives_weight() const
Returns the weight of the positives.
type negatives_weight
Weight for the negatives for the calculation of the error.
string get_error_type() const
Returns a string with the name of the weighted squared error loss type, "WEIGHTED_SQUARED_ERROR".
void calculate_error_gradient_lm(const DataSetBatch &, LossIndexBackPropagationLM &) const
void write_XML(tinyxml2::XMLPrinter &) const
string get_error_type_text() const
Returns a string with the name of the weighted squared error loss type in text format.
type get_negatives_weight() const
Returns the weight of the negatives.
virtual ~WeightedSquaredError()
Destructor.
A loss index composed of several terms, this structure represent the First Order for this function.
Definition: loss_index.h:383