Artificial Intelligence

[논문 리뷰] Split learning for health: Distributed deep learning without sharing raw patient data

은최 2025. 12. 18. 16:40

 

   이 논문은 6G, AI native 통신에서 자주 사용되는 Split Learning에 관해 소개한다. 이는 Federated Learning과 자주 비교되며 이에 대해 공부하고 정리해 보았다.

 

   쉽게 요약하자면 각 의료기관들은 연산을 많이 하지 못하고 데이터도 부족하니 중간 단계(cut layer)까지만 학습한 후 중앙 기관에 학습결과를 공유한다. 중앙에선 이 결과들을 모아 하나의 학습 모델을 만들고 각 기기들에게 공유한다. 이러면 작은 병원에서도 많은 데이터를 집합한 모델을 쓸 수 있고 데이터를 여러 기관끼리 직접 공유하지 않아도 된다. 이런 SplitNN은 연산량과 대역폭 면에서도 다른 방법보다 유리한 결과가 나타났다. 

 

논문 링크: 

 

Split learning for health: Distributed deep learning without sharing raw patient data

Can health entities collaboratively train deep learning models without sharing sensitive raw data? This paper proposes several configurations of a distributed deep learning method called SplitNN to facilitate such collaborations. SplitNN does not share raw

arxiv.org

 

출처: P. Vepakomma, O. Gupta, T. Swedish, and R. Raskar, “Split learning for health: Distributed deep learning without sharing raw patient data,” arXiv:1812.00564, Dec. 2018.


요약

 

이 논문은 의료 정보를 주고받기 위해 데이터 전체를 공유하지 않아도 되는 방법을 소개한다. 이 방법은 SplitNN이라고 하며 이 논문에선 다음과 같이 가정한다. 

  • 각 기관은 서로 다른 종류의 의료 정보를 갖는다. (한 기관은 방사선 사진, 다른 기관은 유전자 데이터 등을 보유)
  • 중앙 의료 기관이 다른 여러 지역 의료 기관하 협력하여 과제를 수행하는 경우
  • label을 공유하지 않고 학습한다. (환자가 어떤 병에 걸렸는지는 민감한 개인정보)

이 논문에선 이러한 splitNN의 성능과 자원 효율성을 federaed learning, large batch synchronous stochastic gradient descent 방법과 비교한다. 


서론

 

   의료 업계에선 환자들의 데이터 공유에 관한 동의를 얻기 어렵고 기관들 사이의 신뢰가 부족해 서로 협력이 어렵다. 각 의료기관들은 서로 다른 종류의 필요한 환자 데이터를 보유하고 있기에 데이터를 직접적으로 공유하지 않는 분산 머신러닝 방법이 요구된다. 여기선 데이터뿐만 아니라 모델 구조와 파라미터도 추가적으로 공유하지 않는 것이 보안에 더욱 유리하다. 작은 의료 기관들은 데이터와 진단 인력이 부족할 수 있지만 SplitNN을 통해 전체 모델에 조금씩 기여할 수 있다. 이뿐만 아니라 SplitNN은 기존의 방법보다 정확도를 높이고 연산량과 대역폭을 줄일 수 있다. 

 

   기존의 federated learning과 large batch synchronous stochastic gradient descent(SGD)는 다음과 같은 상황들에서 연구되지 않았다. 즉, Split Learning은 다음과 같은 상황들에서 쓰일 수 있다. 

  • vertically partitioned data, 서로 다른 기관이 다른 데이터 유형을 보유한 경우
  • distributed deep learning without label sharing, 환자의 질병 여부 같은 라벨을 공유하지 않고 학습해야 하는 경우
  • distributed semi-supervised learning, 일부 데이터에만 라벨이 있는 상황에서 학습해야 하는 경우
  • distributed multi-task learning, 여러 기관들이 동시에 서로 다른 과제를 학습해야하는 경우

주요 구성 방식

 

