2.1 分布式通信基础理论
分布式通信基础理论¶
学习目标¶
- 理解分布式通信中最核心的All-Reduce原理.
All-Reduce原理详解¶
- 分布式深度学习的通信依赖于规则的集群通信, 例如all-reduce, reduce-scatter, all-gather等. 因此高度优化的集群通信, 以及根据特点和通信拓扑选择合适的集群通信算法至关重要.
all-reduce做了什么?¶
- 如下图所示, 总共4个设备, 每个设备上有一个矩阵(为简单起见, 每一行标识一个元素), all-reduce操作的目的是, 让每个设备上的矩阵中的每一个位置的数值都是所有设备上对应位置的数值之和!
![](img/2_1.png)
- all-reduce可以通过reduce-scatter和all-gather这两个原子操作来实现. 基于ring通信可以高效的实现reduce-scatter和all-gather, 如下图所示:
![](img/2_2.png)
-
通过上图的展示, 可以看出, reduce-scatter的结果是每个设备保存一部分reduce之后的结果.
-
定义一些符号, 假设有p个设备(图中p=4), 假设整个矩阵大小是V, 那么reduce-scatter后, 每个设备上有V/p大小的数据块. 假设卡和卡之间通信带宽是β, 而且是双工通信(duplex), 即每个设备出口和入口带宽可以同时达到β, 所有设备的入口带宽总和是p * β, 所有设备的出口带宽总和也是p * β.
reduce-scatter的实现¶
- 高效实现一个集群通信的关键是如何充分利用设备和设备之间的带宽, 基于环状(ring)通信实现的集群通信算法就是这一思想的体现.
- 我们以reduce-scatter为例来看看环状通信算法是怎么工作的, 一共有p个设备, 每个设备上数据都划分为p份, 环状reduce-scatter一共需要p-1步才能完成.
![](img/2_3.png)
-
第1步: 每个设备都负责某一块V/p的数据并向左边的设备发送这块数据, 例如在上图中
- 第1个设备负责第2片数据并向第0个设备发送(即第4个设备)
- 第2个设备负责第3片数据并向第1个设备发送
- 第3个设备负责第4片数据并向第2个设备发送
- 第4个设备负责第1片数据并向第3个设备发送
-
每个设备收到右边设备的数据后, 就把收到的数据累加到本地对应位置的数据上去(通过逐渐变深的颜色表示数据累加的次数更多). 注意, 在这样的安排下, 每个设备的入口带宽和出口带宽都被用上了, 而且不会发生争抢带宽的事情.
-
第2 步: 同样见上图中
- 第1个设备把累加后的第3片数据向第0个设备发送(即第4个设备)
- 第2个设备把累加后的第4片数据向第1个设备发送
- 第3个设备把累加后的第1片数据向第2个设备发送
- 第4个设备把累加后的第2片数据向第3个设备发送
-
每个设备收到右边设备发过来的数据后, 就把收到的数据累加到本地对应位置的数据上去(累加后颜色更深).
-
第3步: 同样见上图中
- 第1个设备把累加后的第4片数据向第0个设备发送(即第4个设备)
- 第2个设备把累加后的第1片数据向第1个设备发送
- 第3个设备把累加后的第2片数据向第2个设备发送
- 第4个设备把累加后的第3片数据向第3个设备发送
-
每个设备收到右边设备发送过来的数据后, 就把收到的数据累加到对应位置的数据上去(累加后颜色更深). 经过p-1步之后, 每个设备上都有了一片所有设备上对应位置数据reduce之后的数据.
- 整个过程中, 每个设备向外发送了(p-1)V/p大小的数据, 也收到了(p-1)V/p大小的数据, 因为每个设备的出口, 入口带宽是p, 所以整个过程需要的时间等于(p-1)V/(pβ), 如果p足够大, 完成时间近似等于V/β. 神奇的是, 这个时间和设备数p无关. 当然, 在所有设备间传递的数据量是(p-1)V, 和设备数p成正比.
- 结论: 基于环状通信(Ring)的集群通信算法执行时间几乎和设备数无关, 但总通信量和设备数成正比.
all-gather的实现¶
- reduce-scatter执行完毕后, 再通过all-gather操作就可以实现all-reduce, 其中all-gather也可以通过环状通信算法来实现, 具体流程见下图所示:
![](img/2_4.png)
- 通信时间和通信量与reduce-scatter完全一样: 整个过程中, 每个设备向外发送了(p-1)V/p大小的数据, 也收到了(p-1)V/p大小的数据, 因为每个设备的出口, 入口带宽是p, 所以整个过程需要的时间等于(p-1)V/(pβ), 如果p足够大, 完成时间近似等于V/β.
- 注意: 在reduce-scatter中, V都是完整矩阵的数据量, 即reduce-scatter输入矩阵的数据量和all-gather输出矩阵的数据量.
Ring算法分析¶
通信量和冗余显存之间的关系¶
-
以reduce-scatter的实现中的图进行分析: 每个设备上的输入矩阵大小是V, 但经过reduce-scatter之后每个设备上只需要V/p大小的显存, 也就是(p-1)V/p的空间是冗余的, 因为总共有p个设备, 所有整个集群中有(p-1)V的显存是可以节省下来的. 注意, 每个设备冗余的显存恰好和每个设备的通信量一致, 所有设备冗余的显存和所有设备的总通信量一致.
-
以all-gather的实现中的图进行分析: 每个设备上的输入矩阵大小是V/p, 但经过all-gather之后每个设备上需要的显存是V, 而且每个设备上的矩阵大小和数值都完全一样. 也就是经过all-gather之后, 在设备之间产生了冗余, 不同的设备存储了一些完全一样的数据. 注意, 和reduce-scatter一样, 每个设备冗余的显存恰好和每个设备的通信量一致, 所有设备冗余的显存和所有设备的总通信量一致.
-
当然, 冗余量和通信量之间的等价关系不是偶然的, 正是因为这些通信才造成了设备之间数据的冗余. 因此, 当保持V不变时, 增大设备数p(可以称p为集群通信的并行宽度)时, 所有设备之间的通信量是正比增长的, 而且在所有设备上造成的冗余显存是正比例增长的. 当然, 完成一个特定的集群通信所需要的时间基本和并行宽度p无关.
-
因此, 增加并行宽度p是一个双刃剑, 一方面它让每个设备处理的数据量更小, 即V/p, 从而让计算的时间更短. 但另一方面, 它会需要更多的通信带宽(p-1)V, 以及更多的显存空间(p-1)V.
环状算法的最优性¶
- 首先提一个问题: 正在学习的你能不能想出比环状算法更好的集群算法实现? 答案是: 理论上不可能有更好的算法了!!!
- 我们已经分析过了要完成reduce-scatter和all-gather每个设备至少要向外发送(以及同时接收)的数据量是(p-1)V/p, 无论使用什么算法, 这个数据量是不可能更少了. 在这个数据量下, 最少需要多少时间呢? 出口带宽是β, 因此一张卡向外发送数据至少需要的时间是(p-1)V/(pβ), 这也正是环状算法需要的时间.
- 我们这里的通信时间只考虑传输带宽, 没有考虑每次传输都包含的延迟(latency). 当数据量V比较大时, 延迟可以忽略, 上文的分析就是成立的. 当V特别小, 或设备数p特别大时, 带宽β就变得不重要了, 反而是延迟比较关键, 这时更好地实现就不是环状算法了, 而应该使用树状通信.
- 注意: 因此英伟达NCCL里既实现了ring all-reduce算法, 也实现了double-tree all-reduce算法!!!