Definition adaptive_moment_estimation.h:19
void cross_entropy_gradient(const TensorView &input, const TensorView &target, TensorView &input_delta)
void categorical_cross_entropy(const TensorView &input, const TensorView &target, float &error, float *workspace_device)
void minkowski_error(const TensorView &input, const TensorView &target, float power, float &error, float *workspace_device)
void l2_regularization_gradient(const TensorView ¶meters, float lambda, TensorView &gradient)
void mean_squared_error(const TensorView &input, const TensorView &target, float &error, float *workspace_device)
void cross_entropy_3d(const TensorView &input, const TensorView &target, float &error, Index &active_tokens_out, Index &correct_tokens_out, float *errors_device=nullptr)
void weighted_squared_error(const TensorView &input, const TensorView &target, float pos_w, float neg_w, float &error, float *workspace_device)
void l1_regularization_gradient(const TensorView ¶meters, float lambda, TensorView &gradient)
void binary_cross_entropy(const TensorView &input, const TensorView &target, float &error, float *workspace_device)
void minkowski_error_gradient(const TensorView &input, const TensorView &target, float power, TensorView &input_delta)
void weighted_squared_error_gradient(const TensorView &input, const TensorView &target, float pos_w, float neg_w, float coefficient, TensorView &input_delta)
void l2_regularization(const TensorView ¶meters, float lambda, float &penalty)
void mean_squared_error_gradient(const TensorView &input, const TensorView &target, TensorView &input_delta)
void cross_entropy_3d_gradient(const TensorView &input, const TensorView &target, TensorView &input_delta, Index active_tokens_count)
void normalized_squared_error(const TensorView &input, const TensorView &target, float coefficient, float &error, float *workspace_device)
void normalized_squared_error_gradient(const TensorView &input, const TensorView &target, float coefficient, TensorView &input_delta)
void l1_regularization(const TensorView ¶meters, float lambda, float &penalty)
Definition tensor_utilities.h:236