callbacks

A set of callbacks useful for training a model.

class bpreveal.callbacks.FixLossCallback(*args, **kwargs)

Fixes the loss terms in the logs to reflect the metrics, not the actual loss values.

This is because of a stupid in Keras. If you use a loss function for the metrics, they will not give the same values, because of something something regularization. That is, if you do model.compile(loss=[fun], metrics=[fun]), then the loss value will be different than the metric. I hate it. It’s dumb. Worse, the value is close to the metric value.

Since the way BPReveal tracks the components of the losses is by using different metrics, the assumption is that loss = sum(metrics). Instead of figuring out what regularization means, I simply redefine the loss in the logs for each epoch by overwriting the loss (and val_loss) items in the logs dict with the sums of the metrics. While the model trains based on the actual (perverted) loss, the callbacks see only the idealized loss value that this function stuffs into the logs dictionary.

Parameters:

heads (list[dict])

correctLosses(logs)

Get the corrected loss value and put it in logs.

Parameters:

logs (dict) – The logs from the current epoch or batch. This will be EDITED IN PLACE.

Returns:

Nothing, but does edit logs.

Return type:

None

on_epoch_end(_, logs=None)

Update the logs.

Parameters:
  • logs (dict | None) – The logs from the current epoch

  • _ (int)

Return type:

None

on_train_batch_end(_, logs=None)

Update the logs.

Parameters:
  • logs (dict | None) – The logs from the current batch

  • _ (int)

Return type:

None

on_test_batch_end(_, logs=None)

Update the logs.

Parameters:
  • logs (dict | None) – The logs from the current batch

  • _ (int)

Return type:

None

class bpreveal.callbacks.ApplyAdaptiveCountsLoss(*args, **kwargs)

Implements the adaptive counts loss algorithm.

This updates the counts loss weights in your model on the fly, so that you can specify a target fraction of the loss that is due to counts, and the model will automatically update the weight. It currently does not update profile weights, so you can’t yet automatically balance the losses in a multitask model. (You can still manually specify them in the config json, of course!)

Parameters:
  • heads (list[dict]) – Is straight from the json, except with “INTERNAL_counts-loss-weight-variable” entries in each head. If present, these variables are the Keras Variables that contain the loss weights. These will be updated in this callback.

  • aggression (float) – A number from 0 to 1.

  • lrPlateauCallback (keras.callbacks.ReduceLROnPlateau) – The learning rate plateau callback that your model is going to use.

  • earlyStopCallback (keras.callbacks.EarlyStopping) – The early stopping callback that your model is going to use.

  • checkpointCallback (keras.callbacks.ModelCheckpoint) – The checkpoint callback that your model is going to use.

aggression determines how aggressively the counts loss will be re-weighted. Lower values indicate slower changes, and a value of 1.0 means that the old loss should be completely discarded at each iteration. (If the newly calculated loss weight is ever off by over a factor of two, it’s clamped to be exactly twice (or half) of the old weight, so that gross instability in the early stages doesn’t cause the model to explode.)

The three callbacks are the others used in the model, and the model should run them BEFORE this callback is executed. This callback messes with their internal state (naughty, naughty!) because changing the loss weights could cause the model’s loss value to go up even though the model hasn’t gotten any worse.

on_train_begin(logs=None)

Set up the initial guesses for \({\lambda}\).

Parameters:

logs (dict | None) – Ignored.

Return type:

None

At the beginning of training, see which heads are using adaptive weights. For those heads, load in an initial guess for \({\lambda}\) based on the BPNet heuristic.

getLosses(epoch, headName)

Get what the losses actually were in a previous epoch using that epoch’s \({\lambda}\) values.

Parameters:
  • epoch (int) – The epoch for which you’d like losses.

  • headName (str) – The head for which you’d like losses.

Raises:

ValueError – If any of the loss terms were unrecognized.

Returns:

A tuple with four losses: (profile, counts, val_profile, val_counts).

Return type:

tuple[float, float, float, float]

