TrainingStrategy class

This is the documentation for the python TrainingStrategy class methods in the OpenNN python module.

This class represents the concept of training strategy for a neural network.

Initialization methods

  • TrainingStrategy()

    Default initialization method. It creates a training strategy object not associated to any loss index object. It also constructs the main optimization algorithm object.

  • TrainingStrategy(neural_network, data_set)

    Neural Network and Data Set initialization method. It creates a training strategy object associated to NeuralNetwork and DataSet objects.

    • neural_network NeuralNetwrk object.
    • data_set DataSet object.

  • TrainingStrategy(file_name)

    File initialization method. It creates a training strategy object associated to a loss index object. It also loads the members of this object from a XML file.

    • file_name Name of training strategy XML file.

General methods

  • set_loss_method(new_loss_method)

    Select a loss function to use in the Neural Network training.

    • new_loss_method New loss method to use
      • SUM_SQUARED_ERROR
      • MEAN_SQUARED_ERROR
      • NORMALIZED_SQUARED_ERROR
      • MINKOWSKI_ERROR
      • WEIGHTED_SQUARED_ERROR
      • CROSS_ENTROPY_ERROR

  • set_training_method(new_training_method)

    Sets a new main optimization algorithm from a string containing the type.

    • new_training_method String with the type of main optimization algorithm
      • GRADIENT_DESCENT
      • CONJUGATE_GRADIENT
      • QUASI_NEWTON_METHOD
      • LEVENBERG_MARQUARDT_ALGORITHM
      • STOCHASTIC_GRADIENT_DESCENT
      • ADAPTIVE_MOMENT_ESTIMATION

  • train()

    This is the most important method of this class. It optimizes the loss index of a neural network. This method also returns a structure with the results from training.

  • get_gradient_descent()

    Returns a pointer to the gradient descent main algorithm. It also throws an exception if that pointer is nullptr.

  • get_conjugate_gradient()

    Returns a pointer to the conjugate gradient main algorithm. It also throws an exception if that pointer is nullptr.

  • get_quasi_newton_method()

    Returns a pointer to the quasi Newton method main algorithm. It also throws an exception if that pointer is nullptr.

  • get_stochastic_gradient_descent()

    Returns a pointer to the stochastic gradient descent main algorithm. It also throws an exception if that pointer is nullptr.

  • get_levenberg_marquardt_algorithm_pointer()

    Returns a pointer to the Levenberg Marquardt algorithm main algorithm. It also throws an exception if that pointer is nullptr.