機(jī)器之心分析師網(wǎng)絡(luò)
作者:周宇
編輯:H4O
本文重點(diǎn)探討分布式學(xué)習(xí)框架中針對(duì)隨機(jī)梯度下降(SGD)算法的拜占庭問(wèn)題。
分布式學(xué)習(xí)(Distributed Learning)是一種廣泛應(yīng)用的大規(guī)模模型訓(xùn)練框架。在分布式學(xué)習(xí)框架中,服務(wù)器通過(guò)聚合在分布式設(shè)備中訓(xùn)練的本地模型(local model)來(lái)利用各個(gè)設(shè)備的計(jì)算能力。分布式機(jī)器學(xué)習(xí)的典型架構(gòu)——參數(shù)服務(wù)器架構(gòu)中,包括一個(gè)服務(wù)器(稱(chēng)為參數(shù)服務(wù)器 - Parameter Server,PS)和多個(gè)計(jì)算節(jié)點(diǎn)(workers,也稱(chēng)為節(jié)點(diǎn) nodes)[1]。其中,隨機(jī)梯度下降(Stochastic Gradient Descent,SGD)是一種廣泛使用的、效果較好的分布式優(yōu)化算法。在每一輪中,每個(gè)計(jì)算節(jié)點(diǎn)根據(jù)不同的本地?cái)?shù)據(jù)集在它的設(shè)備上訓(xùn)練一個(gè)本地模型,并與服務(wù)器共享最終的參數(shù)。然后,服務(wù)器聚合不同計(jì)算節(jié)點(diǎn)的參數(shù),并通過(guò)與計(jì)算節(jié)點(diǎn)共享得到的組合參數(shù)來(lái)啟動(dòng)下一輪訓(xùn)練。關(guān)于基于 SGD 優(yōu)化的分布式框架的網(wǎng)絡(luò)結(jié)構(gòu)(包括:層數(shù)、類(lèi)型、大小等)在訓(xùn)練開(kāi)始之前由所有計(jì)算節(jié)點(diǎn)共同商定確認(rèn)。
近年來(lái),分布式學(xué)習(xí)的安全性越來(lái)越受到人們的關(guān)注,其中,最重要的就是拜占庭威脅模型。在拜占庭威脅模型中,計(jì)算節(jié)點(diǎn)可以任意和惡意地行事。機(jī)器之心在前期的文章中也探討過(guò)分布式學(xué)習(xí)中的拜占庭問(wèn)題,主要針對(duì)聯(lián)邦學(xué)習(xí)中的拜占庭問(wèn)題。在這篇文章中,我們重點(diǎn)探討的是分布式學(xué)習(xí)框架中針對(duì)隨機(jī)梯度下降(SGD)算法的拜占庭問(wèn)題。如圖 1 所示,在 SGD 學(xué)習(xí)框架中,一些惡意節(jié)點(diǎn)(Malicious worker)向服務(wù)器發(fā)送拜占庭梯度(Byzantine Gradient),而不是計(jì)算得到的真實(shí)梯度,而拜占庭梯度可以是任意值。惡意節(jié)點(diǎn)可以控制計(jì)算節(jié)點(diǎn)設(shè)備本身,也可以控制節(jié)點(diǎn)和服務(wù)器之間的通信。以 Algorithm 1 中提出的同步 SGD(sync-SGD)協(xié)議為例 [4]。攻擊者(惡意節(jié)點(diǎn))在使其效果最大化的時(shí)間內(nèi)(即在 Algorithm 1 的第 6 行和第 7 行之間)干擾進(jìn)程。在此期間,攻擊者可以將節(jié)點(diǎn) i 中的參數(shù)(p_i)^(t+1) 替換為任意值,然后將此任意值發(fā)送到服務(wù)器中。攻擊方法在設(shè)置參數(shù)值的方式上有所不同,而防御方法則試圖識(shí)別損壞的參數(shù)并丟棄它們。Algorithm 1 使用平均值(第 8 行中的 AggregationRule( ))聚合計(jì)算節(jié)點(diǎn)參數(shù)。
圖 1. SGD 學(xué)習(xí)框架工作流程 [3]
本文所討論的分布式學(xué)習(xí)的核心是這樣一個(gè)假設(shè):經(jīng)過(guò)訓(xùn)練的網(wǎng)絡(luò)參數(shù)是獨(dú)立同分布的(Independent and identically distributed,i.i.d.)
關(guān)鍵詞: 分布式 機(jī)器 學(xué)習(xí) 拜占庭