基於一致性的半監督語義分割方法:刷新多項SOTA,還有更好泛化性

機器之心pro 發佈 2022-08-10T09:15:40.134153+00:00

機器之心專欄機器之心編輯部在本工作中,來自阿德萊德大學、烏魯姆大學的研究者針對當前一致性學習出現的三個問題做了針對性的處理, 使得經典的 teacher-student 架構 (A.K.A Mean-Teacher) 在半監督圖像切割任務上得到了顯著的提升。

機器之心專欄

機器之心編輯部

在本工作中,來自阿德萊德大學、烏魯姆大學的研究者針對當前一致性學習出現的三個問題做了針對性的處理, 使得經典的 teacher-student 架構 (A.K.A Mean-Teacher) 在半監督圖像切割任務上得到了顯著的提升。

該研究已被計算機視覺頂會 CVPR 2022 大會接收,論文標題為《Perturbed and Strict Mean Teachers for Semi-supervised Semantic Segmentation》:

  • 文章地址:https://arxiv.org/abs/2111.12903
  • 代碼地址:https://github.com/yyliu01/PS-MT

背景

語義分割是一項重要的像素級別分類任務。但是由於其非常依賴於數據的特性(data hungary), 模型的整體性能會因為數據集的大小而產生大幅度變化。同時, 相比於圖像級別的標註, 針對圖像切割的像素級標註會多花費十幾倍的時間。因此, 在近些年來半監督圖像切割得到了越來越多的關注。

半監督分割的任務依賴於一部分像素級標記圖像和無標籤圖像 (通常來說無標籤圖像個數大於等於有標籤個數),其中兩種類型的圖像都遵從相同的數據分布。該任務的挑戰之處在於如何從未標記的圖像中提取額外且有用的訓練信號,以使模型的訓練能夠加強自身的泛化能力。

在當前領域內有兩個比較火熱的研究方向, 分別是自監督訓練(self-training) 和 一致性學習 (consistency learning)。我們的項目主要基於後者來進行。

一致性學習的介紹

簡單來說, 一致性學習(consistency learning)過程可以分為 3 步來描述: 1). 用不做數據增強的 「簡單」 圖像來給像素區域打上偽標籤, 2). 用數據增強 (或擾動) 之後的 「複雜」 圖片進行 2 次預測, 和 3). 用偽標籤的結果來懲罰增強之後的結果。

可是, 為什麼要進行這 3 步呢? 先用簡單圖像打標籤, 複雜圖像學習的意義在哪?

從細節來說, 如上圖所示, 假設我們有一個像素的分類問題 (在此簡化為 2 分類, 左下的三角和右上的圓圈) 。我們假設中間虛線為真實分布, 藍色曲線為模型的判別邊界。

在這個例子中, 假設這個像素的標籤是圓圈, 並且由 1). 得到的偽標籤結果是正確的 (y_tilde=Circ.)。在 2). 中如果像素的增強或擾動可以讓預測成三角類, 那麼隨著 3)步驟的懲罰, 模型的判別邊界會 (順著紅色箭頭) 挪向真實分布。由此, 模型的泛化能力得到加強。

由此得出, 在 1). 中使用 「簡單」 的樣本更容易確保偽標籤的正確性, 在 2). 時使用增強後的 「複雜」 樣本來確保預測掉在邊界的另一端來增強泛化能力。可是在實踐中,

1). 沒有經受過增強的樣本也很可能被判斷錯 (hard samples), 導致模型在學習過程中打的偽標籤正確性下降。

2). 隨著訓練的進行, 一般的圖像增強將不能讓模型做出錯誤判斷。這時, 一致性學習的效率會大幅度下降。

3). 被廣泛實用的半監督 loss 例如 MSE, 在切割任務里不能給到足夠的力量來有效的推動判別邊界。而 Cross-entropy 很容易讓模型過擬合錯誤標籤, 造成認知偏差 (confirmation bias)。

針對這三個問題, 我們提出了:

1). 新的基於一致性的半監督語義分割 MT 模型。通過新引入的 teacher 模型提高未標記訓練圖像的分割精度。同時, 用置信加權 CE 損失 (Conf-CE) 代替 MT 的 MSE 損失,從而實現更強的收斂性和整體上更好的訓練準確性。

2). 一種結合輸入、特徵和網絡擾動結合的數據增強方式,以提高模型的泛化能力。

3). 一種新型的特徵擾動,稱為 T-VAT。它基於 Teacher 模型的預測結果生成具有挑戰性的對抗性噪聲進一步加強了 student 模型的學習效率.

方法介紹

1). Dual-Teacher Architecture

我們的方法基於 Mean-Teacher, 其中 student 的模型基於反向傳播做正常訓練。在每個 iteration 結束後, student 模型內的參數以 expotional moving average (EMA)的方式轉移給 teacher 模型。

在我們的方法中, 我們使用了兩個 Teacher 模型。在做偽標籤時, 我們用兩個 teacher 預測的結果做一個 ensemble 來進一步增強偽標籤的穩定性。我們在每一個 epoch 的訓練內只更新其中一個 teacher 模型的參數, 來增加兩個 teacher 之間的 diversity。

由於雙 teacher 模型並沒有參加到反向傳播的運算中, 在每個 iteration 內他們只會消耗很小的運算成本來更新參數。

2). Semi-supervised Loss

在訓練中, teacher 模型的輸出經過 softmax 後的置信度代表著它對對應偽標籤的信心。置信度越高, 說明這個偽標籤潛在的準確率可能會更高。在我們的模型中, 我們首先對同一張圖兩個 teacher 的預測取平均值。然後通過最後的 confidence 作為權重, 對 student 模型的輸出做一個基於 cross-entropy 懲罰。同時, 我們會捨棄掉置信度過低的像素標籤, 因為他們是噪音的可能性會更大。

3). Teacher-based Virtual Adversarial Training (T-VAT)

Virtual Adversarial Training (VAT) 是半監督學習中常用的添加擾動的方式, 它以部分反向傳播的方式來尋找能最大化預測和偽標籤距離的噪音。

在我們的模型中, dual-teacher 的預測比學生的更加準確, 並且 (由於 EMA 的更新方式使其) 更加穩定。我們使用 teacher 模型替代 student 來尋找擾動性最強的對抗性噪音, 進而讓 student 的預測出錯的可能性加大, 最後達到增強一致性學習效率的目的。

4). 訓練流程

i). supervised part: 我們用 strong-augmentation 後的圖片通過 cross-entropy 來訓練 student 模型。

ii). unsupervised part: 我們首先餵給 dual-teacher 模型們一個 weak-augmentation 的圖片, 並且用他們 ensemble 的結果生成標籤。之後我們用 strong-augmentation 後的圖片餵給 student 模型。在通過 encoder 之後, 我們用 dual-teachers 來通過 T-VAT 尋找具有最強擾動性的噪音並且注入到 (student encoded 之後的) 特徵圖里, 並讓其 decoder 來做最終預測。

iii). 我們通過 dual-teachers 的結果用 conf-ce 懲罰 student 的預測

iv). 基於 student 模型的內部參數, 以 EMA 的方式更新一個 teacher 模型。

實驗

1). Compare with SOTAs.

Pascal VOC12 Dataset:

訓練 log 可視化連結:

https://wandb.ai/pyedog1976/PS-MT(VOC12)?workspace=user-pyedog1976

該數據集包含超過 13,000 張圖像和 21 個類別。它提供了 1,464 張高質量標籤的圖像用於訓練,1,449 圖像用於驗證,1,456 圖像用於測試。我們 follow 以往的工作, 使了 10582 張低質量標籤來做擴展學習, 並且使用了和相同的 label id。

Low-quality Experiments

該實驗從整個數據集中隨機 sample 不同 ratio 的樣本來當作訓練集 (其中包含高質量和低質量兩種標籤), 旨在測試模型在有不同數量的標籤時所展示的泛化能力。

在此實驗中, 我們使用了 DeeplabV3 + 當作架構, 並且用 ResNet50 和 ResNet101 得到了所有 ratio 的 SOTA。

High-quality Experiments

該實驗從數據集提供的高質量標籤內隨機挑取不同 ratio 的標籤, 來測試模型在極少標籤下的泛化能力。我們的模型在不同的架構下 (e.g., Deeplabv3+ and PSPNet) 都取得了最好的結果。

Cityscapes Dataset

訓練 log 可視化連結:

https://wandb.ai/pyedog1976/PS-MT(City)?workspace=user-pyedog1976

Cityscapes 是城市駕駛場景數據集,其中包含 2,975 張訓練圖像、500 張驗證圖像和 1,525 張測試圖像。數據集中的每張圖像的解析度為 2,048 ×1,024,總共有 19 個類別。

在 2021 年之前, 大多數方法用 712x712 作為訓練的 resolution, 並且拿 Cross-entropy 當作 supervised 的 loss function。在最近, 越來越多的方式傾向於用大 resolution (800x800)當作輸入, OHEM 當作 supervised loss function。為了公平的對比之前的工作, 我們分別對兩種 setting 做了單獨的訓練並且都拿到了 SOTA 的結果。

2). Ablation Learnings.

我們使用 VOC 數據集中 1/8 的 ratio 來進行消融實驗。原本的 MT 我們依照之前的工作使用了 MSE 的 loss 方式。可以看到, conf-CE 帶來了接近 3 個點的巨大提升。在這之後, T-VAT (teacher-based virtual adversarial training)使 student 模型的一致性學習更有效率, 它對兩個架構帶來了接近 1% 的提升。最後, dual-teacher 的架構給兩個 backbone 分別帶來了 0.83% 和 0.84% 的提升。

同時我們對比了多種針對 feature 的擾動的方法, 依次分別為不使用 perturbation, 使用 uniform sample 的噪音, 使用原本的 VAT 和我們提出的 T-VAT。T-VAT 依然帶來了最好的結果。

3). Improvements over Supervised Baseline.

我們的方法相較於相同架構但只使用 label part 的數據集的結果有了巨大提升。以 Pascal VOC12 為例, 在 1/16 的比率中 (即 662 張標記圖像), 我們的方法分別 (在 ResNet50 和 ResNet101 中) 超過了基於全監督訓練的結果 6.01% 和 5.97%。在其他 ratio 上,我們的方法也顯示出一致的改進。

總結

在本文中,我們提出了一種新的基於一致性的半監督語義分割方法。在我們的貢獻中,我們引入了一個新的 MT 模型,它基於多個 teacher 和一個 student 模型,它顯示了對促進一致性學習的未標記圖像更準確的預測,使我們能夠使用比原始 MT 的 MSE 更嚴格的基於置信度的 CE 來增強一致性學習的效率。這種更準確的預測還使我們能夠使用網絡、特徵和輸入圖像擾動的具有挑戰性的組合,從而顯示出更好的泛化性。

此外,我們提出了一種新的對抗性特徵擾動 (T-VAT),進一步增強了我們模型的泛化性。

關鍵字: