model_selection.cpp
1// OpenNN: Open Neural Networks Library
2// www.opennn.net
3//
4// M O D E L S E L E C T I O N C L A S S
5//
6// Artificial Intelligence Techniques SL
7// artelnics@artelnics.com
8
9#include "model_selection.h"
10
11namespace OpenNN
12{
13
15
17{
19}
20
21
24
25ModelSelection::ModelSelection(TrainingStrategy* new_training_strategy_pointer)
26{
27 set(new_training_strategy_pointer);
28
30}
31
32
34
36{
37}
38
39
41
43{
44#ifdef OPENNN_DEBUG
45
47 {
48 ostringstream buffer;
49
50 buffer << "OpenNN Exception: ModelSelection class.\n"
51 << "TrainingStrategy* get_training_strategy_pointer() const method.\n"
52 << "Training strategy pointer is nullptr.\n";
53
54 throw logic_error(buffer.str());
55 }
56
57#endif
58
60}
61
62
65
67{
69 {
70 return true;
71 }
72 else
73 {
74 return false;
75 }
76}
77
78
80
82{
84}
85
86
88
90{
92}
93
94
96
98{
99 return &growing_neurons;
100}
101
102
104
106{
107 return &growing_inputs;
108}
109
110
112
114{
115 return &pruning_inputs;
116}
117
118
120
122{
123 return &genetic_algorithm;
124}
125
126
128
130{
131 set_neurons_selection_method(NeuronsSelectionMethod::GROWING_NEURONS);
132
133 set_inputs_selection_method(InputsSelectionMethod::GROWING_INPUTS);
134
135 display = true;
136}
137
138
143
144void ModelSelection::set_display(const bool& new_display)
145{
146 display = new_display;
147
148 // Neurons selection
149
150 growing_neurons.set_display(new_display);
151
152 // Inputs selection
153
154 growing_inputs.set_display(new_display);
155
156 pruning_inputs.set_display(new_display);
157
158 genetic_algorithm.set_display(new_display);
159
160}
161
162
165
167{
168 neurons_selection_method = new_neurons_selection_method;
169}
170
171
174
175void ModelSelection::set_neurons_selection_method(const string& new_neurons_selection_method)
176{
177 if(new_neurons_selection_method == "GROWING_NEURONS")
178 {
179 set_neurons_selection_method(NeuronsSelectionMethod::GROWING_NEURONS);
180 }
181 else
182 {
183 ostringstream buffer;
184
185 buffer << "OpenNN Exception: ModelSelection class.\n"
186 << "void set_neurons_selection_method(const string&) method.\n"
187 << "Unknown neurons selection type: " << new_neurons_selection_method << ".\n";
188
189 throw logic_error(buffer.str());
190 }
191}
192
193
196
198{
199 inputs_selection_method = new_inputs_selection_method;
200}
201
202
205
206void ModelSelection::set_inputs_selection_method(const string& new_inputs_selection_method)
207{
208 if(new_inputs_selection_method == "GROWING_INPUTS")
209 {
210 set_inputs_selection_method(InputsSelectionMethod::GROWING_INPUTS);
211 }
212 else if(new_inputs_selection_method == "PRUNING_INPUTS")
213 {
214 set_inputs_selection_method(InputsSelectionMethod::PRUNING_INPUTS);
215 }
216 else if(new_inputs_selection_method == "GENETIC_ALGORITHM")
217 {
218 set_inputs_selection_method(InputsSelectionMethod::GENETIC_ALGORITHM);
219 }
220 else
221 {
222 ostringstream buffer;
223
224 buffer << "OpenNN Exception: ModelSelection class.\n"
225 << "void set_inputs_selection_method(const string&) method.\n"
226 << "Unknown inputs selection type: " << new_inputs_selection_method << ".\n";
227
228 throw logic_error(buffer.str());
229 }
230}
231
232
235
236void ModelSelection::set(TrainingStrategy* new_training_strategy_pointer)
237{
238 training_strategy_pointer = new_training_strategy_pointer;
239
240 // Neurons selection
241
242 growing_neurons.set_training_strategy_pointer(new_training_strategy_pointer);
243
244 // Inputs selection
245
246 growing_inputs.set(new_training_strategy_pointer);
247 pruning_inputs.set(new_training_strategy_pointer);
248 genetic_algorithm.set(new_training_strategy_pointer);
249}
250
251
253
255{
256
257 // Optimization algorithm
258
259 ostringstream buffer;
260
262 {
263 buffer << "OpenNN Exception: ModelSelection class.\n"
264 << "void check() const method.\n"
265 << "Pointer to training strategy is nullptr.\n";
266
267 throw logic_error(buffer.str());
268 }
269
270 // Loss index
271
272 const LossIndex* loss_index_pointer = training_strategy_pointer->get_loss_index_pointer();
273
274 if(!loss_index_pointer)
275 {
276 buffer << "OpenNN Exception: ModelSelection class.\n"
277 << "void check() const method.\n"
278 << "Pointer to loss index is nullptr.\n";
279
280 throw logic_error(buffer.str());
281 }
282
283 // Neural network
284
285 const NeuralNetwork* neural_network_pointer = loss_index_pointer->get_neural_network_pointer();
286
287 if(!neural_network_pointer)
288 {
289 buffer << "OpenNN Exception: ModelSelection class.\n"
290 << "void check() const method.\n"
291 << "Pointer to neural network is nullptr.\n";
292
293 throw logic_error(buffer.str());
294 }
295
296 if(neural_network_pointer->is_empty())
297 {
298 buffer << "OpenNN Exception: ModelSelection class.\n"
299 << "void check() const method.\n"
300 << "Multilayer Perceptron is empty.\n";
301
302 throw logic_error(buffer.str());
303 }
304
305 // Data set
306
307 const DataSet* data_set_pointer = loss_index_pointer->get_data_set_pointer();
308
309 if(!data_set_pointer)
310 {
311 buffer << "OpenNN Exception: ModelSelection class.\n"
312 << "void check() const method.\n"
313 << "Pointer to data set is nullptr.\n";
314
315 throw logic_error(buffer.str());
316 }
317
318 const Index selection_samples_number = data_set_pointer->get_selection_samples_number();
319
320 if(selection_samples_number == 0)
321 {
322 buffer << "OpenNN Exception: ModelSelection class.\n"
323 << "void check() const method.\n"
324 << "Number of selection samples is zero.\n";
325
326 throw logic_error(buffer.str());
327 }
328}
329
330
334
336{
338 {
339 case NeuronsSelectionMethod::GROWING_NEURONS:
341 }
343
344}
345
346
349
351{
353 {
354 case InputsSelectionMethod::GROWING_INPUTS:
356
357 case InputsSelectionMethod::PRUNING_INPUTS:
359
360 case InputsSelectionMethod::GENETIC_ALGORITHM:
362 }
363
364 return InputsSelectionResults();
365}
366
367
370
372{
373 // Model selection
374
375 file_stream.OpenElement("ModelSelection");
376
377 // Neurons selection
378
379 file_stream.OpenElement("NeuronsSelection");
380
381 file_stream.OpenElement("NeuronsSelectionMethod");
382 file_stream.PushText(write_neurons_selection_method().c_str());
383 file_stream.CloseElement();
384
385 growing_neurons.write_XML(file_stream);
386
387 file_stream.CloseElement();
388
389 // Inputs selection
390
391 file_stream.OpenElement("InputsSelection");
392
393 file_stream.OpenElement("InputsSelectionMethod");
394 file_stream.PushText(write_inputs_selection_method().c_str());
395 file_stream.CloseElement();
396
397 growing_inputs.write_XML(file_stream);
398 pruning_inputs.write_XML(file_stream);
399 genetic_algorithm.write_XML(file_stream);
400
401 file_stream.CloseElement();
402
403 // Model selection (end tag)
404
405 file_stream.CloseElement();
406}
407
408
411
413{
414 const tinyxml2::XMLElement* root_element = document.FirstChildElement("ModelSelection");
415
416 if(!root_element)
417 {
418 ostringstream buffer;
419
420 buffer << "OpenNN Exception: ModelSelection class.\n"
421 << "void from_XML(const tinyxml2::XMLDocument&) method.\n"
422 << "Model Selection element is nullptr.\n";
423
424 throw logic_error(buffer.str());
425 }
426
427 // Neurons Selection
428
429 const tinyxml2::XMLElement* neurons_selection_element = root_element->FirstChildElement("NeuronsSelection");
430
431 if(neurons_selection_element)
432 {
433 // Neurons selection method
434
435 const tinyxml2::XMLElement* neurons_selection_method_element = neurons_selection_element->FirstChildElement("NeuronsSelectionMethod");
436
437 set_neurons_selection_method(neurons_selection_method_element->GetText());
438
439 // Growing neurons
440
441 const tinyxml2::XMLElement* growing_neurons_element = neurons_selection_element->FirstChildElement("GrowingNeurons");
442
443 if(growing_neurons_element)
444 {
445 tinyxml2::XMLDocument growing_neurons_document;
446
447 tinyxml2::XMLElement* growing_neurons_element_copy = growing_neurons_document.NewElement("GrowingNeurons");
448
449 for(const tinyxml2::XMLNode* nodeFor=growing_neurons_element->FirstChild(); nodeFor; nodeFor=nodeFor->NextSibling())
450 {
451 tinyxml2::XMLNode* copy = nodeFor->DeepClone(&growing_neurons_document );
452 growing_neurons_element_copy->InsertEndChild(copy );
453 }
454
455 growing_neurons_document.InsertEndChild(growing_neurons_element_copy);
456
457 growing_neurons.from_XML(growing_neurons_document);
458 }
459 }
460
461 // Inputs Selection
462 {
463 const tinyxml2::XMLElement* inputs_selection_element = root_element->FirstChildElement("InputsSelection");
464
465 if(inputs_selection_element)
466 {
467 const tinyxml2::XMLElement* inputs_selection_method_element = inputs_selection_element->FirstChildElement("InputsSelectionMethod");
468
469 set_inputs_selection_method(inputs_selection_method_element->GetText());
470
471 // Growing inputs
472
473 const tinyxml2::XMLElement* growing_inputs_element = inputs_selection_element->FirstChildElement("GrowingInputs");
474
475 if(growing_inputs_element)
476 {
477 tinyxml2::XMLDocument growing_inputs_document;
478
479 tinyxml2::XMLElement* growing_inputs_element_copy = growing_inputs_document.NewElement("GrowingInputs");
480
481 for(const tinyxml2::XMLNode* nodeFor=growing_inputs_element->FirstChild(); nodeFor; nodeFor=nodeFor->NextSibling())
482 {
483 tinyxml2::XMLNode* copy = nodeFor->DeepClone(&growing_inputs_document );
484 growing_inputs_element_copy->InsertEndChild(copy );
485 }
486
487 growing_inputs_document.InsertEndChild(growing_inputs_element_copy);
488
489 growing_inputs.from_XML(growing_inputs_document);
490 }
491
492
493 // Pruning inputs
494
495 const tinyxml2::XMLElement* pruning_inputs_element = inputs_selection_element->FirstChildElement("PruningInputs");
496
497 if(pruning_inputs_element)
498 {
499 tinyxml2::XMLDocument pruning_inputs_document;
500
501 tinyxml2::XMLElement* pruning_inputs_element_copy = pruning_inputs_document.NewElement("PruningInputs");
502
503 for(const tinyxml2::XMLNode* nodeFor=pruning_inputs_element->FirstChild(); nodeFor; nodeFor=nodeFor->NextSibling())
504 {
505 tinyxml2::XMLNode* copy = nodeFor->DeepClone(&pruning_inputs_document );
506 pruning_inputs_element_copy->InsertEndChild(copy );
507 }
508
509 pruning_inputs_document.InsertEndChild(pruning_inputs_element_copy);
510
511 pruning_inputs.from_XML(pruning_inputs_document);
512 }
513
514 // Genetic algorithm
515
516 const tinyxml2::XMLElement* genetic_algorithm_element = inputs_selection_element->FirstChildElement("GeneticAlgorithm");
517
518 if(genetic_algorithm_element)
519 {
520 tinyxml2::XMLDocument genetic_algorithm_document;
521
522 tinyxml2::XMLElement* genetic_algorithm_element_copy = genetic_algorithm_document.NewElement("GeneticAlgorithm");
523
524 for(const tinyxml2::XMLNode* nodeFor=genetic_algorithm_element->FirstChild(); nodeFor; nodeFor=nodeFor->NextSibling())
525 {
526 tinyxml2::XMLNode* copy = nodeFor->DeepClone(&genetic_algorithm_document );
527 genetic_algorithm_element_copy->InsertEndChild(copy );
528 }
529
530 genetic_algorithm_document.InsertEndChild(genetic_algorithm_element_copy);
531
532 genetic_algorithm.from_XML(genetic_algorithm_document);
533 }
534 }
535 }
536}
537
538
539string ModelSelection::write_neurons_selection_method() const
540{
542 {
543 case NeuronsSelectionMethod::GROWING_NEURONS:
544 return "GROWING_NEURONS";
545 }
546
547 return string();
548}
549
550
551string ModelSelection::write_inputs_selection_method() const
552{
554 {
555 case InputsSelectionMethod::GROWING_INPUTS:
556 return "GROWING_INPUTS";
557
558 case InputsSelectionMethod::PRUNING_INPUTS:
559 return "PRUNING_INPUTS";
560
561 case InputsSelectionMethod::GENETIC_ALGORITHM:
562 return "GENETIC_ALGORITHM";
563 }
564
565 return string();
566}
567
568
570
572{
573// cout << to_string();
574}
575
576
579
580void ModelSelection::save(const string& file_name) const
581{
582 FILE * file = fopen(file_name.c_str(), "w");
583
584 tinyxml2::XMLPrinter printer(file);
585
586 write_XML(printer);
587
588 fclose(file);
589}
590
591
594
595void ModelSelection::load(const string& file_name)
596{
597 tinyxml2::XMLDocument document;
598
599 if(document.LoadFile(file_name.c_str()))
600 {
601 ostringstream buffer;
602
603 buffer << "OpenNN Exception: ModelSelection class.\n"
604 << "void load(const string&) method.\n"
605 << "Cannot load XML file " << file_name << ".\n";
606
607 throw logic_error(buffer.str());
608 }
609
610 from_XML(document);
611}
612
613}
614
615// OpenNN: Open Neural Networks Library.
616// Copyright(C) 2005-2021 Artificial Intelligence Techniques, SL.
617//
618// This library is free software; you can redistribute it and/or
619// modify it under the terms of the GNU Lesser General Public
620// License as published by the Free Software Foundation; either
621// version 2.1 of the License, or any later version.
622//
623// This library is distributed in the hope that it will be useful,
624// but WITHOUT ANY WARRANTY; without even the implied warranty of
625// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
626// Lesser General Public License for more details.
627
628// You should have received a copy of the GNU Lesser General Public
629// License along with this library; if not, write to the Free Software
630// Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
This class represents the concept of data set for data modelling problems, such as approximation,...
Definition: data_set.h:57
Index get_selection_samples_number() const
Returns the number of samples in the data set which will be used for selection.
Definition: data_set.cpp:1402
void from_XML(const tinyxml2::XMLDocument &)
InputsSelectionResults perform_inputs_selection()
Select the inputs with best generalization properties using the genetic algorithm.
void write_XML(tinyxml2::XMLPrinter &) const
This concrete class represents a growing inputs algorithm for the InputsSelection as part of the Mode...
void from_XML(const tinyxml2::XMLDocument &)
InputsSelectionResults perform_inputs_selection()
Perform inputs selection with the growing inputs method.
void write_XML(tinyxml2::XMLPrinter &) const
This concrete class represents an growing neurons algorithm for the NeuronsSelection as part of the M...
void from_XML(const tinyxml2::XMLDocument &)
NeuronsSelectionResults perform_neurons_selection()
Perform neurons selection with the growing neurons method.
void write_XML(tinyxml2::XMLPrinter &) const
void set(TrainingStrategy *)
void set_display(const bool &)
This abstract class represents the concept of loss index composed of an error term and a regularizati...
Definition: loss_index.h:48
NeuralNetwork * get_neural_network_pointer() const
Returns a pointer to the neural network object associated to the error term.
Definition: loss_index.h:70
DataSet * get_data_set_pointer() const
Returns a pointer to the data set object associated to the error term.
Definition: loss_index.h:92
const InputsSelectionMethod & get_inputs_selection_method() const
Returns the type of algorithm for the inputs selection.
GrowingInputs * get_growing_inputs_pointer()
Returns a pointer to the growing inputs selection algorithm.
PruningInputs pruning_inputs
Pruning inputs object to be used for inputs selection.
TrainingStrategy * training_strategy_pointer
Pointer to a training strategy object.
NeuronsSelectionMethod
Enumeration of all the available neurons selection algorithms.
void from_XML(const tinyxml2::XMLDocument &)
void set_default()
Sets the members of the model selection object to their default values.
GrowingInputs growing_inputs
Growing inputs object to be used for inputs selection.
ModelSelection()
Default constructor.
void check() const
Checks that the different pointers needed for performing the model selection are not nullptr.
void load(const string &)
NeuronsSelectionResults perform_neurons_selection()
void set_inputs_selection_method(const InputsSelectionMethod &)
GrowingNeurons * get_growing_neurons_pointer()
Returns a pointer to the growing neurons selection algorithm.
bool display
Display messages to screen.
GeneticAlgorithm genetic_algorithm
Genetic algorithm object to be used for inputs selection.
InputsSelectionResults perform_inputs_selection()
const NeuronsSelectionMethod & get_neurons_selection_method() const
Returns the type of algorithm for the neurons selection.
void save(const string &) const
virtual ~ModelSelection()
Destructor.
GrowingNeurons growing_neurons
Growing neurons object to be used for neurons selection.
TrainingStrategy * get_training_strategy_pointer() const
Returns a pointer to the training strategy object.
NeuronsSelectionMethod neurons_selection_method
Type of neurons selection algorithm.
void set(TrainingStrategy *)
InputsSelectionMethod
Enumeration of all the available inputs selection algorithms.
void print() const
Prints to the screen the XML representation of this model selection object.
void set_display(const bool &)
PruningInputs * get_pruning_inputs_pointer()
Returns a pointer to the pruning inputs selection algorithm.
void write_XML(tinyxml2::XMLPrinter &) const
bool has_training_strategy() const
GeneticAlgorithm * get_genetic_algorithm_pointer()
Returns a pointer to the genetic inputs selection algorithm.
InputsSelectionMethod inputs_selection_method
Type of inputs selection algorithm.
void set_neurons_selection_method(const NeuronsSelectionMethod &)
void set_training_strategy_pointer(TrainingStrategy *)
void set_display(const bool &)
This concrete class represents a pruning inputs algorithm for the InputsSelection as part of the Mode...
void from_XML(const tinyxml2::XMLDocument &)
InputsSelectionResults perform_inputs_selection()
Perform the inputs selection with the pruning inputs method.
void write_XML(tinyxml2::XMLPrinter &) const
This class represents the concept of training strategy for a neural network in OpenNN.
LossIndex * get_loss_index_pointer()
Returns a pointer to the LossIndex class.
const XMLNode * NextSibling() const
Get the next(right) sibling node of this node.
Definition: tinyxml2.h:809
const XMLNode * FirstChild() const
Get the first child node, or null if none exists.
Definition: tinyxml2.h:757
void PushText(const char *text, bool cdata=false)
Add a text node.
Definition: tinyxml2.cpp:2878
virtual void CloseElement(bool compactMode=false)
If streaming, close the Element.
Definition: tinyxml2.cpp:2834
This structure contains the results from the inputs selection.
This structure contains the results from the neurons selection.