What were the profile, counts, val_profile, and val_counts (in that order) losses at the given epoch? The counts weight returned from this function uses the counts weight in use at the given epoch. This method assumes that the validation and training losses match some regexes, and these are based on layer names. If someone got really goofy with head names (e.g., one head named “x” and another named “profile_x”, then this could get messed up.

whatWouldValLossBe(epoch)

Determine a previous epoch’s validation loss but use the current \({\lambda}\) values to do it.

Parameters:

epoch (int) – The epoch number where you’d like to know what the loss would have been.

Returns:

A float giving the corrected validation loss at that epoch.

Return type:

float

Had we been using the current \({\lambda}\) in a previous epoch, what would its validation loss have been?

resetCallbacks()

Manipulate the other callbacks so they don’t break when \({\lambda}\) changes.

This is the squirreliest method here. The other callbacks that track model progression track the loss of the model at the record-setting epoch. But the definition of loss itself is changing during the training, so we need to update their stored idea of what the model’s loss was during the training.

For example, consider the following scenario:

raw_loss  loss_weight  scaled_loss
10          1           10
9           1           9
8           1           8
7           2           14
6           2           12
5           2           10

A callback that was tracking the loss, looking for the minimum, would claim that the best epoch was epoch 3, where the scaled loss was eight. We need to go into that callback and say, “no, with our current loss weight, the loss on epoch three would actually have been 16.” so that the callback thinks that the loss of 14 on epoch 4 is an improvement.

Return type:

None

on_epoch_end(epoch, logs=None)

Update the other callbacks and calculate a new \({\lambda}\).

Parameters:
  • epoch (int) – The epoch number that just finished.

  • logs (dict | None) – The history logs from the last epoch.

Return type:

None

class bpreveal.callbacks.DisplayCallback(*args, **kwargs)

Replaces the tensorflow progress bar logger with lots of printing to stderr.

Parameters:
  • trainBatchGen (H5BatchGenerator) – The training batch generator.

  • valBatchGen (H5BatchGenerator) – The validation batch generator.

  • plateauCallback (keras.callbacks.ReduceLROnPlateau) – The plateau callback, used to access the LR schedule.

  • earlyStopCallback (keras.callbacks.EarlyStopping) – The EarlyStopping callback, used to see how long we have left.

  • adaptiveLossCallback (ApplyAdaptiveCountsLoss) – The adaptive loss callback, used to read \({\lambda}\) values.

epochNumber = 0

What is the currently-running epoch number?

batchNumber = 0

What is the currently-running training batch number?

printLocationsEpoch = {}

For a given data type, what row should it be printed on in the epoch pane? For example, “val_loss” might go on row 5.

printLocationsBatch = {}

What row should each data type go on in the batch pane?

multipliers = {}

For a given data type, what constant should it be multiplied by for display? This is used to weight profile losses by profile-loss-weight.

prevEpochLogs = None

The logs from last epoch

lastBatchTime: float

When did the last batch happen?

lastBatchEndTime = 0

When did the last batch that we printed finish?

lastValBatchTime: float

When did the last validation batch happen?

lastValBatchEndTime = 0

When did the last validation that we printed finish?

lastEpochEndTime = None

When did the last epoch finish?

lastEpochStartTime = None

When did the current epoch start?

numEpochs: int

What is the maximum number of training epochs?

curEpochWaitTime = None

How long between the end of the last epoch and the start of this one?

maxLen = 0

Of all the data types, what is the length of the longest name? Used to calculate column positions.

trainBeginTime: float

When did the whole training process start?

firstBatchTime: float

When did we see our first batch of this epoch?

firstValBatchTime: float

When did we see our first validation batch of this epoch?

curEpochStartTime: float

When did the current epoch start?

numBatches: int

How many training batches per epoch?

numValBatches: int

How many validation batches per epoch?

ignoreMetrics: list[str]

Names of metrics (i.e., loss terms) that we don’t want to print.

on_train_begin(logs=None)

Just loads in the total number of epochs.

Parameters:

logs (dict | None) – The logs for the current batch

Return type:

None

on_epoch_begin(epoch, logs=None)

Just sets the timers up, so you can check how long an epoch took at the end.

Parameters:
  • epoch (int) – The epoch number

  • logs (dict | None) – The logs for the current batch

Return type:

None

formatStr(val)

Formats an object to be 11 characters wide.

If a second object is provided, format as a ratio.

Parameters:

val (str | int | float | tuple[int, int]) – The thing to format

Returns:

An 11 character wide formatted string.

Return type:

str

on_epoch_end(epoch, logs=None)

Writes out all the logs for this epoch and the last one at INFO logging level.

Parameters:
  • epoch (int) – The epoch number

  • logs (dict | None) – The logs for the current batch

Return type:

None

on_train_batch_end(batch, logs=None)

Write the loss info for the current batch at DEBUG level.

Parameters:
  • batch (int) – The batch number

  • logs (dict | None) – The logs for the current batch

Return type:

None

on_test_batch_end(batch, logs=None)

Just emit a counter with the batch number at DEBUG level.

Parameters:
  • batch (int) – The batch number

  • logs (dict | None) – The logs for the current batch

Return type:

None

bpreveal.callbacks.getCallbacks(earlyStop, outputPrefix, plateauPatience, heads, trainBatchGen, valBatchGen)

Return a set of callbacks for your model.

Parameters:
  • earlyStop (int) – The early-stopping-patience from the config file.

  • outputPrefix (str) – The output-prefix for the model, including directory.

  • plateauPatience (int) – The learning-rate-plateau-patience from the config file.

  • heads (list[dict]) – The heads list for your model, to which adaptive loss \({\lambda}\) tensors have been added.

  • trainBatchGen (H5BatchGenerator) – The batch generator for training. Just used to see how many batches there will be.

  • valBatchGen (H5BatchGenerator) – The batch generator for validation. Just used to see how many batches there will be.

Returns:

A list of Keras callbacks that you should use to train your model.

Return type:

tuple[FixLossCallback, keras.callbacks.EarlyStopping, keras.callbacks.ModelCheckpoint, keras.callbacks.ReduceLROnPlateau, ApplyAdaptiveCountsLoss, DisplayCallback]

The returned callbacks are:

EarlyStopping

Stop training if the validation loss hasn’t improved for a while.

ModelCheckpoint

Write a checkpoint file every time the validation loss improves.

ReduceLROnPlateau

If validation loss hasn’t improved for a while, decrease the learning rate.

ApplyAdaptiveCountsLoss

Implement the adaptive counts loss algorithm.

DisplayCallback

Write log files in a format that is compatible with showTrainingProgress