ICLR 2023 Oral | Batch Norm層暴露TTA短板,開放環境下解決方案來了

機器之心pro 發佈 2024-04-30T08:24:35.694712+00:00

基於這些分析,本文最終提出銳度敏感且可靠的測試時熵最小化方法,通過抑制某些具有較大梯度 / 噪聲測試樣本對模型更新的影響,實現了穩定、高效的模型在線測試時自適應。

機器之心專欄

機器之心編輯部

測試時自適應(Test-Time Adaptation, TTA)方法在測試階段指導模型進行快速無監督 / 自監督學習,是當前用於提升深度模型分布外泛化能力的一種強有效工具。然而在動態開放場景中,穩定性不足仍是現有 TTA 方法的一大短板,嚴重阻礙了其實際部署。為此,來自華南理工大學、騰訊 AI Lab 及新加坡國立大學的研究團隊,從統一的角度對現有 TTA 方法在動態場景下不穩定原因進行分析,指出依賴於 Batch 的歸一化層是導致不穩定的關鍵原因之一,另外測試數據流中某些具有噪聲 / 大規模梯度的樣本容易將模型優化至退化的平凡解。基於此進一步提出銳度敏感且可靠的測試時熵最小化方法 SAR,實現動態開放場景下穩定、高效的測試時模型在線遷移泛化。本工作已入選 ICLR 2023 Oral (Top-5% among accepted papers)。

論文標題:Towards Stable Test-time Adaptation in Dynamic Wild World

  • 論文地址:https://openreview.net/forum?id=g2YraF75Tj
  • 開原始碼:https://github.com/mr-eggplant/SAR

什麼是 Test-Time Adaptation?

傳統機器學習技術通常在預先收集好的大量訓練數據上進行學習,之後固定模型進行推理預測。這種範式在測試與訓練數據來自相同數據分布時,往往取得十分優異的表現。但在實際應用中,測試數據的分布很容易偏離原始訓練數據的分布(distribution shift),例如在採集測試數據的時候:1)天氣的變化使得圖像中包含有雨雪、霧的遮擋;2)由於拍攝不當使得圖像模糊,或傳感器退化導致圖像中包含噪聲;3)模型基於北方城市採集數據進行訓練,卻被部署到了南方城市。以上種種情況十分常見,但對於深度模型而言往往是很致命的,因為在這些場景下其性能可能會大幅下降,嚴重製約了其在現實世界中(尤其是類似於自動駕駛等高風險應用)的廣泛部署。

圖 1 Test-Time Adaptation 示意圖(參考 [5])及其與現有方法特點對比

不同於傳統機器學習範式,如圖 1 所示在測試樣本到來後,Test-Time Adaptation (TTA) 首先基於該數據利用自監督或無監督的方式對模型進行精細化微調,而後再使用更新後的模型做出最終預測。典型的自 / 無監督學習目標包括:旋轉預測、對比學習、熵最小化等等。這些方法均展現出了優異的分布外泛化(Out-of-Distribution Generalization)性能。相較於傳統的 Fine-Tuning 以及 Unsupervised Domain Adaptation 方法,Test-Time Adaptation 能夠做到在線遷移,效率更高也更加普適。另外完全測試時適應方法 [2] 其可以針對任意預訓練模型進行適應,無需原始訓練數據也無需干涉模型原始的訓練過程。以上優點極大增強了 TTA 方法的現實通用性,再加上其展現出來的優異性能,使得 TTA 成為遷移、泛化等相關領域極為熱點的研究方向。

為什麼要 Wild Test-Time Adaptation?

儘管現有 TTA 方法在分布外泛化方面已表現出了極大的潛力,但這種優異的性能往往是在一些特定的測試條件下所獲得的,例如測試數據流在一段時間內的樣本均來自於同一種分布偏移類型、測試樣本的真實類別分布是均勻且隨機的,以及每次需要有一個 mini-batch 的樣本後才可以進行適應。但事實上,以上這些潛在假設在現實開放世界中是很難被一直滿足的。在實際中,測試數據流可能以任意的組合方式到來,而理想情況下模型不應對測試數據流的到來形式做出任何假設。如圖 2 所示,測試數據流完全可能遇到:(a)樣本來自不同的分布偏移(即混合樣本偏移);(b)樣本 batch size 非常小(甚至為 1);(c)樣本在一段時間內的真實類別分布是不均衡的且會動態變化的。本文將上述場景下的 TTA 統稱為 Wild TTA。但不幸的是,現有 TTA 方法在這些 Wild 場景下經常會表現得十分脆弱、不穩定,遷移性能有限,甚至可能損壞原始模型的性能。因此,若想真正實現 TTA 方法在實際場景中的大範圍、深度化應用部署,則解決 Wild TTA 問題即是其中不可避免的重要一環。

