neon.callbacks.callbacks.Callbacks

class neon.callbacks.callbacks.Callbacks(model, train_set=None, output_file=None, eval_freq=None, progress_bar=True, save_path=None, serialize=0, history=1, model_file=None, eval_set=None, metric=None, log_token=None, multicost=False)[source]

Bases: neon.NervanaObject

Container class for storing and iterating over callbacks.

callbacks

list – Ordered set of Callback objects to be run.

__init__(model, train_set=None, output_file=None, eval_freq=None, progress_bar=True, save_path=None, serialize=0, history=1, model_file=None, eval_set=None, metric=None, log_token=None, multicost=False)[source]

Create a callbacks container with the default callbacks.

Parameters:
  • model (Model) – the model object
  • output_file (string, optional) – path to save callback data to
  • eval_freq (int, optional) – how often (in epochs) to run evaluation
  • progress_bar (bool) – control whether a progress bar callback is created. Defaults to True.
  • save_path (string) – file path to save model snapshots (default: None)
  • serialize (int) – serialize model every N epochs (default: 0)
  • history (int) – number of checkpoint files to retain (default: 1)
  • model_file (string, optional) – file to load weights (serialized model) from
  • eval_set (NervanaDataIterator, optional) – the dataset upon which to evaluate loss or metric
  • metric (Metric, optional) – metric to evaluate
  • multicost (bool, optional) – use the multicost callback. default to false.

Methods

__init__(model[, train_set, output_file, …]) Create a callbacks container with the default callbacks.
add_callback(callback[, insert_pos]) Add a user supplied callback.
add_deconv_callback(train_set, valid_set[, …]) Convenience function to create and add a deconvolution callback.
add_early_stop_callback(stop_func) Convenience function to create and add an early stopping callback.
add_hist_callback([plot_per_mini, filter_key]) Convenience function to create and add a histgram callback.
add_save_best_state_callback(path) Convenience function to create and add a save best state callback.
add_watch_ticker_callback(valid) Convenience function to create and add a watch ticker callback.
gen_class(pdict)
get_description() Serialize callback configuration.
load_callbacks(cdict, model[, data]) Load callbacks.
on_epoch_begin(epoch) Call all registered callbacks’ on_epoch_begin functions.
on_epoch_end(epoch) Call all registered callbacks’ on_epoch_end functions.
on_minibatch_begin(epoch, minibatch) Call all registered callbacks’ on_minibatch_begin functions.
on_minibatch_end(epoch, minibatch) Call all registered callbacks’ on_minibatch_end functions.
on_sigint_catch(epoch, minibatch) Callback to handle SIGINT events.
on_train_begin(epochs) Call all registered callbacks’ on_train_begin functions.
on_train_end() Call all registered callbacks’ on_train_end functions.
recursive_gen(pdict, key) helper method to check whether the definition
serialize() Serialize callback configuration.
add_callback(callback, insert_pos=None)[source]

Add a user supplied callback. Since callbacks are run serially and share data, order can matter. If the default behavior (to append the callback) is not sufficient, insert position can be controlled.

Parameters:
  • callback (Callback) – callback object to be registered
  • insert_pos (int, optional) – position in the list to insert the callback. Defaults to None, meaning append
add_deconv_callback(train_set, valid_set, max_fm=16, dataset_pct=25)[source]

Convenience function to create and add a deconvolution callback. The data can be used for visualization.

Parameters:
  • train_set (NervanaDataIterator) – the train dataset to use
  • valid_set (NervanaDataIterator) – the validation dataset to use
  • max_fm – (Default value = 16)
  • dataset_pct – (Default value = 25)
add_early_stop_callback(stop_func)[source]

Convenience function to create and add an early stopping callback.

Parameters:stop_func (function) – function to determine when to stop.
add_hist_callback(plot_per_mini=False, filter_key=['W'])[source]

Convenience function to create and add a histgram callback.

add_save_best_state_callback(path)[source]

Convenience function to create and add a save best state callback.

Parameters:path (string) – where to save the best model state.
add_watch_ticker_callback(valid)[source]

Convenience function to create and add a watch ticker callback.

Parameters:valid (dataset) – the validation set to use For a ticker dataset, this can be the training set if desired.
be = None
classnm

Returns the class name.

gen_class(pdict)
get_description()[source]

Serialize callback configuration.

classmethod load_callbacks(cdict, model, data=[])[source]

Load callbacks.

modulenm

Returns the full module path.

on_epoch_begin(epoch)[source]

Call all registered callbacks’ on_epoch_begin functions.

Parameters:epoch (int) – index of epoch that is beginning
on_epoch_end(epoch)[source]

Call all registered callbacks’ on_epoch_end functions.

Parameters:epoch (int) – index of epoch that is ending
on_minibatch_begin(epoch, minibatch)[source]

Call all registered callbacks’ on_minibatch_begin functions.

Parameters:
  • epoch (int) – index of current epoch
  • minibatch (int) – index of minibatch that is beginning
on_minibatch_end(epoch, minibatch)[source]

Call all registered callbacks’ on_minibatch_end functions.

Parameters:
  • epoch (int) – index of current epoch
  • minibatch (int) – index of minibatch that is ending
on_sigint_catch(epoch, minibatch)[source]

Callback to handle SIGINT events.

Parameters:
  • epoch (int) – index of current epoch
  • minibatch (int) – index of minibatch that is ending
on_train_begin(epochs)[source]

Call all registered callbacks’ on_train_begin functions.

Parameters:epochs (int) – Total epochs
on_train_end()[source]

Call all registered callbacks’ on_train_end functions.

recursive_gen(pdict, key)

helper method to check whether the definition dictionary is defining a NervanaObject child, if so it will instantiate that object and replace the dictionary element with an instance of that object

serialize()[source]

Serialize callback configuration.