OpenNN
Open-source neural networks library
Loading...
Searching...
No Matches
adaptive_moment_estimation.h
Go to the documentation of this file.
1// OpenNN: Open Neural Networks Library
2// www.opennn.net
3//
4// A D A P T I V E M O M E N T E S T I M A T I O N
5//
6// Artificial Intelligence Techniques SL
7// artelnics@artelnics.com
8
9#pragma once
10
11#include "optimizer.h"
12
13namespace opennn
14{
15
16struct BackPropagation;
17
20{
21
22public:
23
26
29
31 Index get_samples_number() const;
32
34 void set_batch_size(const Index);
35
38
40 void set_learning_rate(const float);
42 void set_beta_1(const float);
44 void set_beta_2(const float);
45
48
51
53 void from_JSON(const JsonDocument&) override;
54
56 void to_JSON(JsonWriter&) const override;
57
58private:
59
60 float learning_rate = 0.001f;
61
62 float beta_1 = 0.9f;
63
64 float beta_2 = 0.999f;
65
66 Index batch_size = 0;
67};
68
69}
TrainingResults train() override
Runs the Adam training loop and returns the recorded error history.
void update_parameters(BackPropagation &, OptimizerData &) const
Applies one Adam update to the network parameters using the gradient in back_propagation.
void set_beta_2(const float)
Sets the second-moment decay rate beta_2.
DataSlot
Slot indices into the optimizer scratch buffer (m_t and v_t moments).
Definition adaptive_moment_estimation.h:25
@ SquareGradientMoment
Definition adaptive_moment_estimation.h:25
@ GradientMoment
Definition adaptive_moment_estimation.h:25
AdaptiveMomentEstimation(Loss *=nullptr)
Constructs Adam optionally bound to a Loss instance.
Index get_samples_number() const
Returns the number of training samples seen by the bound dataset.
void set_beta_1(const float)
Sets the first-moment decay rate beta_1.
void set_learning_rate(const float)
Sets the base learning rate alpha.
void set_batch_size(const Index)
Sets the minibatch size used by train().
void set_default()
Resets all hyperparameters (learning rate, betas, stopping criteria) to library defaults.
void to_JSON(JsonWriter &) const override
Serializes hyperparameters to JSON.
void from_JSON(const JsonDocument &) override
Restores hyperparameters from a JSON document.
Definition json.h:72
Definition json.h:85
Unified loss container supporting MSE, cross-entropy, Minkowski, weighted, and regularized variants.
Definition loss.h:24
Optimizer(Loss *=nullptr)
Constructs an optimizer optionally bound to a Loss instance.
Definition adaptive_moment_estimation.h:14
Workspace holding parameter gradients and per-layer deltas during a backward pass.
Definition back_propagation.h:21
Per-optimizer scratch state (moments, directions, iteration counter) backing the update step.
Definition optimizer.h:182
History and final metrics produced by a training run.
Definition optimizer.h:204