Artificial Intelligence

[논문 리뷰] Communication-Efficient Learning of Deep Networks from Decentralized Data

은최 2025. 12. 17. 16:03

논문 링크: 

 

Communication-Efficient Learning of Deep Networks from Decentralized Data

Modern mobile devices have access to a wealth of data suitable for learning models, which in turn can greatly improve the user experience on the device. For example, language models can improve speech recognition and text entry, and image models can automa

arxiv.org

 

출처: H. B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. y Arcas,“Communication-efficient learning of deep networks from decentralized data,” arXiv:1602.05629, 2016.


요약

 

   모바일 기기들은 더 많은 데이터에 접근할 수 있기에 더 나은 사용자 경험을 제공할 수 있지만 이러한 데이터들은 프라이버시가 중요하고 그 양이 많다. 이에 따라 이 논문에선 Federated Learning이란 대안을 제시하는데 이는 훈련 데이터를 모바일 기기들에 분산시키고 각 기기들의 업데이트를 종합해 하나의 공유된 모델을 사용하는 방법이다. 이 논문에선 이 훈련 방법을 현실적으로 다루고 실험하여 평가한다. 이러한 평가 결과, communication round는 기존 방법 대비 10-100배 감소했다.


서론

 

  Federated Learning 기법은 각 기기의 학습 파라미터만 공유하기에 중앙 저장이 필요 없어 저장공간을 아낄 수 있다. 또 파라미터만 공유되기에 데이터의 보안도 동시에 챙길 수 있다. 이 논문에선 Federated Averaging 알고리즘을 제시하는데 이는 각 client들의 stochastic gradient descent(SGD)를 합친 후 평균하는 것이다. unbalanced, non-IID 데이터셋으로 실험하고 평가한 결과 필요한 통신 라운드의 수가 크게 감소했다. 

 

Federated Learning은 다음과 같은 특징을 갖는다.

  • 데이터 센터의 공개 데이터보단 실제 환경의 데이터에서 훈련시키는 것이 더 좋은 성능을 보인다. 
  • 데이터는 개인정보에 민감하고 그 양이 많기 때문에 중앙 데이터 센터에 저장하지 않는 것이 좋다. 
  • Supervised 과제에선 데이터의 label은 사용자 반응을 통해 자연스럽게 추론될 수 있다.

Privacy

   데이터를 중앙에 모으는 것보단 분산 시키는 것이 보안에 훨씬 유리하다. 데이터의 일부를 익명화해도 그것들이 모이면 추측이 가능해져 보안이 위협될 수 있다. Federated learning에서 필요한 정보는 특정 모델을 위한 최소한의 업데이트뿐이다. 알고리즘은 보내진 값들을 평균만 하기 때문에 데이터의 소스 또한 필요 없으며 데이터의 익명성을 보장할 수 있다. 

 

Federated Optimization

optimization과 관련해선 다음과 같은 특징들이 있다.

  • Non-IID, 데이터 셋은 사용자의 사용도에 따라 다르기에 이들은 전체를 대표하지 못한다.
  • Unbalanced, 어떤 사용자들은 다른 사용자들보다 서비스를 더 많이 이용하여 더 많은 데이터를 만들 수 있다. 
  • Massively distributed, 사용자 수는 각 사용자가 보유한 예시보다 훨씬 클 것으로 예상한다. 
  • Limited communication, 모바일 기기들은 자주 오프라인 상태거나 연결이 느리다. 

   또한 실험 환경은 Round에 따른 동기화 업데이트 계획을 가정하며 K명의 client, 매 라운드마다 선택하는 client의 비율 C를 가정한다. 각 라운드마다 선택된 client들은 서버로 업데이트를 전송하고 서버는 이 업데이트들을 global state에 적용한다. 이는 매 라운드마다 반복된다. 알고리즘은 다음과 같이 제시된다. 

\[
\min_{w \in \mathbb{R}^d} f(w), 
\quad f(w) = \frac{1}{n} \sum_{i=1}^{n} f_i(w)
\]
   여기서 f는 손실함수를 나타내며 알고리즘의 목표는 손실을 최소화하는 파라미터 w를 찾는 것이다. 

각 클라이언트 \(k\)는 데이터 집합 \(P_k\)를 가지고 있으며, 그 크기는 \(n_k = |P_k|\)이다.
\[
f(w) = \sum_{k=1}^{K} \frac{n_k}{n} F_k(w), 
\quad F_k(w) = \frac{1}{n_k} \sum_{i \in P_k} f_i(w)
\]

   Federated optimization에선 통신 비용이 연산 비용보다 높다. 이에 따라 이 논문에선 연산 비용을 오히려 늘려 통신 비용, 라운드 수를 줄이려고 한다. 이런 방법은 두 가지가 있는데 increased parallelism은 각 라운드에 독립적으로 작동하는 클라이언트 수를 증가시킨다. Increased computation은 클라이언트들이 각 라운드에 더 복잡한 연산을 할 수 있도록 한다. 


알고리즘

 

   이 논문에선 stochastic gradient descent(SGD)를 이용해 Federated Averaging 알고리즘을 설계한다. Single batch를 사용하면 연산이 효율적이지만 많은 라운드 수가 필요하기에 큰 사이즈의 batch를 사용한다. 이를 위해 전체 클라이언트 중 선택되는 비율인 C를 사용한다. C = 1은 full batch를 의미하며 이 baseline 알고리즘을 FederatedSGD(FedSGD)라고 한다. 

learning rate η, 클라이언트의 수 k, gradient는  \(g_k\)이며 서버는 이 gradient들을 모아 다음과 같이 파라미터를 업데이트 한다:
\[
w_{t+1} \leftarrow w_t - \eta \sum_{k=1}^{K} \frac{n_k}{n} g_k
\]
   이는 식은 아래와 동등하게 나타낼 수 있다. 처음엔 클라이언트 k가 전역 모델 \(w_t\)를 받고 자신의 로컬 gradient로 한 번의 경사하강을 수행하여 새로운 로컬 모델을 만든다. 그다음엔 서버가 클라이언트 k의 로컬 모델들을 모아 크기 비율로 가중 평균하여 전역 모델을 업데이트한다:
\[
w_{t+1}^k \leftarrow w_t - \eta g_k, 
\quad w_{t+1} \leftarrow \sum_{k=1}^{K} \frac{n_k}{n} w_{t+1}^k
\]

   한 번만 로컬 데이터로 gradient descent를 수행하는 FedSGD와 다르게 Federated Averaging(FedAvg)에서는 경사하강을 로컬에서 여러 번 반복한다:

\[
w_k \leftarrow w_k - \eta F_k(w_k)
\]
   연산량은 클라이언트의 비율 C, 각 클라이언트가 로컬 데이터셋이 훈련을 반복하는 횟수 E, 로컬 미니 배치 사이즈 B에 따라 결정된다. FedAvg에서 B =\(\infty\) (전체 데이터 셋이 미니 배치로 취급), E=1인 경우가 FedSGD에 해당하며 로컬에서의 업데이트 횟수는 \(u_k = \frac{E \cdot n_k}{B}\)로 나타낼 수 있다. 

 

   추가적으로 여기서 모델을 어떻게 설정했는지도 다루어진다. 일반적인 non-convex 상황에서 모델들을 평균하면 결과가 안 좋지만 이 알고리즘에선 두 모델을 동일한 random initializaton으로 시작한 후 서로 다른 데이터로 학습시켰다. Random initialization 한 것과 비교했을 때 loss가 적었으며 convex 한 결과가 나타났다. 


실험 결과

   

   MNIST 2NN과 CNN모델에 대해 실험했으며 클라이언트 비율 C에 따라 목표 test-set accuracy를 달성하는데 필요한 라운드 수를 나타낸 것이다. C가 증가할수록 라운드 수가 보통 감소하는 것을 확인할 수 있다. B=10인 경우가 B =\(\infty\)에 비해 C의 증가에 따른 통신 라운드 수 감소가 크다. 

 

Increasing computation per client

 

   위 표는 MNIST CNN, Shakespeare LSTM 두 모델에 대해 실험했으며 FedAvg의 E와 B값을 변화시키며 (연산량 변화) IID, non-IID 데이터에 대해서 통신 라운드 수를 정리한 것이다. Averaging 한 것이 일반 FedSGD보다 항상 더  좋은 결과를 얻었다. 또, 셰익스피어 데이터셋은 non-IID, unbalanced 경우 학습이 더 잘 되는 것을 확인할 수 있다. 이는 일부 역할이 상대적으로 많은 데이터 셋을 가지고 있기에 나타난 결과로 생각된다. FedAvg는 test-set 정확도가 수렴한 뒤에 training loss를 줄이는데도 효과적이다.

 

Can we over-optimize on the client datasets?

 

   이미 충분히 학습된 상태라면 추가적으로 학습 라운드를 늘려도 성능 향상이 이루어지지 않는다. 위 실험 결과는 lochal epoch E에 따른 loss 변화를 나타낸 것이다. 이 실험 결과는 연산량을 줄이는 것(E 감소, B증가)이 더 나은 loss를 얻게 해 줄 수 있음을 보여준다. 

 

CIFAR experiments

 

   위 표는 CIFAR 데이터 셋에서 목표 Accuracy를 달성하는데 필요한 통신 라운드 수를 나타낸다. SGD와 비교했을 때 FedSGD와 FedAvg는 더 적은 수의 라운드로 목표 정확도에 도달했음을 확인할 수 있다. 

 

 

   FedAvg는 FedSGD에 비해 더 적은 통신 라운드 수로도 정확도가 빠르게 상승하는 것을 확인할 수 있다. 

 

Large-scale LSTM experiments

 

   마지막으로 이 논문에선 실제 환경과 가까운 문제해결에 접근하고자 sns상의 공개 게시물을 데이터 셋으로 이용했다. Learning rate를 변경하며 실험한 결과, FedAvg는 FedSGD보다 더 적은 라운드 수로 목표 test-accuarcy를 달성했다.


결론

 

   실험을 통해 FedAvg는 적은 수의 라운드로도 좋은 성능을 달성할 수 있다는 것을 확인했다. 이는 다양한 모델 구조들에서 검증되었다. Differential privacy, multi-party computation과 같은 발전 방향이 제시되며 이들은 FedAvg와 잘 어우러진다.