OpenNN
Open-source neural networks library
Loading...
Searching...
No Matches
training_strategy.h
Go to the documentation of this file.
1// OpenNN: Open Neural Networks Library
2// www.opennn.net
3//
4// T R A I N I N G S T R A T E G Y C L A S S H E A D E R
5//
6// Artificial Intelligence Techniques SL
7// artelnics@artelnics.com
8
9#pragma once
10
11#include "loss.h"
12#include "optimizer.h"
13
14namespace opennn
15{
16
17class Loss;
18class Optimizer;
19
20struct TrainingResults;
21
24{
25
26public:
27
29 TrainingStrategy(NeuralNetwork* = nullptr, Dataset* = nullptr);
30
31 const Dataset* get_dataset() const { return dataset; }
32 Dataset* get_dataset() { return dataset; }
33
34 const NeuralNetwork* get_neural_network() const { return neural_network; }
35 NeuralNetwork* get_neural_network() { return neural_network; }
36
37 const Loss* get_loss() const { return loss.get(); }
38 Loss* get_loss() { return loss.get(); }
39
40 const Optimizer* get_optimization_algorithm() const { return optimizer.get(); }
41 Optimizer* get_optimization_algorithm() { return optimizer.get(); }
43 void set(NeuralNetwork* = nullptr, Dataset* = nullptr);
46
47 void set_dataset(Dataset* new_dataset) { dataset = new_dataset; }
48 void set_neural_network(NeuralNetwork* new_neural_network) { neural_network = new_neural_network; }
49
51 void set_loss(const string&);
53 void set_optimization_algorithm(const string&);
54
58 void from_JSON(const JsonDocument&);
60 void to_JSON(JsonWriter&) const;
61
63 void save(const filesystem::path&) const;
65 void load(const filesystem::path&);
66
67private:
68
69 void fix_forecasting();
70
71 Dataset* dataset = nullptr;
72
73 NeuralNetwork* neural_network = nullptr;
74
75 unique_ptr<Loss> loss;
76
77 unique_ptr<Optimizer> optimizer;
78};
79
80}
81
82// OpenNN: Open Neural Networks Library.
83// Copyright(C) 2005-2026 Artificial Intelligence Techniques, SL.
84// Licensed under the GNU Lesser General Public License v2.1 or later.
Abstract base class for OpenNN datasets, owning samples, variables, and metadata.
Definition dataset.h:61
Definition json.h:72
Definition json.h:85
Unified loss container supporting MSE, cross-entropy, Minkowski, weighted, and regularized variants.
Definition loss.h:24
Container of layers forming a feed-forward neural network, with parameter storage and I/O.
Definition neural_network.h:20
Abstract base class for training optimizers (Adam, SGD, Quasi-Newton, Levenberg-Marquardt).
Definition optimizer.h:31
void set_default()
Resets the loss and optimizer to their default types and hyperparameters.
const NeuralNetwork * get_neural_network() const
Definition training_strategy.h:34
const Optimizer * get_optimization_algorithm() const
Definition training_strategy.h:40
Optimizer * get_optimization_algorithm()
Definition training_strategy.h:41
void load(const filesystem::path &)
Loads the strategy configuration from a JSON file at the given path.
void from_JSON(const JsonDocument &)
Restores the full strategy (loss + optimizer configurations) from a JSON document.
void set_loss(const string &)
Replaces the current loss with one selected by name (e.g. "MeanSquaredError", "CrossEntropy").
void save(const filesystem::path &) const
Writes the strategy configuration to a JSON file at the given path.
TrainingResults train()
Runs the configured optimizer against the configured loss and returns the training history.
Loss * get_loss()
Definition training_strategy.h:38
const Dataset * get_dataset() const
Definition training_strategy.h:31
void to_JSON(JsonWriter &) const
Serializes the full strategy (loss + optimizer configurations) to JSON.
void set_dataset(Dataset *new_dataset)
Definition training_strategy.h:47
TrainingStrategy(NeuralNetwork *=nullptr, Dataset *=nullptr)
Constructs the strategy with default loss (MSE) and optimizer (Adam) bound to the given network and d...
void set(NeuralNetwork *=nullptr, Dataset *=nullptr)
Rebinds the strategy to a new network/dataset, resetting loss and optimizer to defaults.
Dataset * get_dataset()
Definition training_strategy.h:32
const Loss * get_loss() const
Definition training_strategy.h:37
NeuralNetwork * get_neural_network()
Definition training_strategy.h:35
void set_optimization_algorithm(const string &)
Replaces the current optimizer with one selected by name (e.g. "Adam", "SGD", "QuasiNewton",...
void set_neural_network(NeuralNetwork *new_neural_network)
Definition training_strategy.h:48
Definition adaptive_moment_estimation.h:14
History and final metrics produced by a training run.
Definition optimizer.h:204