Assessing Generalization via Disagreement

Our paper “A Note on ‘Assessing Generalization of SGD via Disagreement’” was published in TMLR this week and serves both as a short reproduction and review note. It engages with the claims in “Assessing Generalization of SGD via Disagreement” by Jiang et al. (2022) , which received an ICLR 2022 spotlight. We would like to thank the authors for constructively engaging with our note on OpenReview.

TL;DR. We simplify the theoretical statements and proofs, showing them to be straightforward within a probabilistic context (unlike the original hypothesis space view). Empirically, we ask whether the suggested theory might be impractical under distribution shifts because model calibration can deteriorate as prediction disagreement increases. We show this on CIFAR-10/CINIC-10 and ImageNet/PACS. This is precisely when the proposed coupling of test error and disagreement is of the most interest. At the same time, labels are needed to estimate the calibration on new datasets. The authors of “Assessing Generalization (of SGD)1 via Disagreement” seem to agree with this: they adjusted the paper to be clearer about being only about in-distribution data for their camera-ready.

Connecting Generalization and Disagreement

Several recent papers connect the generalization error of a model to the model’s prediction disagreement . The prediction disagreement equals the model’s own prediction that it is wrong for a given sample \(x\): \[\mathbb{E}_{p(y|x)} \, [1 - p(y|x)],\] where \(p(y|x)\) are the predictions for a given \(x\), i.e., the predicted error.

Prediction agreement is also approximating a model’s predictive entropy (uncertainty). This can be seen via a first-order Taylor expansion of Shannon’s information content around 1: \(H(p(y|x)) = -\log p(y|x) \ge 1 - p(y|x)\) and taking the expectation over \(p(y|x)\). Predictive entropy (and maximum class confidence) are well-known metrics that are used in OOD detection literature , yet this recent literature does often not compare to it: the task of detecting samples the model will be wrong on is essentially “very-near-OOD detection”.

If we were to summarize the major (empirical) claim of these works, it would be that the prediction disagreement is a good proxy for the generalization error and that we can use it to estimate it. On average, we can trust a model’s predictive uncertainty to tell us when the model will be wrong for a given sample. That is, after identifying samples that are likely to be wrong using the prediction disagreement, that fraction will likely be close to the generalization error that we could measure if we had the actual labels.

Please note that this is very simplified, and the various papers differ in their claims and approaches—however, this does capture the gist.

However, it is easy to misunderstand the claims of these recent papers:

It is tempting to extend this claim to data under distribution shift. Then we could use the prediction disagreement to estimate the generalization error on a new dataset, even when we do not have access to the labels of the new dataset and do not know how that dataset differs from the training data.

We have to be careful with this, however. The prediction disagreement can be a good proxy for the generalization error on the training data when the model is well-calibrated, but it is not necessarily a good proxy for the generalization error on new data. The reason is that the gap between prediction disagreement and generalization error is bounded by the model’s calibration , which depends on the data distribution. If the training data and the new data are different, the model’s calibration will be different, and while calibration might be good on the training data, it might be bad on the new data. (This is also illustrated in the figures below.)

A Note on ‘Assessing Generalization of SGD via Disagreement’

Jiang et al. (2022) show that when a model is well-calibrated according to proposed calibration metrics (class-wise and class-aggregated calibration), we can use the prediction disagreement to predict the generalization error. Furthermore, they find this also seems to hold under some distribution shifts. Follow-up work (Baek et al., 2022 ) looks at distribution shifts in more detail.

We empirically validate on CIFAR-10/CINIC-10 and ImageNet/PACS that the proposed calibration metrics worsen as prediction disagreement increases, which is also in line with Ovadia et al. (2019) . Because the gap between prediction disagreement and generalization error increases, we cannot compute the calibration once for in-distribution data and apply it to samples drawn under distribution shift using the proposed theory. The authors agree with this and have added clarifications to the camera-ready version of their paper.

This takeaway likely also applies to other works: as a model gets more uncertain about its predictions, e.g., higher prediction disagreement, it likely becomes less reliable and calibrated.

Another takeaway of our paper is that a probabilistic notation (and modeling the parameters as a parameter distribution) is easier to work with than a hypothesis space: we look at the proposed calibration metrics and theoretical results, and we simplify them quite a bit.