圖 2 模型測試時自適應中的動態開放場景

解決思路與技術方案

本文從統一角度對 TTA 在眾多 Wild 場景下失敗原因進行分析,進而給出解決方案。

1. 為何 Wild TTA 會不穩定?

(1)Batch Normalization (BN) 是導致動態場景下 TTA 不穩定的關鍵原因之一:現有 TTA 方法通常是建立在 BN 統計量自適應基礎之上的,即使用測試數據來計算 BN 層中的均值及標準差。然而,在 3 種實際動態場景中,BN 層內的統計量估計準確性均會出現偏差,從而引發不穩定的 TTA:

場景(a):由於 BN 的統計量實際上代表了某一種測試數據分布,使用一組統計量參數同時估計多個分布不可避免會獲得有限的性能,參見圖 3;

場景(b):BN 的統計量依賴於 batch size 大小,在小 batch size 樣本上很難得到準確的 BN 的統計量估計,參見圖 4;

場景(c):非均衡標籤分布的樣本會導致 BN 層內統計量存在偏差,即統計量偏向某一特定類別(該 batch 中占比較大的類別),參見圖 5;

為進一步驗證上述分析,本文考慮 3 種廣泛應用的模型(搭載不同的 Batch\Layer\Group Norm),基於兩種代表性 TTA 方法(TTT [1] 和 Tent [2])進行分析驗證。最終得出結論為:batch 無關的 Norm 層(Group 和 Layer Norm)一定程度上規避了 Batch Norm 局限性,更適合在動態開放場景中執行 TTA,其穩定性也更高。因此,本文也將基於搭載 Group\Layer Norm 的模型進行方法設計。

圖 3 不同方法和模型(不同歸一化層)在混合分布偏移下性能表現

圖 4 不同方法和模型(不同歸一化層)在不同 batch size 下性能表現。圖中陰影區域表示該模型性能的標準差,ResNet50-BN 和 ResNet50-GN 的標準差過小導致在圖中不顯著(下圖同)

圖 5 不同方法和模型(不同歸一化層)在在線不平衡標籤分布偏移下性能表現,圖中橫軸 Imbalance Ratio 越大代表的標籤不平衡程度越嚴重

(2)在線熵最小化易將模型優化至退化的平凡解,即將任意樣本預測到同一個類:根據圖 6 (a) 和 (b) 顯示,在分布偏移程度嚴重(level 5)時,在線自適應過程中突然出現了模型退化崩潰現象,即所有樣本(真實類別不同)被預測到同一類;同時,模型梯度的 範數在模型崩潰前後快速增大而後降至幾乎為 0,見圖 6(c),側面說明可能是某些大尺度 / 噪聲梯度破壞了模型參數,進而導致模型崩潰。

圖 6 在線測試時熵最小化中的失敗案例分析

2. 銳度敏感且可靠的測試時熵最小化方法

為了緩解上述模型退化問題,本文提出了銳度敏感且可靠的測試時熵最小化方法 (Sharpness-aware and Reliable Entropy Minimization Method, SAR)。其從兩個方面緩解這一問題:1)可靠熵最小化從模型自適應更新中移除部分產生較大 / 噪聲梯度的樣本;2)模型銳度優化使得模型對剩餘樣本中所產生的某些噪聲梯度不敏感。具體細節闡述如下:

可靠熵最小化:基於 Entropy 建立梯度選擇的替代判斷指標,將高熵樣本(包含圖 6 (d) 中區域 1 和 2 的樣本)排除在模型自適應之外不參與模型更新:

其中 x 表示測試樣本,Θ 表示模型參數, 表示指示函數, 表示樣本預測結果的熵, 為超參數。僅當 時樣本才會參與反向傳播計算。

銳度敏感的熵優化:通過可靠樣本選擇機制過濾後的樣本中,無法避免仍含有圖 6 (d) 區域 4 中的樣本,這些樣本可能產生噪聲 / 較大梯度繼續干擾模型。為此,本文考慮將模型優化至一個 flat minimum,使其能夠對噪聲梯度帶來的模型更新不敏感,即不影響其原始模型性能,優化目標為:

上述目標的最終梯度更新形式如下:

其中 受啟發於 SAM [4] 通過一階泰勒展開近似求解得到,具體細節可參見本論文原文與代碼。

至此,本文的總體優化目標為:

此外,為了防止極端條件下上述方案仍可能失敗的情況,進一步引入了一個模型復原策略:通過移動監測模型是否出現退化崩潰,決定在必要時刻對模型更新參數進行原始值恢復。

實驗評估

在動態開放場景下的性能對比

SAR 基於上述三種動態開放場景,即 a)混合分布偏移、b)單樣本適應和 c)在線不平衡類別分布偏移,在 ImageNet-C 數據集上進行實驗驗證,結果如表 1, 2, 3 所示。SAR 在三種場景中均取得顯著效果,特別是在場景 b)和 c)中,SAR 以 VitBase 作為基礎模型,準確率超過當前 SOTA 方法 EATA 接近 10%。

表 1 SAR 與現有方法在 ImageNet-C 的 15 種損壞類型混合場景下性能對比,對應動態場景 (a);以及和現有方法的效率對比

表 2 SAR 與現有方法在 ImageNet-C 上單樣本適應場景中的性能對比,對應動態場景 (b)

表 3 SAR 與現有方法在 ImageNet-C 上在線非均衡類別分布偏移場景中性能對比,對應動態場景(c)

消融實驗

與梯度裁剪方法的對比:梯度裁剪避免大梯度影響模型更新(甚至導致坍塌)的一種簡單且直接的方法。此處與梯度裁剪的兩個變種(即:by value or by norm)進行對比。如下圖所示,梯度裁剪對於梯度裁剪閾值 δ 的選取很敏感,較小的 δ 與模型不更新的結果相當,較大的 δ 又難以避免模型坍塌。相反,SAR 不需要繁雜的超參數篩選過程且性能顯著優於梯度裁剪。

圖 7 與梯度裁剪方法的在 ImageNet-C(shot nosise, level 5) 上在線不平衡標籤分布偏移場景中的性能對比。其中準確率是基於所有之前的測試樣本在線計算得出

不同模塊對算法性能的影響:如下表所示,SAR 的不同模塊協同作用,有效提升了動態開放場景下測試時模型自適應穩定性。

表 4 SAR 在 ImageNet-C (level 5) 上在線不平衡標籤分布偏移場景下的消融實驗

Loss 表面的銳度可視化:通過在模型權重增加擾動對損失函數可視化的結果如下圖所示。其中,SAR 相較於 Tent 在最低損失等高線內的區域(深藍色區域)更大,表明 SAR 獲得的解更加平坦,對於噪聲 / 較大梯度更加魯棒,抗干擾能力更強。

圖 8 熵損失表面可視化

結語

本文致力於解決在動態開放場景中模型在線測試時自適應不穩定的難題。為此,本文首先從統一的角度對已有方法在實際動態場景失效的原因進行分析,並設計完備的實驗對其進行深度驗證。基於這些分析,本文最終提出銳度敏感且可靠的測試時熵最小化方法,通過抑制某些具有較大梯度 / 噪聲測試樣本對模型更新的影響,實現了穩定、高效的模型在線測試時自適應。

參考文獻

[1] Yu Sun, Xiaolong Wang, Zhuang Liu, John Miller, Alexei Efros, and Moritz Hardt. Test-time training with self-supervision for generalization under distribution shifts. In International Conference on Machine Learning, pp. 9229–9248, 2020.

[2] Dequan Wang, Evan Shelhamer, Shaoteng Liu, Bruno Olshausen, and Trevor Darrell. Tent: Fully test-time adaptation by entropy minimization. In International Conference on Learning Representations, 2021.

[3] Shuaicheng Niu, Jiaxiang Wu, Yifan Zhang, Yaofo Chen, Shijian Zheng, Peilin Zhao, and Mingkui Tan. Efficient test-time model adaptation without forgetting. In International Conference on Machine Learning, pp. 16888–16905, 2022.

[4] Pierre Foret, Ariel Kleiner, Hossein Mobahi, and Behnam Neyshabur. Sharpness-aware minimization for efficiently improving generalization. In International Conference on Learning Representations, 2021.

[5] Tong Wu, Feiran Jia, Xiangyu Qi, Jiachen T. Wang, Vikash Sehwag, Saeed Mahloujifar, and Prateek Mittal. Uncovering adversarial risks of test-time adaptation. arXiv preprint arXiv:2301.12576, 2023.

關鍵字: