清華開源遷移學習算法庫:基于PyTorch實現,支持輕松調用已有算法

五嘎子 4年前發布 | 776 次閱讀 Pytorch

  清華大學大數據研究中心機器學習研究部長期致力于遷移學習研究。近日,該課題部開源了一個基于 PyTorch 實現的高效簡潔遷移學習算法庫:Transfer-Learn。使用該庫,可以輕松開發新算法,或使用現有算法。

  項目地址:https://github.com/thuml/Transfer-Learning-Library

  目前,該項目發布了第一個子庫——領域自適應算法庫(DALIB),其支持的算法包括:

Domain Adversarial Neural Network (DANN)

Deep Adaptation Network (DAN)

Joint Adaptation Network (JAN)

Conditional Domain Adversarial Network (CDAN)

Maximum Classifier Discrepancy (MCD)

Margin Disparity Discrepancy (MDD)

  領域自適應背景介紹

  目前,深度學習模型在一部分計算機視覺、自然語言處理任務中超越了人類的表現,但是它們的成功通常依賴于大規模標記數據。在實際應用場景中,標記數據往往是稀缺的。

  解決標記數據稀缺問題的一個方法是通過計算機模擬生成訓練數據,例如使用計算機圖形學技術合成訓練數據(如下圖所示)。此外,還可以從相關的領域 “借用” 標記數據。

  但是,在此場景下,訓練數據和測試數據不再服從獨立同分布,使訓練得到的深度網絡準確率大打折扣。為了解決數據集偏移造成的泛化難題,領域自適應 (Domain Adaptation) 的概念被提出。

  領域自適應的目標是將機器學習模型在源領域 (Source) 學到的知識遷移到目標領域 (Target)。例如在計算機模擬生成訓練數據的例子中,合成數據是源領域,真實場景的數據是目標領域。領域自適應有效地緩解了深度學習對于人工標記數據的依賴,受到學術界和工業界的廣泛關注。目前已廣泛應用到圖像分類、圖像分割、目標檢測、情感分析、機器翻譯等眾多任務上。

  吳恩達曾說過:「在監督學習之后,遷移學習將引領下一輪機器學習技術商業化浪潮。」圖靈獎得主 Bengio 也認為遷移能力是深度學習進一步發展的基礎能力之一。隨著產品級機器學習應用進入數據稀缺領域,監督學習得到的尖端模型性能大打折扣,領域自適應變得越來越重要。

  研究現狀

  深度領域自適應方法主要包括以下三大類:

對抗訓練。領域對抗網絡 DANN 是最早的工作,它引入領域判別器,鼓勵特征提取器學習領域無關的特征。在 DANN 的基礎上衍生出了一系列方法,例如條件領域對抗網絡 CDAN、最大分類器差異 MCD。

理論啟發。通過嚴格的理論推導,得到可以顯式控制遷移學習泛化誤差的算法,如間隔分歧散度 MDD 等。


  DANN 網絡架構圖。

  MDD 網絡架構圖。

  上述方法在實驗數據上表現出良好的性能。然而目前學術界領域自適應方法的開源實現存在下述問題:

穩定性差。部分對抗訓練方法隨著訓練進行,準確率會大幅度下降。

  針對這些不足,深度領域自適應算法庫(DALIB)設計的初衷就是:用戶通過少數幾行代碼,即可將領域自適應算法應用到實際項目中,無需考慮領域自適應模塊的實現細節

  易用性

  DALIB 將現有領域自適應訓練代碼中的領域自適應損失函數分離出來,按照 PyTorch 交叉熵損失函數的形式進行封裝,以方便用戶使用。

  領域自適應損失函數也和模型架構進行了解耦,不依賴于具體的分類任務,所以算法庫很容易擴展到圖像分類以外的機器學習任務。

  如下所示,使用兩行代碼即可定義一個與任務無關的領域對抗損失函數:

  各種領域自適應損失函數中有一些公用的模塊,例如所有算法中都用到的分類器模塊、對抗訓練中用到的梯度反轉模塊和領域判別器模塊、統計距離中用到的核函數模塊等。

  這些公用模塊和提供的領域自適應損失函數是分離的。因此,在 DALIB 中,用戶可以像搭積木一樣,重新定制自己需要的領域自適應損失函數。例如,在核方法中,用戶可以自定義不同參數的高斯核函數或其他核函數,然后將其傳入到多核最大均值差異(MK-MMD)的計算中。

  目前,所有的模塊和損失函數均已提供詳細的 API 說明文檔:

  https://dalib.readthedocs.io/en/latest/

  穩定性

  領域自適應算法研究往往關注方法的創新性或理論價值,而忽視了工程實現中的穩定性和可復現性。在復現現有算法的過程中,出現了部分算法準確率不穩定的問題。DALIB 通過對數值計算方面的改進,解決了這些問題。(具體實現此處不再展開。)

  DALIB 在常見的領域自適應基準集上的測試準確率都比原論文匯報準確率高,在部分數據集上的準確率甚至高出 14%。下圖分別是 Office-31 和 VisDA-2017 三個基準集上的測試結果:

  Office-31 上不同算法的準確率。

  VisDA-2017 上不同算法的準確率。

  DALIB 算法庫提供了所支持的算法在 Office-31、Office-Home 和 VisDA-2017 三個基準集上的測試結果,以及完整的測試腳本。清華大學龍明盛老師課題組認為開源這一算法庫有助于更好地推進遷移學習方向的未來研究工作。

  未來的工作

  領域自適應算法子庫 DALIB 的下一個版本將支持領域自適應算法的各種復雜設定,包括部分集領域自適應任務(Partial Domain Adaptation)、開放集領域自適應任務(Open-Set Domain Adaptation)、通用域自適應任務(Universal Domain Adaptation)等。同時,還將支持多功能領域自適應算法(Versatile Domain Adaptation)。

  遷移學習算法庫 Transfer-Learn 目前還處于初期開發階段。該研究團隊表示,隨著遷移學習方向的不斷發展,今后 Transfer-Learn 算法庫將不斷跟進新工作中比較好的算法,不斷擴展優化,為遷移學習提供一個穩定可靠的評測基準。

  當前版本由龍明盛老師課題組的江俊廣、付博兩名同學維護。清華大學軟件學院、大數據系統軟件國家工程實驗室為研發該算法庫提供了強大的平臺支撐

 本文由用戶 五嘎子 自行上傳分享,僅供網友學習交流。所有權歸原作者,若您的權利被侵害,請聯系管理員。
 轉載本站原創文章,請注明出處,并保留原始鏈接、圖片水印。
 本站是一個以用戶分享為主的開源技術平臺,歡迎各類分享!
  轉載自:https://mp.weixin.qq.com/s/8-ZZQePbuzrS2VdmBB2mnw