Definition adaptive_moment_estimation.h:14
void categorical_cross_entropy(const TensorView &input, const TensorView &target, float &error, float *workspace_device)
Computes the multi-class (categorical) cross-entropy between softmax probabilities and one-hot target...
void minkowski_error(const TensorView &input, const TensorView &target, float power, float &error, float *workspace_device)
Computes the Minkowski error sum(|input - target|^power) for the given power exponent.
void minkowski_error_gradient(const TensorView &input, const TensorView &target, float power, const TensorView &input_delta)
Writes the Minkowski-error gradient with respect to the predictions into input_delta.
void mean_squared_error_gradient(const TensorView &input, const TensorView &target, const TensorView &input_delta)
Writes the MSE gradient with respect to the predictions into input_delta.
void mean_squared_error(const TensorView &input, const TensorView &target, float &error, float *workspace_device)
Computes the mean squared error between predictions and targets.
void normalized_squared_error_gradient(const TensorView &input, const TensorView &target, float coefficient, const TensorView &input_delta)
Writes the normalized-squared-error gradient with respect to the predictions into input_delta.
void cross_entropy_3d(const TensorView &input, const TensorView &target, float &error, Index &active_tokens_out, Index &correct_tokens_out, float *errors_device=nullptr)
Computes 3-D (sequence) cross-entropy used by transformer-style targets, ignoring padded positions.
void l2_regularization_gradient(const TensorView ¶meters, float lambda, const TensorView &gradient)
Adds the L2 regularization gradient 2 * lambda * parameters into the gradient tensor.
void weighted_squared_error_gradient(const TensorView &input, const TensorView &target, float pos_w, float neg_w, float coefficient, const TensorView &input_delta)
Writes the gradient of the weighted squared error scaled by coefficient into input_delta.
void weighted_squared_error(const TensorView &input, const TensorView &target, float pos_w, float neg_w, float &error, float *workspace_device)
Computes the binary squared error weighted asymmetrically for positive and negative classes.
void cross_entropy_gradient(const TensorView &input, const TensorView &target, const TensorView &input_delta)
Writes the cross-entropy gradient with respect to the (pre-softmax/logit) predictions into input_delt...
void binary_cross_entropy(const TensorView &input, const TensorView &target, float &error, float *workspace_device)
Computes the binary cross-entropy between predicted probabilities and binary targets.
void cross_entropy_3d_gradient_device_count(const TensorView &input, const TensorView &target, const TensorView &input_delta, const float *active_tokens_count_device)
Variant of cross_entropy_3d_gradient that reads the active-token count from device memory.
void l2_regularization(const TensorView ¶meters, float lambda, float &penalty)
Computes the L2 regularization penalty lambda * sum(parameters^2).
void l1_regularization_gradient(const TensorView ¶meters, float lambda, const TensorView &gradient)
Adds the L1 regularization gradient lambda * sign(parameters) into the gradient tensor.
void normalized_squared_error(const TensorView &input, const TensorView &target, float coefficient, float &error, float *workspace_device)
Computes the squared error normalized by a dataset-level coefficient.
void cross_entropy_3d_gradient(const TensorView &input, const TensorView &target, const TensorView &input_delta, Index active_tokens_count)
Writes the 3-D cross-entropy gradient into input_delta, normalizing by the host-side active-token cou...
void l1_regularization(const TensorView ¶meters, float lambda, float &penalty)
Computes the L1 regularization penalty lambda * sum(|parameters|).
Non-owning view over a tensor: pointer, shape, and data type with rich reshape helpers.
Definition tensor_utilities.h:293