Figure 1: Rejection Plot of Calibration Metrics for Increasing Disagreement Under Distribution Shift (CINIC-10). Different calibration metrics (ECE, CWCE, CACE) vary across CINIC-10 on an ensemble of 25 Wide-ResNet-28-10 model trained on CIFAR-10, depending on the rejection threshold of the predicted error (disagreement rate). Thus, calibration cannot be assumed constant under distribution shift. The test error increases almost linearly with the predicted error (disagreement rate), leading ‘GDE gap’ |Test Error − Predicted Error| to become almost flat, providing evidence for the empirical observations in Nakkiran & Bansal (2020) ; Jiang et al. (2022) . The mean predicted error (disagreement rate) is shown on the x-axis. Left shows results for an ensemble using TOP (following Jiang et al. (2022) ), and right for a regular deep ensemble without TOP. The regular deep ensemble is better calibrated but has higher test error overall and lower test error for samples with small predicted error.
Figure 2: Rejection Plot of Calibration Metrics for Increasing Disagreement In-Distribution (ImageNet). Different calibration metrics (ECE, CWCE, CACE) vary across ImageNet across an ensemble of 5 models trained on ImageNet depending on the rejection threshold of the predicted error (disagreement rate). Again, calibration cannot be assumed constant for in-distribution data. The mean predicted error (disagreement rate) is shown on the x-axis. Left shows results for an ensemble using TOP (following Jiang et al. (2022) )), and right for a regular deep ensemble without TOP.

Generalization Disagreement Equality

Jiang et al. (2022) specifically introduce the Generalization Disagreement Equality (GDE): a model satisfies GDE when its predicted error (prediction disagreement) equals the actual generalization error.

The predicted error is the error if the model’s predictions were true for a sample: e.g., if the model predicts 80% probability for class A and 20% for class B, then the predicted accuracy \(\mathbb{E}_{p(y|x)} \, [p(y|x)]\) is 80% * 80% + 20% * 20% = 68% and the predicted error (prediction disagreement) is 1 - 0.68 = 32%.

We show that GDE immediately follows from the proposed calibration metrics (class-wise and class-aggregated calibration error). We also look at prior art and find that class-wise and class-aggregated calibration error have been implemented previously by Nixon et al. (2019) as “adaptive calibration error” and “static calibration error.” Moreover, Kumar et al. (2019) and Kull et al. (2019) have also introduced class-wise calibration as “marginal calibration error” and “classwise calibration”.

For additional context, the recently accepted NeurIPS 2022 paper by Gruber and Büttner (2022) provides a comprehensive comparison of different calibration metrics.

Bayesian Perspective

From a Bayesian perspective, we usually look at epistemic uncertainty to tell us how reliable a model might be for new data. We look at these connections in our paper a bit as well and see that the proposed theory does not disentangle aleatoric and epistemic uncertainty. (Looking at the consequences is future work and an interesting research question.)

Epistemic uncertainty tells us when the model “knows” that there are multiple interpretations, and it is not sure which one is correct. In contrast, aleatoric uncertainty tells us when the model “knows” the data is noisy . That is, epistemic uncertainty equals model disagreement (also known as BALD score), and aleatoric uncertainty equals data disagreement. Finally, predictive uncertainty (entropy) is the sum of epistemic and aleatoric uncertainty.

In other words, epistemic uncertainty is the uncertainty that we can reduce by training with more data while aleatoric uncertainty is the uncertainty we cannot reduce by training with more data (e.g., data noise).

We cannot trust the model’s predictions for samples with high epistemic uncertainty. If we want to be conservative, we need to assume the model’s prediction will be wrong under high epistemic uncertainty. Likewise, for high aleatoric uncertainty (data noise), the model’s (one-hot) prediction will also likely be wrong. Using predictive uncertainty (prediction disagreement) as the sum of the two can give us a lower bound on the “worst-case generalization error” but not necessarily more. Even when the model is confident, it could still be wrong.

Full Paper

For more details, see our full paper: the OpenReview link is here, and the arXiv link is here.

Acknowledgements

I would like to thank Yarin Gal, as well as the members of OATML in general for their continued feedback.

More from OATML

For more blog posts by OATML in Oxford, check out our group’s blog https://oatml.cs.ox.ac.uk/blog.html.


  1. While the paper’s official title is “Assessing Generalization of SGD via Disagreement” on arXiv and OpenReview, the paper itself is aptly titled “Assessing Generalization via Disagreement” because the results do not depend on SGD itself.↩︎

Follow me on Twitter @blackhc