training_strategy.cpp
1// OpenNN: Open Neural Networks Library+
2// www.opennn.net
3//
4// T R A I N I N G S T R A T E G Y C L A S S
5//
6// Artificial Intelligence Techniques SL
7// artelnics@artelnics.com
8
9#include "training_strategy.h"
10#include "optimization_algorithm.h"
11
12namespace OpenNN
13{
14
18
20{
21 data_set_pointer = nullptr;
22
23 neural_network_pointer = nullptr;
24
25 set_loss_method(LossMethod::NORMALIZED_SQUARED_ERROR);
26
27 set_optimization_method(OptimizationMethod::QUASI_NEWTON_METHOD);
28
29 LossIndex* loss_index_pointer = get_loss_index_pointer();
30
31 set_loss_index_pointer(loss_index_pointer);
32
34
35}
36
37
41
42TrainingStrategy::TrainingStrategy(NeuralNetwork* new_neural_network_pointer, DataSet* new_data_set_pointer)
43{
44 data_set_pointer = new_data_set_pointer;
45
46 neural_network_pointer = new_neural_network_pointer;
47
48 set_optimization_method(OptimizationMethod::QUASI_NEWTON_METHOD);
49 set_loss_method(LossMethod::NORMALIZED_SQUARED_ERROR);
50
51 set_loss_index_neural_network_pointer(neural_network_pointer);
52 set_loss_index_data_set_pointer(data_set_pointer);
53
55}
56
57
60
62{
63}
64
65
67
69{
70 return data_set_pointer;
71}
72
73
75
77{
78 return neural_network_pointer;
79}
80
81
83
85{
86 switch(loss_method)
87 {
88 case LossMethod::SUM_SQUARED_ERROR: return &sum_squared_error;
89
90 case LossMethod::MEAN_SQUARED_ERROR: return &mean_squared_error;
91
92 case LossMethod::NORMALIZED_SQUARED_ERROR: return &normalized_squared_error;
93
94 case LossMethod::MINKOWSKI_ERROR: return &Minkowski_error;
95
96 case LossMethod::WEIGHTED_SQUARED_ERROR: return &weighted_squared_error;
97
98 case LossMethod::CROSS_ENTROPY_ERROR: return &cross_entropy_error;
99 }
100
101 return nullptr;
102}
103
104
106
108{
109 switch(optimization_method)
110 {
111 case OptimizationMethod::GRADIENT_DESCENT: return &gradient_descent;
112
113 case OptimizationMethod::CONJUGATE_GRADIENT: return &conjugate_gradient;
114
115 case OptimizationMethod::STOCHASTIC_GRADIENT_DESCENT: return &stochastic_gradient_descent;
116
117 case OptimizationMethod::ADAPTIVE_MOMENT_ESTIMATION: return &adaptive_moment_estimation;
118
119 case OptimizationMethod::QUASI_NEWTON_METHOD: return &quasi_Newton_method;
120
121 case OptimizationMethod::LEVENBERG_MARQUARDT_ALGORITHM: return &Levenberg_Marquardt_algorithm;
122 }
123
124 return nullptr;
125}
126
127
128bool TrainingStrategy::has_neural_network() const
129{
130 if(neural_network_pointer == nullptr) return false;
131
132 return true;
133}
134
135
136bool TrainingStrategy::has_data_set() const
137{
138 if(data_set_pointer == nullptr) return false;
139
140 return true;
141}
142
143
146
148{
149 return &gradient_descent;
150}
151
152
155
157{
158 return &conjugate_gradient;
159}
160
161
164
166{
167 return &quasi_Newton_method;
168}
169
170
173
175{
177}
178
179
182
184{
186}
187
188
191
193{
195}
196
197
200
202{
203 return &sum_squared_error;
204}
205
206
209
211{
212 return &mean_squared_error;
213}
214
215
218
220{
221
223}
224
225
226
229
231{
232
233 return &Minkowski_error;
234}
235
236
239
241{
242 return &cross_entropy_error;
243}
244
245
248
250{
252}
253
254
256
258{
259 return loss_method;
260}
261
262
264
266{
267 return optimization_method;
268}
269
270
272
274{
275 switch(loss_method)
276 {
277 case LossMethod::SUM_SQUARED_ERROR:
278 return "SUM_SQUARED_ERROR";
279
280 case LossMethod::MEAN_SQUARED_ERROR:
281 return "MEAN_SQUARED_ERROR";
282
283 case LossMethod::NORMALIZED_SQUARED_ERROR:
284 return "NORMALIZED_SQUARED_ERROR";
285
286 case LossMethod::MINKOWSKI_ERROR:
287 return "MINKOWSKI_ERROR";
288
289 case LossMethod::WEIGHTED_SQUARED_ERROR:
290 return "WEIGHTED_SQUARED_ERROR";
291
292 case LossMethod::CROSS_ENTROPY_ERROR:
293 return "CROSS_ENTROPY_ERROR";
294 }
295
296 return string();
297}
298
299
302
304{
305 if(optimization_method == OptimizationMethod::GRADIENT_DESCENT)
306 {
307 return "GRADIENT_DESCENT";
308 }
309 else if(optimization_method == OptimizationMethod::CONJUGATE_GRADIENT)
310 {
311 return "CONJUGATE_GRADIENT";
312 }
313 else if(optimization_method == OptimizationMethod::QUASI_NEWTON_METHOD)
314 {
315 return "QUASI_NEWTON_METHOD";
316 }
317 else if(optimization_method == OptimizationMethod::LEVENBERG_MARQUARDT_ALGORITHM)
318 {
319 return "LEVENBERG_MARQUARDT_ALGORITHM";
320 }
321 else if(optimization_method == OptimizationMethod::STOCHASTIC_GRADIENT_DESCENT)
322 {
323 return "STOCHASTIC_GRADIENT_DESCENT";
324 }
325 else if(optimization_method == OptimizationMethod::ADAPTIVE_MOMENT_ESTIMATION)
326 {
327 return "ADAPTIVE_MOMENT_ESTIMATION";
328 }
329 else
330 {
331 ostringstream buffer;
332
333 buffer << "OpenNN Exception: TrainingStrategy class.\n"
334 << "string write_optimization_method() const method.\n"
335 << "Unknown main type.\n";
336
337 throw logic_error(buffer.str());
338 }
339}
340
341
344
346{
347 if(optimization_method == OptimizationMethod::GRADIENT_DESCENT)
348 {
349 return "gradient descent";
350 }
351 else if(optimization_method == OptimizationMethod::CONJUGATE_GRADIENT)
352 {
353 return "conjugate gradient";
354 }
355 else if(optimization_method == OptimizationMethod::QUASI_NEWTON_METHOD)
356 {
357 return "quasi-Newton method";
358 }
359 else if(optimization_method == OptimizationMethod::LEVENBERG_MARQUARDT_ALGORITHM)
360 {
361 return "Levenberg-Marquardt algorithm";
362 }
363 else if(optimization_method == OptimizationMethod::STOCHASTIC_GRADIENT_DESCENT)
364 {
365 return "stochastic gradient descent";
366 }
367 else if(optimization_method == OptimizationMethod::ADAPTIVE_MOMENT_ESTIMATION)
368 {
369 return "adaptive moment estimation";
370 }
371 else
372 {
373 ostringstream buffer;
374
375 buffer << "OpenNN Exception: TrainingStrategy class.\n"
376 << "string write_optimization_method_text() const method.\n"
377 << "Unknown main type.\n";
378
379 throw logic_error(buffer.str());
380 }
381}
382
383
385
387{
388 switch(loss_method)
389 {
390 case LossMethod::SUM_SQUARED_ERROR:
391 return "Sum squared error";
392
393 case LossMethod::MEAN_SQUARED_ERROR:
394 return "Mean squared error";
395
396 case LossMethod::NORMALIZED_SQUARED_ERROR:
397 return "Normalized squared error";
398
399 case LossMethod::MINKOWSKI_ERROR:
400 return "Minkowski error";
401
402 case LossMethod::WEIGHTED_SQUARED_ERROR:
403 return "Weighted squared error";
404
405 case LossMethod::CROSS_ENTROPY_ERROR:
406 return "Cross entropy error";
407 }
408
409 return string();
410}
411
412
415
417{
418 return display;
419}
420
421
425
427{
428 set_optimization_method(OptimizationMethod::QUASI_NEWTON_METHOD);
429
430 set_default();
431}
432
433
434void TrainingStrategy::set(NeuralNetwork* new_neural_network_pointer, DataSet* new_data_set_pointer)
435{
436 set_neural_network_pointer(new_neural_network_pointer);
437
438 set_data_set_pointer(new_data_set_pointer);
439}
440
441
445
446void TrainingStrategy::set_loss_method(const string& new_loss_method)
447{
448 if(new_loss_method == "SUM_SQUARED_ERROR")
449 {
450 set_loss_method(LossMethod::SUM_SQUARED_ERROR);
451 }
452 else if(new_loss_method == "MEAN_SQUARED_ERROR")
453 {
454 set_loss_method(LossMethod::MEAN_SQUARED_ERROR);
455 }
456 else if(new_loss_method == "NORMALIZED_SQUARED_ERROR")
457 {
458 set_loss_method(LossMethod::NORMALIZED_SQUARED_ERROR);
459 }
460 else if(new_loss_method == "MINKOWSKI_ERROR")
461 {
462 set_loss_method(LossMethod::MINKOWSKI_ERROR);
463 }
464 else if(new_loss_method == "WEIGHTED_SQUARED_ERROR")
465 {
466 set_loss_method(LossMethod::WEIGHTED_SQUARED_ERROR);
467 }
468 else if(new_loss_method == "CROSS_ENTROPY_ERROR")
469 {
470 set_loss_method(LossMethod::CROSS_ENTROPY_ERROR);
471 }
472 else
473 {
474 ostringstream buffer;
475
476 buffer << "OpenNN Exception: TrainingStrategy class.\n"
477 << "void set_loss_method(const string&) method.\n"
478 << "Unknown loss method: " << new_loss_method << ".\n";
479
480 throw logic_error(buffer.str());
481 }
482}
483
484
488
490{
491 loss_method = new_loss_method;
492
494}
495
496
499
501{
502 optimization_method = new_optimization_method;
503}
504
505
508
509void TrainingStrategy::set_optimization_method(const string& new_optimization_method)
510{
511 if(new_optimization_method == "GRADIENT_DESCENT")
512 {
513 set_optimization_method(OptimizationMethod::GRADIENT_DESCENT);
514 }
515 else if(new_optimization_method == "CONJUGATE_GRADIENT")
516 {
517 set_optimization_method(OptimizationMethod::CONJUGATE_GRADIENT);
518 }
519 else if(new_optimization_method == "QUASI_NEWTON_METHOD")
520 {
521 set_optimization_method(OptimizationMethod::QUASI_NEWTON_METHOD);
522 }
523 else if(new_optimization_method == "LEVENBERG_MARQUARDT_ALGORITHM")
524 {
525 set_optimization_method(OptimizationMethod::LEVENBERG_MARQUARDT_ALGORITHM);
526 }
527 else if(new_optimization_method == "STOCHASTIC_GRADIENT_DESCENT")
528 {
529 set_optimization_method(OptimizationMethod::STOCHASTIC_GRADIENT_DESCENT);
530 }
531 else if(new_optimization_method == "ADAPTIVE_MOMENT_ESTIMATION")
532 {
533 set_optimization_method(OptimizationMethod::ADAPTIVE_MOMENT_ESTIMATION);
534 }
535 else
536 {
537 ostringstream buffer;
538
539 buffer << "OpenNN Exception: TrainingStrategy class.\n"
540 << "void set_optimization_method(const string&) method.\n"
541 << "Unknown main type: " << new_optimization_method << ".\n";
542
543 throw logic_error(buffer.str());
544 }
545}
546
547
548void TrainingStrategy::set_threads_number(const int& new_threads_number)
549{
550 set_loss_index_threads_number(new_threads_number);
551
552 set_optimization_algorithm_threads_number(new_threads_number);
553}
554
555
556void TrainingStrategy::set_data_set_pointer(DataSet* new_data_set_pointer)
557{
558 data_set_pointer = new_data_set_pointer;
559
560 set_loss_index_data_set_pointer(data_set_pointer);
561}
562
563
564void TrainingStrategy::set_neural_network_pointer(NeuralNetwork* new_neural_network_pointer)
565{
566 neural_network_pointer = new_neural_network_pointer;
567
568 set_loss_index_neural_network_pointer(neural_network_pointer);
569}
570
571
572void TrainingStrategy::set_loss_index_threads_number(const int& new_threads_number)
573{
574 sum_squared_error.set_threads_number(new_threads_number);
575 mean_squared_error.set_threads_number(new_threads_number);
576 normalized_squared_error.set_threads_number(new_threads_number);
577 Minkowski_error.set_threads_number(new_threads_number);
578 weighted_squared_error.set_threads_number(new_threads_number);
579 cross_entropy_error.set_threads_number(new_threads_number);
580}
581
582
583void TrainingStrategy::set_optimization_algorithm_threads_number(const int& new_threads_number)
584{
585 gradient_descent.set_threads_number(new_threads_number);
586 conjugate_gradient.set_threads_number(new_threads_number);
587 quasi_Newton_method.set_threads_number(new_threads_number);
588 Levenberg_Marquardt_algorithm.set_threads_number(new_threads_number);
589 stochastic_gradient_descent.set_threads_number(new_threads_number);
590 adaptive_moment_estimation.set_threads_number(new_threads_number);
591}
592
593
596
598{
599 gradient_descent.set_loss_index_pointer(new_loss_index_pointer);
600 conjugate_gradient.set_loss_index_pointer(new_loss_index_pointer);
603 quasi_Newton_method.set_loss_index_pointer(new_loss_index_pointer);
605}
606
607
608void TrainingStrategy::set_loss_index_data_set_pointer(DataSet* new_data_set_pointer)
609{
610 sum_squared_error.set_data_set_pointer(new_data_set_pointer);
611 mean_squared_error.set_data_set_pointer(new_data_set_pointer);
613 cross_entropy_error.set_data_set_pointer(new_data_set_pointer);
614 weighted_squared_error.set_data_set_pointer(new_data_set_pointer);
615 Minkowski_error.set_data_set_pointer(new_data_set_pointer);
616}
617
618
619void TrainingStrategy::set_loss_index_neural_network_pointer(NeuralNetwork* new_neural_network_pointer)
620{
621 sum_squared_error.set_neural_network_pointer(new_neural_network_pointer);
622 mean_squared_error.set_neural_network_pointer(new_neural_network_pointer);
623 normalized_squared_error.set_neural_network_pointer(new_neural_network_pointer);
624 cross_entropy_error.set_neural_network_pointer(new_neural_network_pointer);
625 weighted_squared_error.set_neural_network_pointer(new_neural_network_pointer);
626 Minkowski_error.set_neural_network_pointer(new_neural_network_pointer);
627}
628
629
634
635void TrainingStrategy::set_display(const bool& new_display)
636{
637 display = new_display;
638
639 // Loss index
640
647
648 // Optimization algorithm
649
656}
657
658
659void TrainingStrategy::set_loss_goal(const type& new_loss_goal)
660{
661 gradient_descent.set_loss_goal(new_loss_goal);
662 conjugate_gradient.set_loss_goal(new_loss_goal);
663 quasi_Newton_method.set_loss_goal(new_loss_goal);
665}
666
667
668void TrainingStrategy::set_maximum_selection_failures(const Index& maximum_selection_failures)
669{
670 gradient_descent.set_maximum_selection_failures(maximum_selection_failures);
671 conjugate_gradient.set_maximum_selection_failures(maximum_selection_failures);
672 quasi_Newton_method.set_maximum_selection_failures(maximum_selection_failures);
674}
675
676
677void TrainingStrategy::set_maximum_epochs_number(const int & maximum_epochs_number)
678{
679 gradient_descent.set_maximum_epochs_number(maximum_epochs_number);
680 conjugate_gradient.set_maximum_epochs_number(maximum_epochs_number);
683 quasi_Newton_method.set_maximum_epochs_number(maximum_epochs_number);
685}
686
687
688void TrainingStrategy::set_display_period(const int & display_period)
689{
691}
692
693
694void TrainingStrategy::set_maximum_time(const type& maximum_time)
695{
702}
703
704
709
711{
712}
713
714
718
720{
721 if(neural_network_pointer->has_long_short_term_memory_layer() || neural_network_pointer->has_recurrent_layer())
722 {
724 }
725
726 if(neural_network_pointer->has_convolutional_layer())
727 {
728 ostringstream buffer;
729
730 buffer << "OpenNN Exception: TrainingStrategy class.\n"
731 << "TrainingResults perform_training() const method.\n"
732 << "Convolutional Layer is not available yet. It will be included in future versions.\n";
733
734 throw logic_error(buffer.str());
735 }
736
737 switch(optimization_method)
738 {
739 case OptimizationMethod::GRADIENT_DESCENT:
740 {
742
744 }
745
746 case OptimizationMethod::CONJUGATE_GRADIENT:
747 {
749
751 }
752
753 case OptimizationMethod::QUASI_NEWTON_METHOD:
754 {
756
758
759 }
760
761 case OptimizationMethod::LEVENBERG_MARQUARDT_ALGORITHM:
762 {
764
766 }
767
768 case OptimizationMethod::STOCHASTIC_GRADIENT_DESCENT:
769 {
771
773 }
774
775 case OptimizationMethod::ADAPTIVE_MOMENT_ESTIMATION:
776 {
778
780 }
781 }
782
783 return TrainingResults(0);
784}
785
786
790
792{
793 Index timesteps = 0;
794
795 if(neural_network_pointer->has_recurrent_layer())
796 {
797 timesteps = neural_network_pointer->get_recurrent_layer_pointer()->get_timesteps();
798 }
799 else if(neural_network_pointer->has_long_short_term_memory_layer())
800 {
801 timesteps = neural_network_pointer->get_long_short_term_memory_layer_pointer()->get_timesteps();
802 }
803 else
804 {
805 return;
806 }
807
808 Index batch_samples_number = 0;
809
810 if(optimization_method == OptimizationMethod::ADAPTIVE_MOMENT_ESTIMATION)
811 {
812 batch_samples_number = adaptive_moment_estimation.get_batch_samples_number();
813 }
814 else if(optimization_method == OptimizationMethod::STOCHASTIC_GRADIENT_DESCENT)
815 {
816 batch_samples_number = stochastic_gradient_descent.get_batch_samples_number();
817 }
818 else
819 {
820 return;
821 }
822
823 if(batch_samples_number%timesteps == 0)
824 {
825 return;
826 }
827 else
828 {
829 const Index constant = static_cast<Index>(batch_samples_number/timesteps);
830
831 if(optimization_method == OptimizationMethod::ADAPTIVE_MOMENT_ESTIMATION)
832 {
834 }
835 else if(optimization_method == OptimizationMethod::STOCHASTIC_GRADIENT_DESCENT)
836 {
837 stochastic_gradient_descent.set_batch_samples_number(constant*timesteps);
838 }
839 }
840}
841
842
844
846{
847 cout << "Training strategy object" << endl;
848 cout << "Loss index: " << write_loss_method() << endl;
849 cout << "Optimization algorithm: " << write_optimization_method() << endl;
850}
851
852
855
857{
858 file_stream.OpenElement("TrainingStrategy");
859
860 // Loss index
861
862 file_stream.OpenElement("LossIndex");
863
864 // Loss method
865
866 file_stream.OpenElement("LossMethod");
867 file_stream.PushText(write_loss_method().c_str());
868 file_stream.CloseElement();
869
870 mean_squared_error.write_XML(file_stream);
872 Minkowski_error.write_XML(file_stream);
873 cross_entropy_error.write_XML(file_stream);
875
876 switch(loss_method)
877 {
878 case LossMethod::MEAN_SQUARED_ERROR : mean_squared_error.write_regularization_XML(file_stream); break;
879 case LossMethod::NORMALIZED_SQUARED_ERROR : normalized_squared_error.write_regularization_XML(file_stream); break;
880 case LossMethod::MINKOWSKI_ERROR : Minkowski_error.write_regularization_XML(file_stream); break;
881 case LossMethod::CROSS_ENTROPY_ERROR : cross_entropy_error.write_regularization_XML(file_stream); break;
882 case LossMethod::WEIGHTED_SQUARED_ERROR : weighted_squared_error.write_regularization_XML(file_stream); break;
883 case LossMethod::SUM_SQUARED_ERROR : sum_squared_error.write_regularization_XML(file_stream); break;
884 }
885
886 file_stream.CloseElement();
887
888 // Optimization algorithm
889
890 file_stream.OpenElement("OptimizationAlgorithm");
891
892 file_stream.OpenElement("OptimizationMethod");
893 file_stream.PushText(write_optimization_method().c_str());
894 file_stream.CloseElement();
895
896 gradient_descent.write_XML(file_stream);
897 conjugate_gradient.write_XML(file_stream);
900 quasi_Newton_method.write_XML(file_stream);
902
903 file_stream.CloseElement();
904
905 // Close TrainingStrategy
906
907 file_stream.CloseElement();
908}
909
910
913
915{
916 const tinyxml2::XMLElement* root_element = document.FirstChildElement("TrainingStrategy");
917
918 if(!root_element)
919 {
920 ostringstream buffer;
921
922 buffer << "OpenNN Exception: TrainingStrategy class.\n"
923 << "void from_XML(const tinyxml2::XMLDocument&) method.\n"
924 << "Training strategy element is nullptr.\n";
925
926 throw logic_error(buffer.str());
927 }
928
929 // Loss index
930
931 const tinyxml2::XMLElement* loss_index_element = root_element->FirstChildElement("LossIndex");
932
933 if(loss_index_element)
934 {
935 const tinyxml2::XMLElement* loss_method_element = loss_index_element->FirstChildElement("LossMethod");
936
937 set_loss_method(loss_method_element->GetText());
938
939 // Minkowski error
940
941 const tinyxml2::XMLElement* Minkowski_error_element = loss_index_element->FirstChildElement("MinkowskiError");
942
943 if(Minkowski_error_element)
944 {
945 tinyxml2::XMLDocument new_document;
946
947 tinyxml2::XMLElement* Minkowski_error_element_copy = new_document.NewElement("MinkowskiError");
948
949 for(const tinyxml2::XMLNode* nodeFor=Minkowski_error_element->FirstChild(); nodeFor; nodeFor=nodeFor->NextSibling())
950 {
951 tinyxml2::XMLNode* copy = nodeFor->DeepClone(&new_document );
952 Minkowski_error_element_copy->InsertEndChild(copy );
953 }
954
955 new_document.InsertEndChild(Minkowski_error_element_copy);
956
957 Minkowski_error.from_XML(new_document);
958 }
959 else
960 {
962 }
963
964 // Cross entropy error
965
966 const tinyxml2::XMLElement* cross_entropy_element = loss_index_element->FirstChildElement("CrossEntropyError");
967
968 if(cross_entropy_element)
969 {
970 tinyxml2::XMLDocument new_document;
971
972 tinyxml2::XMLElement* cross_entropy_error_element_copy = new_document.NewElement("CrossEntropyError");
973
974 for(const tinyxml2::XMLNode* nodeFor=loss_index_element->FirstChild(); nodeFor; nodeFor=nodeFor->NextSibling())
975 {
976 tinyxml2::XMLNode* copy = nodeFor->DeepClone(&new_document );
977 cross_entropy_error_element_copy->InsertEndChild(copy );
978 }
979
980 new_document.InsertEndChild(cross_entropy_error_element_copy);
981
982 cross_entropy_error.from_XML(new_document);
983 }
984
985 // Weighted squared error
986
987 const tinyxml2::XMLElement* weighted_squared_error_element = loss_index_element->FirstChildElement("WeightedSquaredError");
988
989 if(weighted_squared_error_element)
990 {
991 tinyxml2::XMLDocument new_document;
992
993 tinyxml2::XMLElement* weighted_squared_error_element_copy = new_document.NewElement("WeightedSquaredError");
994
995 for(const tinyxml2::XMLNode* nodeFor=weighted_squared_error_element->FirstChild(); nodeFor; nodeFor=nodeFor->NextSibling())
996 {
997 tinyxml2::XMLNode* copy = nodeFor->DeepClone(&new_document );
998 weighted_squared_error_element_copy->InsertEndChild(copy );
999 }
1000
1001 new_document.InsertEndChild(weighted_squared_error_element_copy);
1002
1003 weighted_squared_error.from_XML(new_document);
1004 }
1005 else
1006 {
1009 }
1010
1011 // Regularization
1012
1013 const tinyxml2::XMLElement* regularization_element = loss_index_element->FirstChildElement("Regularization");
1014
1015 if(regularization_element)
1016 {
1017 tinyxml2::XMLDocument regularization_document;
1018 tinyxml2::XMLNode* element_clone;
1019
1020 element_clone = regularization_element->DeepClone(&regularization_document);
1021
1022 regularization_document.InsertFirstChild(element_clone);
1023
1024 get_loss_index_pointer()->regularization_from_XML(regularization_document);
1025 }
1026 }
1027
1028 // Optimization algorithm
1029
1030 const tinyxml2::XMLElement* optimization_algorithm_element = root_element->FirstChildElement("OptimizationAlgorithm");
1031
1032 if(optimization_algorithm_element)
1033 {
1034 const tinyxml2::XMLElement* optimization_method_element = optimization_algorithm_element->FirstChildElement("OptimizationMethod");
1035
1036 set_optimization_method(optimization_method_element->GetText());
1037
1038 // Gradient descent
1039
1040 const tinyxml2::XMLElement* gradient_descent_element = optimization_algorithm_element->FirstChildElement("GradientDescent");
1041
1042 if(gradient_descent_element)
1043 {
1044 tinyxml2::XMLDocument gradient_descent_document;
1045
1046 tinyxml2::XMLElement* gradient_descent_element_copy = gradient_descent_document.NewElement("GradientDescent");
1047
1048 for(const tinyxml2::XMLNode* nodeFor=gradient_descent_element->FirstChild(); nodeFor; nodeFor=nodeFor->NextSibling())
1049 {
1050 tinyxml2::XMLNode* copy = nodeFor->DeepClone(&gradient_descent_document );
1051 gradient_descent_element_copy->InsertEndChild(copy );
1052 }
1053
1054 gradient_descent_document.InsertEndChild(gradient_descent_element_copy);
1055
1056 gradient_descent.from_XML(gradient_descent_document);
1057 }
1058
1059 // Conjugate gradient
1060
1061 const tinyxml2::XMLElement* conjugate_gradient_element = optimization_algorithm_element->FirstChildElement("ConjugateGradient");
1062
1063 if(conjugate_gradient_element)
1064 {
1065 tinyxml2::XMLDocument conjugate_gradient_document;
1066
1067 tinyxml2::XMLElement* conjugate_gradient_element_copy = conjugate_gradient_document.NewElement("ConjugateGradient");
1068
1069 for(const tinyxml2::XMLNode* nodeFor=conjugate_gradient_element->FirstChild(); nodeFor; nodeFor=nodeFor->NextSibling())
1070 {
1071 tinyxml2::XMLNode* copy = nodeFor->DeepClone(&conjugate_gradient_document );
1072 conjugate_gradient_element_copy->InsertEndChild(copy );
1073 }
1074
1075 conjugate_gradient_document.InsertEndChild(conjugate_gradient_element_copy);
1076
1077 conjugate_gradient.from_XML(conjugate_gradient_document);
1078 }
1079
1080 // Stochastic gradient
1081
1082 const tinyxml2::XMLElement* stochastic_gradient_descent_element = optimization_algorithm_element->FirstChildElement("StochasticGradientDescent");
1083
1084 if(stochastic_gradient_descent_element)
1085 {
1086 tinyxml2::XMLDocument stochastic_gradient_descent_document;
1087
1088 tinyxml2::XMLElement* stochastic_gradient_descent_element_copy = stochastic_gradient_descent_document.NewElement("StochasticGradientDescent");
1089
1090 for(const tinyxml2::XMLNode* nodeFor=stochastic_gradient_descent_element->FirstChild(); nodeFor; nodeFor=nodeFor->NextSibling())
1091 {
1092 tinyxml2::XMLNode* copy = nodeFor->DeepClone(&stochastic_gradient_descent_document );
1093 stochastic_gradient_descent_element_copy->InsertEndChild(copy );
1094 }
1095
1096 stochastic_gradient_descent_document.InsertEndChild(stochastic_gradient_descent_element_copy);
1097
1098 stochastic_gradient_descent.from_XML(stochastic_gradient_descent_document);
1099 }
1100
1101 // Adaptive moment estimation
1102
1103 const tinyxml2::XMLElement* adaptive_moment_estimation_element = optimization_algorithm_element->FirstChildElement("AdaptiveMomentEstimation");
1104
1105 if(adaptive_moment_estimation_element)
1106 {
1107 tinyxml2::XMLDocument adaptive_moment_estimation_document;
1108
1109 tinyxml2::XMLElement* adaptive_moment_estimation_element_copy = adaptive_moment_estimation_document.NewElement("AdaptiveMomentEstimation");
1110
1111 for(const tinyxml2::XMLNode* nodeFor=adaptive_moment_estimation_element->FirstChild(); nodeFor; nodeFor=nodeFor->NextSibling())
1112 {
1113 tinyxml2::XMLNode* copy = nodeFor->DeepClone(&adaptive_moment_estimation_document );
1114 adaptive_moment_estimation_element_copy->InsertEndChild(copy );
1115 }
1116
1117 adaptive_moment_estimation_document.InsertEndChild(adaptive_moment_estimation_element_copy);
1118
1119 adaptive_moment_estimation.from_XML(adaptive_moment_estimation_document);
1120 }
1121
1122 // Quasi-Newton method
1123
1124 const tinyxml2::XMLElement* quasi_Newton_method_element = optimization_algorithm_element->FirstChildElement("QuasiNewtonMethod");
1125
1126 if(quasi_Newton_method_element)
1127 {
1128 tinyxml2::XMLDocument quasi_Newton_document;
1129
1130 tinyxml2::XMLElement* quasi_newton_method_element_copy = quasi_Newton_document.NewElement("QuasiNewtonMethod");
1131
1132 for(const tinyxml2::XMLNode* nodeFor=quasi_Newton_method_element->FirstChild(); nodeFor; nodeFor=nodeFor->NextSibling())
1133 {
1134 tinyxml2::XMLNode* copy = nodeFor->DeepClone(&quasi_Newton_document );
1135 quasi_newton_method_element_copy->InsertEndChild(copy );
1136 }
1137
1138 quasi_Newton_document.InsertEndChild(quasi_newton_method_element_copy);
1139
1140 quasi_Newton_method.from_XML(quasi_Newton_document);
1141 }
1142
1143 // Levenberg Marquardt
1144
1145 const tinyxml2::XMLElement* Levenberg_Marquardt_element = optimization_algorithm_element->FirstChildElement("LevenbergMarquardt");
1146
1147 if(Levenberg_Marquardt_element)
1148 {
1149 tinyxml2::XMLDocument Levenberg_Marquardt_document;
1150
1151 tinyxml2::XMLElement* levenberg_marquardt_algorithm_element_copy = Levenberg_Marquardt_document.NewElement("LevenbergMarquardt");
1152
1153 for(const tinyxml2::XMLNode* nodeFor=Levenberg_Marquardt_element->FirstChild(); nodeFor; nodeFor=nodeFor->NextSibling())
1154 {
1155 tinyxml2::XMLNode* copy = nodeFor->DeepClone(&Levenberg_Marquardt_document );
1156 levenberg_marquardt_algorithm_element_copy->InsertEndChild(copy );
1157 }
1158
1159 Levenberg_Marquardt_document.InsertEndChild(levenberg_marquardt_algorithm_element_copy);
1160
1161 Levenberg_Marquardt_algorithm.from_XML(Levenberg_Marquardt_document);
1162 }
1163 }
1164
1165 // Display
1166 {
1167 const tinyxml2::XMLElement* element = root_element->FirstChildElement("Display");
1168
1169 if(element)
1170 {
1171 const string new_display = element->GetText();
1172
1173 try
1174 {
1175 set_display(new_display != "0");
1176 }
1177 catch(const logic_error& e)
1178 {
1179 cerr << e.what() << endl;
1180 }
1181 }
1182 }
1183}
1184
1185
1188
1189void TrainingStrategy::save(const string& file_name) const
1190{
1191 FILE * file = fopen(file_name.c_str(), "w");
1192
1193 tinyxml2::XMLPrinter printer(file);
1194
1195 write_XML(printer);
1196
1197 fclose(file);
1198}
1199
1200
1204
1205void TrainingStrategy::load(const string& file_name)
1206{
1207 set_default();
1208
1209 tinyxml2::XMLDocument document;
1210
1211 if(document.LoadFile(file_name.c_str()))
1212 {
1213 ostringstream buffer;
1214
1215 buffer << "OpenNN Exception: TrainingStrategy class.\n"
1216 << "void load(const string&) method.\n"
1217 << "Cannot load XML file " << file_name << ".\n";
1218
1219 throw logic_error(buffer.str());
1220 }
1221
1222 from_XML(document);
1223}
1224
1225}
1226
1227// OpenNN: Open Neural Networks Library.
1228// Copyright(C) 2005-2021 Artificial Intelligence Techniques, SL.
1229//
1230// This library is free software; you can redistribute it and/or
1231// modify it under the terms of the GNU Lesser General Public
1232// License as published by the Free Software Foundation; either
1233// version 2.1 of the License, or any later version.
1234//
1235// This library is distributed in the hope that it will be useful,
1236// but WITHOUT ANY WARRANTY; without even the implied warranty of
1237// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
1238// Lesser General Public License for more details.
1239
1240// You should have received a copy of the GNU Lesser General Public
1241// License along with this library; if not, write to the Free Software
1242// Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
void from_XML(const tinyxml2::XMLDocument &)
void set_batch_samples_number(const Index &new_batch_samples_number)
Set number of samples in each batch. Default 1000.
void write_XML(tinyxml2::XMLPrinter &) const
TrainingResults perform_training()
void set_maximum_selection_failures(const Index &)
void set_loss_index_pointer(LossIndex *)
void from_XML(const tinyxml2::XMLDocument &)
void set_maximum_time(const type &)
void set_loss_goal(const type &)
void set_maximum_epochs_number(const Index &)
void write_XML(tinyxml2::XMLPrinter &) const
This class represents the cross entropy error term, used for predicting probabilities.
void from_XML(const tinyxml2::XMLDocument &)
void write_XML(tinyxml2::XMLPrinter &) const
This class represents the concept of data set for data modelling problems, such as approximation,...
Definition: data_set.h:57
TrainingResults perform_training()
void set_maximum_selection_failures(const Index &)
void set_loss_index_pointer(LossIndex *)
void from_XML(const tinyxml2::XMLDocument &)
void set_maximum_time(const type &)
void set_loss_goal(const type &)
void set_maximum_epochs_number(const Index &)
void write_XML(tinyxml2::XMLPrinter &) const
Levenberg-Marquardt Algorithm will always compute the approximate Hessian matrix, which has dimension...
void from_XML(const tinyxml2::XMLDocument &)
void write_XML(tinyxml2::XMLPrinter &) const
Index get_timesteps() const
Returns the number of timesteps.
This abstract class represents the concept of loss index composed of an error term and a regularizati...
Definition: loss_index.h:48
virtual void set_data_set_pointer(DataSet *)
Sets a new data set on which the error term is to be measured.
Definition: loss_index.cpp:196
void set_neural_network_pointer(NeuralNetwork *)
Definition: loss_index.cpp:188
void set_display(const bool &)
Definition: loss_index.cpp:271
This class represents the mean squared error term.
void write_XML(tinyxml2::XMLPrinter &) const
This class represents the Minkowski error term.
void from_XML(const tinyxml2::XMLDocument &)
void set_Minkowski_parameter(const type &)
void write_XML(tinyxml2::XMLPrinter &) const
bool has_long_short_term_memory_layer() const
LongShortTermMemoryLayer * get_long_short_term_memory_layer_pointer() const
Returns a pointer to the long short term memory layer of this neural network, if exits.
bool has_recurrent_layer() const
bool has_convolutional_layer() const
RecurrentLayer * get_recurrent_layer_pointer() const
Returns a pointer to the recurrent layer of this neural network, if exits.
This class represents the normalized squared error term.
void set_data_set_pointer(DataSet *new_data_set_pointer)
set_data_set_pointer
void write_XML(tinyxml2::XMLPrinter &) const
virtual void set_loss_index_pointer(LossIndex *)
virtual void set_display(const bool &)
void set_maximum_selection_failures(const Index &)
void set_loss_index_pointer(LossIndex *)
void from_XML(const tinyxml2::XMLDocument &)
void set_maximum_time(const type &)
void set_maximum_epochs_number(const Index &)
void write_XML(tinyxml2::XMLPrinter &) const
This concrete class represents the stochastic gradient descent optimization algorithm[1] for a loss i...
void from_XML(const tinyxml2::XMLDocument &)
void write_XML(tinyxml2::XMLPrinter &) const
This class represents the sum squared peformance term functional.
AdaptiveMomentEstimation * get_adaptive_moment_estimation_pointer()
TrainingResults perform_training()
QuasiNewtonMethod quasi_Newton_method
Quasi-Newton method object to be used as a main optimization algorithm.
LevenbergMarquardtAlgorithm * get_Levenberg_Marquardt_algorithm_pointer()
NormalizedSquaredError * get_normalized_squared_error_pointer()
LossMethod loss_method
Type of loss method.
GradientDescent * get_gradient_descent_pointer()
CrossEntropyError * get_cross_entropy_error_pointer()
AdaptiveMomentEstimation adaptive_moment_estimation
Adaptive moment estimation algorithm object to be used as a main optimization algorithm.
void set_loss_index_pointer(LossIndex *)
const bool & get_display() const
LossIndex * get_loss_index_pointer()
Returns a pointer to the LossIndex class.
void from_XML(const tinyxml2::XMLDocument &)
string write_loss_method() const
Returns a string with the type of the main loss algorithm composing this training strategy object.
bool display
Display messages to screen.
WeightedSquaredError weighted_squared_error
Pointer to the weighted squared error object wich can be used as the error term.
StochasticGradientDescent stochastic_gradient_descent
Stochastic gradient descent algorithm object to be used as a main optimization algorithm.
ConjugateGradient * get_conjugate_gradient_pointer()
void set_loss_method(const LossMethod &)
WeightedSquaredError * get_weighted_squared_error_pointer()
OptimizationAlgorithm * get_optimization_algorithm_pointer()
Returns a pointer to the OptimizationAlgorithm class.
LevenbergMarquardtAlgorithm Levenberg_Marquardt_algorithm
Levenberg-Marquardt algorithm object to be used as a main optimization algorithm.
ConjugateGradient conjugate_gradient
Conjugate gradient object to be used as a main optimization algorithm.
MinkowskiError Minkowski_error
Pointer to the Mikowski error object wich can be used as the error term.
MinkowskiError * get_Minkowski_error_pointer()
CrossEntropyError cross_entropy_error
Pointer to the cross entropy error object wich can be used as the error term.
GradientDescent gradient_descent
Gradient descent object to be used as a main optimization algorithm.
StochasticGradientDescent * get_stochastic_gradient_descent_pointer()
void save(const string &) const
string write_optimization_method() const
NeuralNetwork * get_neural_network_pointer() const
Returns a pointer to the NeuralNetwork class.
void set_optimization_method(const OptimizationMethod &)
MeanSquaredError * get_mean_squared_error_pointer()
QuasiNewtonMethod * get_quasi_Newton_method_pointer()
NormalizedSquaredError normalized_squared_error
Pointer to the normalized squared error object wich can be used as the error term.
LossMethod
Enumeration of available error terms in OpenNN.
void print() const
Prints to the screen the string representation of the training strategy object.
void set_display(const bool &)
SumSquaredError sum_squared_error
Pointer to the sum squared error object wich can be used as the error term.
const LossMethod & get_loss_method() const
Returns the type of the main loss algorithm composing this training strategy object.
void write_XML(tinyxml2::XMLPrinter &) const
SumSquaredError * get_sum_squared_error_pointer()
OptimizationMethod
Enumeration of all the available types of optimization algorithms.
string write_optimization_method_text() const
MeanSquaredError mean_squared_error
Pointer to the mean squared error object wich can be used as the error term.
DataSet * get_data_set_pointer()
Returns a pointer to the DataSet class.
const OptimizationMethod & get_optimization_method() const
Returns the type of the main optimization algorithm composing this training strategy object.
string write_loss_method_text() const
Returns a string with the main loss method type in text format.
OptimizationMethod optimization_method
Type of main optimization algorithm.
This class represents the weighted squared error term.
void set_data_set_pointer(DataSet *)
set_data_set_pointer
void from_XML(const tinyxml2::XMLDocument &)
void write_XML(tinyxml2::XMLPrinter &) const
const XMLNode * NextSibling() const
Get the next(right) sibling node of this node.
Definition: tinyxml2.h:809
const XMLNode * FirstChild() const
Get the first child node, or null if none exists.
Definition: tinyxml2.h:757
void PushText(const char *text, bool cdata=false)
Add a text node.
Definition: tinyxml2.cpp:2878
virtual void CloseElement(bool compactMode=false)
If streaming, close the Element.
Definition: tinyxml2.cpp:2834
This structure contains the optimization algorithm results.