Adaptive Counts Loss

In old BPNet, you had to specify a counts loss weight that was used during training. This parameter adjusts how much of the loss is due to the counts component (mean-squared-error of logcounts) and how much is due to the profile (multinomial log likelihood). This is great, but you don’t know beforehand what exact number you should apply to the counts loss in order to get your desired fraction. As an equation,

\[f \equiv \frac{\lambda c}{\lambda c + p}\]

Where \(f\) is the fraction of loss due to counts, \(c\) is the raw counts loss, \(p\) is the raw profile loss, and \({\lambda}\) is the counts-loss-weight parameter. If you want, say, ten percent of your loss to come from counts, you need to pick a value of \({\lambda}\) so that that \(f = 0.1\). But since we don’t know the values of \(c\) and \(p\) that we’ll get during training, we have to guess.

The adaptive counts loss algorithm skirts this issue by updating \({\lambda}\) during training to match your desired \(f\). Starting with the value of counts-loss-weight you specify in the configuration file, the algorithm springs into action at the end of each epoch.

\[\begin{split}\lambda^\prime_{E+1} &= \frac{p f_{target}}{(1-f_{target}) c} \\ \lambda^{\prime\prime}_{E+1} &= \beta \lambda^\prime_{E+1} + (1 - \beta) \lambda_{E} \\ \gamma &\equiv \frac{\lambda^{\prime\prime}_{E+1}}{\lambda_E} \\ \lambda_{E+1} &= \begin{cases} 2 \lambda_E & \gamma > 2 \\ \lambda^{\prime\prime}_{E+1} & \frac{1}{2} \le \gamma \le 2 \\ \frac{\lambda_E}{2} & \gamma < \frac{1}{2} \end{cases}\end{split}\]

where \(\lambda_{E+1}\) is the \({\lambda}\) value for the next epoch, \(\lambda_E\) is the current \(\lambda\), \(\lambda^{\prime}\) is the value of \(\lambda\) which, given the current profile loss \(p\) and counts loss \(c\), would give the fraction target \(f_{target}\). (\(f_{target}\) should be between 0 (only care about profile) and 1 (only care about counts), exclusive. A normal value would be something like 0.1.) \(\lambda^{\prime\prime}\) is a smoothed version of \(\lambda\), with aggression parameter \(\beta\). \(\beta=0\) implies no change ever to \(\lambda\), while \(\beta = 1\) means that the \(\lambda\) from the last epoch is ignored. The scripts in bpreveal use a \(\beta\) parameter of 0.3. This is currently not user-configurable, but can be edited in callbacks.

Early in training, \(\lambda_E\) may be way off from \(\lambda^{\prime\prime}_{E+1}\), and so we clamp \(\lambda_{E+1}\) to be at most a factor of two off from \(\lambda_E\).