(a) Simple vanilla configuration for split learning

   각 클라이언트들은 딥 네트워크들은 중간 단계인 cut layer까지 학습한다. 이 출력은 서버로 전송되며 raw data 없이 나머지 학습이 마무리된다. 여기까지가 forward propagation이다. Gradient들은 서버의 마지막 레이어로부터 cut layer까지 back propagation 된다. cut layer의 gradient는 다시 클라이언트로 보내지며 나머지 back propagation은 클라이언트 측에서 이루어진다. 이 과정은 학습이 완료될 때까지 반복된다. 

 

(b) U-shaped configurations for split learning without label sharing

   라벨이 환자의 질병에 관한 민감한 정보를 포함하는 경우 이 모델이 사용된다. 여기선 클라이언트가 라벨을 서버에 전달하지 않아도 학습이 가능하다. 서버의 마지막 레이어 출력값을 클라이언트로 전송해 U자형 구조를 형성한다. 대부분의 레이어는 서버가 갖고 있으며 클라이언트들은 서버로부터 받은 값을 통해 back propagation을 진행한다. 결과적으로 이 구조에서 라벨은 오직 클라이언트에만 남는다. 

 

(c) Vertically partitioned data for split learning

   이 방법은 의료 기관들이 환자에 대해 서로 다른 종류의 데이터를 갖고 있을 때 사용된다. 예를 들어 영상 의학 기관에선 이미지 데이터를 cut layer까지 학습시키고 병리학 센터에서도 테스트 결과를 cut layer까지 학습시킨다. 두 센터의 결과는 결합되어 질병 진단 서버로 보내지며 여기서 나머지 학습이 완료된다. 학습이 끝날 때까지 이렇게 backward/forward propagation이 반복된다. 


자원 효율성

 

   위 그림은 연산량 증가에 따라 평가 정확도가 얼마나 증가하는지를 나타낸 그래프이다. 각 그림은 CIFAR 10,100 데이터셋과 VGG, Resnet-50 모델이 사용되었다. 이 논문에서 소개한 SplitNN은 Federated Avg, Large Batch SGD보다 적은 연산량으로 더 높은 정확도를 달성했다. 

또한 (a)의 경우 클라이언트당 연산량이 다른 방법보다 크게 감소한 것을 확인할 수 있다.

(b)의 경우에서도 split NN이 클라이언트당 필요한 연산 대역폭이 대체적으로 감소했다. 그러나 클라이언트 수가 적은 경우 federated learning이 더 적은 대역폭을 사용했다. CNN의 앞부분은 파라미터가 적고 cut layer로 연산을 분할하기 때문에 클라이언트들은 적은 자원으로도 학습에 참여할 수 있기에 SplitNN에서 좋은 결과가 나타났다. 


결론 및 발전 방향

 

   이 논문에서는 의료업계에서 사용되던 단순한 분산 학습 방식 대신 splitNN이라는 새로운 방법을 제안했다. 이는 federated learning과 large batch synchronous SGD보다 더 좋은 결과를 보였다. splitNN은 plug and play로 필요에 따른 용도 변경이 가능하며 여기서 다룬 소규모 병원에서의 학습에서 벗어나 대규모 환경으로의 확장이 가능하다. 또한 splitNN에선 최신 딥러닝 구조를 도입할 수 있으며 신경망에서의 압축 기법을 통해 자원을 더 많이 활용할 수 있다. 이를 통해 엣지 디바이스에서도 끊김 없는 매끄러운 분산 학습을 수행할 수 있다. 


추가 구성 방식

 

(a) Extended vanilla split learning

   각 클라이언트들의 결과가 또 다른 클라이언트를 거쳐 서버에 전달된다. 

 

(b) Configurations for multi-task split learning

   각 클라이언트에서 학습이 진행되며 이 결과들은 여러 서버에 전송된다. 각 서버는 다른 지도 학습 과제를 학습한다.

 

(c) Tor like configuration for multi-hop split learning

   여러 클라이언트들은 cut 레이어까지 학습하고 그 결과를 다음 클라이언트에게 전달한다. 이는 마지막 클라이언트가 자신의 출력값을 서버에 전달하며 마무리된다. 이는 'Tor' 구조와 유사하다.