Java Code Examples for org.nd4j.linalg.learning.config.IUpdater#hasLearningRate()

The following examples show how to use org.nd4j.linalg.learning.config.IUpdater#hasLearningRate() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: NetworkUtils.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
private static void setLearningRate(MultiLayerNetwork net, int layerNumber, double newLr, ISchedule newLrSchedule, boolean refreshUpdater) {

        Layer l = net.getLayer(layerNumber).conf().getLayer();
        if (l instanceof BaseLayer) {
            BaseLayer bl = (BaseLayer) l;
            IUpdater u = bl.getIUpdater();
            if (u != null && u.hasLearningRate()) {
                if (newLrSchedule != null) {
                    u.setLrAndSchedule(Double.NaN, newLrSchedule);
                } else {
                    u.setLrAndSchedule(newLr, null);
                }
            }

            //Need to refresh the updater - if we change the LR (or schedule) we may rebuild the updater blocks, which are
            // built by creating blocks of params with the same configuration
            if (refreshUpdater) {
                refreshUpdater(net);
            }
        }
    }
 
Example 2
Source File: NetworkUtils.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * Get the current learning rate, for the specified layer, fromthe network.
 * Note: If the layer has no learning rate (no parameters, or an updater without a learning rate) then null is returned
 *
 * @param net         Network
 * @param layerNumber Layer number to get the learning rate for
 * @return Learning rate for the specified layer, or null
 */
public static Double getLearningRate(MultiLayerNetwork net, int layerNumber) {
    Layer l = net.getLayer(layerNumber).conf().getLayer();
    int iter = net.getIterationCount();
    int epoch = net.getEpochCount();
    if (l instanceof BaseLayer) {
        BaseLayer bl = (BaseLayer) l;
        IUpdater u = bl.getIUpdater();
        if (u != null && u.hasLearningRate()) {
            double d = u.getLearningRate(iter, epoch);
            if (Double.isNaN(d)) {
                return null;
            }
            return d;
        }
        return null;
    }
    return null;
}
 
Example 3
Source File: NetworkUtils.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
private static void setLearningRate(ComputationGraph net, String layerName, double newLr, ISchedule newLrSchedule, boolean refreshUpdater) {

        Layer l = net.getLayer(layerName).conf().getLayer();
        if (l instanceof BaseLayer) {
            BaseLayer bl = (BaseLayer) l;
            IUpdater u = bl.getIUpdater();
            if (u != null && u.hasLearningRate()) {
                if (newLrSchedule != null) {
                    u.setLrAndSchedule(Double.NaN, newLrSchedule);
                } else {
                    u.setLrAndSchedule(newLr, null);
                }
            }

            //Need to refresh the updater - if we change the LR (or schedule) we may rebuild the updater blocks, which are
            // built by creating blocks of params with the same configuration
            if (refreshUpdater) {
                refreshUpdater(net);
            }
        }
    }
 
Example 4
Source File: NetworkUtils.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * Get the current learning rate, for the specified layer, from the network.
 * Note: If the layer has no learning rate (no parameters, or an updater without a learning rate) then null is returned
 *
 * @param net       Network
 * @param layerName Layer name to get the learning rate for
 * @return Learning rate for the specified layer, or null
 */
public static Double getLearningRate(ComputationGraph net, String layerName) {
    Layer l = net.getLayer(layerName).conf().getLayer();
    int iter = net.getConfiguration().getIterationCount();
    int epoch = net.getConfiguration().getEpochCount();
    if (l instanceof BaseLayer) {
        BaseLayer bl = (BaseLayer) l;
        IUpdater u = bl.getIUpdater();
        if (u != null && u.hasLearningRate()) {
            double d = u.getLearningRate(iter, epoch);
            if (Double.isNaN(d)) {
                return null;
            }
            return d;
        }
        return null;
    }
    return null;
}