29 const string& =
"sequence_pooling_layer");
70 Index sequence_length = 0;
71 Index input_features = 0;
77 enum Forward { Input, MaximalIndices, Output };
const string & get_label() const
Definition layer.h:112
void set_output_shape(const Shape &) override
Sets the output shape; subclasses override when the output is user-configurable.
Definition pooling_layer_3d.h:54
Shape get_input_shape() const override
Returns the input tensor shape (sequence_length, input_features).
Definition pooling_layer_3d.h:32
void set_pooling_method(PoolingMethod)
Sets the pooling method via enum.
void set_input_shape(const Shape &new_input_shape) override
Updates the layer for a new input shape, preserving pooling method and label.
Definition pooling_layer_3d.h:49
Pooling3d(const Shape &={0, 0}, const PoolingMethod &=PoolingMethod::MaxPooling, const string &="sequence_pooling_layer")
Constructs a sequence pooling layer.
Index get_input_features() const
Definition pooling_layer_3d.h:38
Index get_sequence_length() const
Definition pooling_layer_3d.h:37
vector< TensorSpec > get_forward_specs(Index batch_size) const override
Returns the tensor specifications used during forward propagation.
void set_pooling_method(const string &)
Sets the pooling method by name ("MaxPooling" or "AveragePooling").
void set(const Shape &, const PoolingMethod &, const string &)
Reconfigures the layer with a new input shape, pooling method and name.
void write_JSON_body(JsonWriter &) const override
Writes the layer configuration to a JSON writer.
void read_JSON_body(const Json *) override
Reads the layer configuration from a JSON node.
Shape get_output_shape() const override
Returns the output tensor shape after sequence-axis pooling.
PoolingMethod get_pooling_method() const
Definition pooling_layer_3d.h:40
Definition adaptive_moment_estimation.h:14
PoolingMethod
Pooling reduction method used by Pooling and Pooling3d layers.
Definition pooling_layer.h:19
@ MaxPooling
Definition pooling_layer.h:20
Sequence-wide 1D pooling over the embedding dimension (mean or max).
Definition operators.h:930
Fixed-capacity small-vector describing tensor dimensions (rank up to MaxRank).
Definition tensor_utilities.h:42