callbacks

A set of callbacks useful for training a model.

bpreveal.callbacks.getCallbacks(earlyStop, outputPrefix, plateauPatience, heads, trainBatchGen, valBatchGen)

Return a set of callbacks for your model.

Parameters:
  • earlyStop (int) – The early-stopping-patience from the config file.

  • outputPrefix (str) – The output-prefix for the model, including directory.

  • plateauPatience (int) – The learning-rate-plateau-patience from the config file.

  • heads (list[dict]) – The heads list for your model, to which adaptive loss \({\lambda}\) tensors have been added.

  • trainBatchGen (H5BatchGenerator) – The batch generator for training. Just used to see how many batches there will be.

  • valBatchGen (H5BatchGenerator) – The batch generator for validation. Just used to see how many batches there will be.

Returns:

A list of Keras callbacks that you should use to train your model.

Return type:

list[tensorflow.keras.callbacks.Callback]

The returned callbacks are:

EarlyStopping

Stop training if the validation loss hasn’t improved for a while.

ModelCheckpoint

Write a checkpoint file every time the validation loss improves.

ReduceLROnPlateau

If validation loss hasn’t improved for a while, decrease the learning rate.

ApplyAdaptiveCountsLoss

Implement the adaptive counts loss algorithm.

class bpreveal.callbacks.DisplayCallback(*args, **kwargs)

Replaces the tensorflow progress bar logger with lots of printing to stderr.

Parameters:
  • trainBatchGen – The training batch generator.

  • valBatchGen – The validation batch generator.

  • plateauCallback – The plateau callback, used to access the LR schedule.

  • earlyStopCallback – The EarlyStopping callback, used to see how long we have left.

  • adaptiveLossCallback – The adaptive loss callback, used to read \({\lambda}\) values.

epochNumber = 0

What is the currently-running epoch number?

batchNumber = 0

What is the currently-running training batch number?

printLocationsEpoch = {}

For a given data type, what row should it be printed on in the epoch pane? For example, “val_loss” might go on row 5.

printLocationsBatch = {}

What row should each data type go on in the batch pane?

multipliers = {}

For a given data type, what constant should it be multiplied by for display? This is used to weight profile losses by profile-loss-weight.

prevEpochLogs = None

The logs from last epoch

lastBatchEndTime = 0

When did the last batch finish?

lastValBatchEndTime = 0

When did the last validation batch finish?

lastEpochEndTime = None

When did the last epoch finish?

lastEpochStartTime = None

When did the current epoch start?

numEpochs: int

What is the maximum number of training epochs?

curEpochWaitTime = None

How long between the end of the last epoch and the start of this one?

maxLen = 0

Of all the data types, what is the length of the longest name? Used to calculate column positions.

trainBeginTime: float

When did the whole training process start?

numBatches: int

How many training batches per epoch?

numValBatches: int

How many validation batches per epoch?

on_train_begin(logs=None)

Just loads in the total number of epochs.

Parameters:

logs (dict | None) –

on_epoch_begin(epoch, logs=None)

Just sets the timers up, so you can check how long an epoch took at the end.

Parameters:
  • epoch (int) –

  • logs (dict | None) –

formatStr(val)

Formats an object to be 10 characters wide.

If a second object is provided, format as a ratio.

Parameters:

val (str | int | float | tuple[int, int]) – The thing to format

Returns:

An 11 character wide formatted string.

Return type:

str

on_epoch_end(epoch, logs=None)

Writes out all the logs for this epoch and the last one at INFO logging level.

Parameters:
  • epoch (int) –

  • logs (dict | None) –

on_train_batch_end(batch, logs=None)

Write the loss info for the current batch at DEBUG level.

Parameters:
  • batch (int) –

  • logs (dict | None) –

on_test_batch_end(batch, logs=None)

Just emit a counter with the batch number at DEBUG level.

Parameters:
  • batch (int) –

  • logs (dict | None) –

class bpreveal.callbacks.ApplyAdaptiveCountsLoss(*args, **kwargs)

Implements the adaptive counts loss algorithm.

This updates the counts loss weights in your model on the fly, so that you can specify a target fraction of the loss that is due to counts, and the model will automatically update the weight. It currently does not update profile weights, so you can’t yet automatically balance the losses in a multitask model. (You can still manually specify them in the config json, of course!)

Parameters:
  • heads (dict) – Is straight from the json, except with “INTERNAL_counts-loss-weight-variable” entries in each head. If present, these variables are the Keras Variables that contain the loss weights. These will be updated in this callback.

  • aggression (float) – A number from 0 to 1.

  • lrPlateauCallback – The learning rate plateau callback that your model is going to use.

  • earlyStopCallback – The early stopping callback that your model is going to use.

  • checkpointCallback – The checkpoint callback that your model is going to use.

aggression determines how aggressively the counts loss will be re-weighted. Lower values indicate slower changes, and a value of 1.0 means that the old loss should be completely discarded at each iteration. (If the newly calculated loss weight is ever off by over a factor of two, it’s clamped to be exactly twice (or half) of the old weight, so that gross instability in the early stages doesn’t cause the model to explode.)

The three callbacks are the others used in the model, and the model should run them BEFORE this callback is executed. This callback messes with their internal state (naughty, naughty!) because changing the loss weights could cause the model’s loss value to go up even though the model hasn’t gotten any worse.

on_train_begin(logs=None)

Set up the initial guesses for \({\lambda}\).

Parameters:

logs (dict | None) – Ignored.

At the beginning of training, see which heads are using adaptive weights. For those heads, load in an initial guess for \({\lambda}\) based on the BPNet heuristic.

getLosses(epoch, headName)

Get what the losses actually were in a previous epoch using that epoch’s \({\lambda}\) values.

Parameters:
  • epoch (int) – The epoch for which you’d like losses.

  • headName (str) – The head for which you’d like losses.

Returns:

A tuple with four losses: (profile, counts, val_profile, val_counts).

Return type:

tuple[float, float, float, float]

What were the profile, counts, val_profile, and val_counts (in that order) losses at the given epoch? The counts weight returned from this function uses the counts weight in use at the given epoch. This method assumes that the validation and training losses match some regexes, and these are based on layer names. If someone got really goofy with head names (e.g., one head named “x” and another named “profile_x”, then this could get messed up.

whatWouldValLossBe(epoch)

Determine a previous epoch’s validation loss but use the current \({\lambda}\) values to do it.

Parameters:

epoch (int) – The epoch number where you’d like to know what the loss would have been.

Returns:

A float giving the corrected validation loss at that epoch.

Had we been using the current \({\lambda}\) in a previous epoch, what would its validation loss have been?

resetCallbacks()

Manipulate the other callbacks so they don’t break when \({\lambda}\) changes.

This is the squirreliest method here. The other callbacks that track model progression track the loss of the model at the record-setting epoch. But the definition of loss itself is changing during the training, so we need to update their stored idea of what the model’s loss was during the training.

For example, consider the following scenario:

raw_loss  loss_weight  scaled_loss
10          1           10
9           1           9
8           1           8
7           2           14
6           2           12
5           2           10

A callback that was tracking the loss, looking for the minimum, would claim that the best epoch was epoch 3, where the scaled loss was eight. We need to go into that callback and say, “no, with our current loss weight, the loss on epoch three would actually have been 16.” so that the callback thinks that the loss of 14 on epoch 4 is an improvement.

on_epoch_end(epoch, logs=None)

Update the other callbacks and calculate a new \({\lambda}\).

Parameters:
  • epoch (int) – The epoch number that just finished.

  • logs (dict | None) – The history logs from the last epoch.