callbacks
A set of callbacks useful for training a model.
- class bpreveal.callbacks.FixLossCallback(*args, **kwargs)
Fixes the loss terms in the logs to reflect the metrics, not the actual loss values.
This is because of a stupid in Keras. If you use a loss function for the metrics, they will not give the same values, because of something something regularization. That is, if you do
model.compile(loss=[fun], metrics=[fun]), then the loss value will be different than the metric. I hate it. It’s dumb. Worse, the value is close to the metric value.Since the way BPReveal tracks the components of the losses is by using different metrics, the assumption is that loss = sum(metrics). Instead of figuring out what regularization means, I simply redefine the loss in the logs for each epoch by overwriting the
loss(andval_loss) items in the logs dict with the sums of the metrics. While the model trains based on the actual (perverted) loss, the callbacks see only the idealized loss value that this function stuffs into the logs dictionary.- Parameters:
heads (list[dict])
- correctLosses(logs)
Get the corrected loss value and put it in
logs.- Parameters:
logs (dict) – The logs from the current epoch or batch. This will be EDITED IN PLACE.
- Returns:
Nothing, but does edit logs.
- Return type:
None
- on_epoch_end(_, logs=None)
Update the logs.
- Parameters:
_ (int)
logs (dict | None)
- Return type:
None
- on_train_batch_end(_, logs=None)
Update the logs.
- Parameters:
_ (int)
logs (dict | None)
- Return type:
None
- on_test_batch_end(_, logs=None)
Update the logs.
- Parameters:
_ (int)
logs (dict | None)
- Return type:
None
- class bpreveal.callbacks.ApplyAdaptiveCountsLoss(*args, **kwargs)
Implements the adaptive counts loss algorithm.
This updates the counts loss weights in your model on the fly, so that you can specify a target fraction of the loss that is due to counts, and the model will automatically update the weight. It currently does not update profile weights, so you can’t yet automatically balance the losses in a multitask model. (You can still manually specify them in the config json, of course!)
- Parameters:
heads (list[dict]) – Is straight from the json, except with “INTERNAL_counts-loss-weight-variable” entries in each head. If present, these variables are the Keras Variables that contain the loss weights. These will be updated in this callback.
aggression (float) – A number from 0 to 1.
lrPlateauCallback (keras.callbacks.ReduceLROnPlateau) – The learning rate plateau callback that your model is going to use.
earlyStopCallback (keras.callbacks.EarlyStopping) – The early stopping callback that your model is going to use.
checkpointCallback (keras.callbacks.ModelCheckpoint) – The checkpoint callback that your model is going to use.
aggressiondetermines how aggressively the counts loss will be re-weighted. Lower values indicate slower changes, and a value of 1.0 means that the old loss should be completely discarded at each iteration. (If the newly calculated loss weight is ever off by over a factor of two, it’s clamped to be exactly twice (or half) of the old weight, so that gross instability in the early stages doesn’t cause the model to explode.)The three callbacks are the others used in the model, and the model should run them BEFORE this callback is executed. This callback messes with their internal state (naughty, naughty!) because changing the loss weights could cause the model’s loss value to go up even though the model hasn’t gotten any worse.
- on_train_begin(logs=None)
Set up the initial guesses for \({\lambda}\).
- Parameters:
logs (dict | None) – Ignored.
- Return type:
None
At the beginning of training, see which heads are using adaptive weights. For those heads, load in an initial guess for \({\lambda}\) based on the BPNet heuristic.
- getLosses(epoch, headName)
Get what the losses actually were in a previous epoch using that epoch’s \({\lambda}\) values.
- Parameters:
epoch (int) – The epoch for which you’d like losses.
headName (str) – The head for which you’d like losses.
- Returns:
A tuple with four losses: (profile, counts, val_profile, val_counts).
- Return type:
tuple[float, float, float, float]
What were the profile, counts, val_profile, and val_counts (in that order) losses at the given epoch? The counts weight returned from this function uses the counts weight in use at the given epoch. This method assumes that the validation and training losses match some regexes, and these are based on layer names. If someone got really goofy with head names (e.g., one head named “x” and another named “profile_x”, then this could get messed up.
- whatWouldValLossBe(epoch)
Determine a previous epoch’s validation loss but use the current \({\lambda}\) values to do it.
- Parameters:
epoch (int) – The epoch number where you’d like to know what the loss would have been.
- Returns:
A float giving the corrected validation loss at that epoch.
- Return type:
float
Had we been using the current \({\lambda}\) in a previous epoch, what would its validation loss have been?
- resetCallbacks()
Manipulate the other callbacks so they don’t break when \({\lambda}\) changes.
This is the squirreliest method here. The other callbacks that track model progression track the loss of the model at the record-setting epoch. But the definition of loss itself is changing during the training, so we need to update their stored idea of what the model’s loss was during the training.
For example, consider the following scenario:
raw_loss loss_weight scaled_loss 10 1 10 9 1 9 8 1 8 7 2 14 6 2 12 5 2 10
A callback that was tracking the loss, looking for the minimum, would claim that the best epoch was epoch 3, where the scaled loss was eight. We need to go into that callback and say, “no, with our current loss weight, the loss on epoch three would actually have been 16.” so that the callback thinks that the loss of 14 on epoch 4 is an improvement.
- Return type:
None
- on_epoch_end(epoch, logs=None)
Update the other callbacks and calculate a new \({\lambda}\).
- Parameters:
epoch (int) – The epoch number that just finished.
logs (dict | None) – The history logs from the last epoch.
- Return type:
None
- class bpreveal.callbacks.DisplayCallback(*args, **kwargs)
Replaces the tensorflow progress bar logger with lots of printing to stderr.
- Parameters:
trainBatchGen (H5BatchGenerator) – The training batch generator.
valBatchGen (H5BatchGenerator) – The validation batch generator.
plateauCallback (keras.callbacks.ReduceLROnPlateau) – The plateau callback, used to access the LR schedule.
earlyStopCallback (keras.callbacks.EarlyStopping) – The EarlyStopping callback, used to see how long we have left.
adaptiveLossCallback (ApplyAdaptiveCountsLoss) – The adaptive loss callback, used to read \({\lambda}\) values.
- epochNumber = 0
What is the currently-running epoch number?
- batchNumber = 0
What is the currently-running training batch number?
- printLocationsEpoch = {}
For a given data type, what row should it be printed on in the epoch pane? For example, “val_loss” might go on row 5.
- printLocationsBatch = {}
What row should each data type go on in the batch pane?
- multipliers = {}
For a given data type, what constant should it be multiplied by for display? This is used to weight profile losses by profile-loss-weight.
- prevEpochLogs = None
The logs from last epoch
- lastBatchTime: float
When did the last batch happen?
- lastBatchEndTime = 0
When did the last batch that we printed finish?
- lastValBatchTime: float
When did the last validation batch happen?
- lastValBatchEndTime = 0
When did the last validation that we printed finish?
- lastEpochEndTime = None
When did the last epoch finish?
- lastEpochStartTime = None
When did the current epoch start?
- numEpochs: int
What is the maximum number of training epochs?
- curEpochWaitTime = None
How long between the end of the last epoch and the start of this one?
- maxLen = 0
Of all the data types, what is the length of the longest name? Used to calculate column positions.
- trainBeginTime: float
When did the whole training process start?
- firstBatchTime: float
When did we see our first batch of this epoch?
- firstValBatchTime: float
When did we see our first validation batch of this epoch?
- curEpochStartTime: float
When did the current epoch start?
- numBatches: int
How many training batches per epoch?
- numValBatches: int
How many validation batches per epoch?
- ignoreMetrics: list[str]
Names of metrics (i.e., loss terms) that we don’t want to print.
- on_train_begin(logs=None)
Just loads in the total number of epochs.
- Parameters:
logs (dict | None)
- Return type:
None
- on_epoch_begin(epoch, logs=None)
Just sets the timers up, so you can check how long an epoch took at the end.
- Parameters:
epoch (int)
logs (dict | None)
- Return type:
None
- formatStr(val)
Formats an object to be 11 characters wide.
If a second object is provided, format as a ratio.
- Parameters:
val (str | int | float | tuple[int, int]) – The thing to format
- Returns:
An 11 character wide formatted string.
- Return type:
str
- on_epoch_end(epoch, logs=None)
Writes out all the logs for this epoch and the last one at INFO logging level.
- Parameters:
epoch (int)
logs (dict | None)
- Return type:
None
- on_train_batch_end(batch, logs=None)
Write the loss info for the current batch at DEBUG level.
- Parameters:
batch (int)
logs (dict | None)
- Return type:
None
- on_test_batch_end(batch, logs=None)
Just emit a counter with the batch number at DEBUG level.
- Parameters:
batch (int)
logs (dict | None)
- Return type:
None
- bpreveal.callbacks.getCallbacks(earlyStop, outputPrefix, plateauPatience, heads, trainBatchGen, valBatchGen)
Return a set of callbacks for your model.
- Parameters:
earlyStop (int) – The
early-stopping-patiencefrom the config file.outputPrefix (str) – The
output-prefixfor the model, including directory.plateauPatience (int) – The
learning-rate-plateau-patiencefrom the config file.heads (list[dict]) – The heads list for your model, to which adaptive loss \({\lambda}\) tensors have been added.
trainBatchGen (H5BatchGenerator) – The batch generator for training. Just used to see how many batches there will be.
valBatchGen (H5BatchGenerator) – The batch generator for validation. Just used to see how many batches there will be.
- Returns:
A list of Keras callbacks that you should use to train your model.
- Return type:
tuple[FixLossCallback, keras.callbacks.EarlyStopping, keras.callbacks.ModelCheckpoint, keras.callbacks.ReduceLROnPlateau, ApplyAdaptiveCountsLoss, DisplayCallback]
The returned callbacks are:
- EarlyStopping
Stop training if the validation loss hasn’t improved for a while.
- ModelCheckpoint
Write a checkpoint file every time the validation loss improves.
- ReduceLROnPlateau
If validation loss hasn’t improved for a while, decrease the learning rate.
- ApplyAdaptiveCountsLoss
Implement the adaptive counts loss algorithm.
- DisplayCallback
Write log files in a format that is compatible with
showTrainingProgress