61 void fill(
const vector<Index>& sample_indices,
62 const vector<Index>& input_feature_indices,
63 const vector<Index>& decoder_feature_indices,
64 const vector<Index>& target_feature_indices,
65 bool augment =
false);
static Configuration & instance()
Definition configuration.h:113
Base data container with samples, variables and per-variable metadata.
Definition dataset.h:108
Declares the Dataset class and the SampleRole enum.
Definition adaptive_moment_estimation.h:19
void * cudaStream_t
Definition neural_network.h:71
@ CUDA
Definition configuration.h:16
const vector< TensorView > & get_inputs() const
Returns the input TensorViews for the active device.
Definition batch.h:71
Index num_target_features
Total number of target features.
Definition batch.h:150
Index num_decoder_features
Total number of decoder features.
Definition batch.h:148
const TensorView & get_targets() const
Returns the target TensorView for the active device.
Definition batch.h:83
float * targets_host
Pinned host pointer used to stage targets into the device buffer.
Definition batch.h:157
Index inputs_host_allocated_size
Allocation capacity of inputs_host, in floats.
Definition batch.h:160
int target_contiguous
Stride hint passed to Dataset::fill_targets(); -1 to ignore.
Definition batch.h:126
TensorView target_view_host_cache
Host-side target TensorView.
Definition batch.h:138
Shape input_shape
Shape of input (per-sample dimensions).
Definition batch.h:109
Batch(const Index samples_number=0, const Dataset *dataset=nullptr)
Constructs a batch sized for a given dataset.
void fill(const vector< Index > &sample_indices, const vector< Index > &input_feature_indices, const vector< Index > &decoder_feature_indices, const vector< Index > &target_feature_indices, bool augment=false)
Fills the host-side buffers from the bound dataset.
vector< TensorView > input_views_host_cache
Host-side input TensorViews (one per input feature group).
Definition batch.h:136
Index get_samples_number() const
Number of samples currently held in the batch.
Buffer decoder
Owning storage for the decoder tensor (encoder-decoder models).
Definition batch.h:112
Buffer inputs_fp32_staging
Device-resident FP32 staging buffer used during mixed-precision uploads.
Definition batch.h:167
int decoder_contiguous
Stride hint passed to Dataset::fill_inputs() for the decoder; -1 to ignore.
Definition batch.h:124
int input_contiguous
Stride hint passed to Dataset::fill_inputs(); -1 to ignore.
Definition batch.h:122
vector< TensorView > input_views_cache
Device-side input TensorViews; populated only on CUDA mode.
Definition batch.h:141
Shape decoder_shape
Shape of decoder.
Definition batch.h:114
void copy_device_async(const Index sample_count, cudaStream_t stream)
Asynchronously copies the host buffers to the device on stream.
TensorView target_view_cache
Device-side target TensorView; populated only on CUDA mode.
Definition batch.h:143
Buffer input
Owning storage for the input tensor (host or device).
Definition batch.h:107
Index samples_number
Number of samples currently held in the batch.
Definition batch.h:101
const Dataset * dataset
Dataset whose shapes determined the batch buffer sizes; not owned.
Definition batch.h:104
Index decoder_host_allocated_size
Allocation capacity of decoder_host, in floats.
Definition batch.h:162
Index num_input_features
Total number of input features across all input feature groups.
Definition batch.h:146
Index targets_host_allocated_size
Allocation capacity of targets_host, in floats.
Definition batch.h:164
Shape target_shape
Shape of target.
Definition batch.h:119
float * inputs_host
Pinned host pointer used to stage inputs into the device buffer.
Definition batch.h:153
void print() const
Prints a human-readable summary of the batch buffers to stdout.
float * decoder_host
Pinned host pointer used to stage decoder inputs into the device buffer.
Definition batch.h:155
Buffer target
Owning storage for the target tensor.
Definition batch.h:117
void set(const Index samples_number=0, const Dataset *dataset=nullptr)
(Re)allocates buffers for a given batch size and dataset.
bool is_empty() const
Whether the batch holds zero samples.
Definition tensor_utilities.h:144
Definition tensor_utilities.h:46
Definition tensor_utilities.h:236