32 void set(
const Index = 0,
Loss* =
nullptr);
68 void setup_delta_pool(
const vector<vector<TensorSpec>>& backward_specs);
Unified loss container supporting MSE, cross-entropy, Minkowski, weighted, and regularized variants.
Definition loss.h:24
Container of layers forming a feed-forward neural network, with parameter storage and I/O.
Definition neural_network.h:20
Definition adaptive_moment_estimation.h:14
vector< vector< TensorView > > gradient_views
Definition back_propagation.h:41
Buffer gradient
Definition back_propagation.h:40
vector< vector< TensorView > > delta_views
Definition back_propagation.h:44
virtual ~BackPropagation()=default
const NeuralNetwork * neural_network
Definition back_propagation.h:38
BackPropagation(const Index=0, Loss *=nullptr)
Constructs a workspace for the given batch size and loss.
Index active_tokens_count
Definition back_propagation.h:64
Buffer delta_pool
Definition back_propagation.h:43
void print() const
Prints a human-readable summary of the workspace contents.
void accumulate_output_deltas(size_t layer_index)
Accumulates deltas from all consumer edges into the given layer's output delta.
const TensorView & get_output_delta() const
Returns the output delta of the network (gradient w.r.t. the final outputs).
float accuracy
Definition back_propagation.h:62
Index batch_size
Definition back_propagation.h:57
float loss_value
Definition back_propagation.h:63
vector< vector< pair< size_t, size_t > > > consumer_edges
Definition back_propagation.h:46
float error
Definition back_propagation.h:61
void set(const Index=0, Loss *=nullptr)
Reconfigures the workspace for a new batch size or loss; reuses allocations when possible.
TensorView & get_output_delta()
Returns the output delta of the network (gradient w.r.t. the final outputs).
Loss * loss
Definition back_propagation.h:59
Owning raw byte buffer that lives on CPU or CUDA memory, with aligned (re)allocation.
Definition tensor_utilities.h:166
Non-owning view over a tensor: pointer, shape, and data type with rich reshape helpers.
Definition tensor_utilities.h:293