Python深度学习:基于PyTorch
上QQ阅读APP看书,第一时间看更新

1.7 广播机制

Numpy的Universal functions中要求输入的数组shape是一致的,当数组的shape不相等时,则会使用广播机制。不过,调整数组使得shape一样,需要满足一定的规则,否则将出错。这些规则可归纳为以下4条。

1)让所有输入数组都向其中shape最长的数组看齐,不足的部分则通过在前面加1补齐,如:

a:2×3×2

b:3×2

则b向a看齐,在b的前面加1,变为:1×3×2

2)输出数组的shape是输入数组shape的各个轴上的最大值;

3)如果输入数组的某个轴和输出数组的对应轴的长度相同或者某个轴的长度为1时,这个数组能被用来计算,否则出错;

4)当输入数组的某个轴的长度为1时,沿着此轴运算时都用(或复制)此轴上的第一组值。

广播在整个Numpy中用于决定如何处理形状迥异的数组,涉及的算术运算包括(+,-,*,/…)。这些规则说得很严谨,但不直观,下面我们结合图形与代码来进一步说明。

目的:A+B,其中A为4×1矩阵,B为一维向量(3,)。

要相加,需要做如下处理:

·根据规则1,B需要向看齐,把B变为(1,3)

·根据规则2,输出的结果为各个轴上的最大值,即输出结果应该为(4,3)矩阵,那么A如何由(4,1)变为(4,3)矩阵?B又如何由(1,3)变为(4,3)矩阵?

·根据规则4,用此轴上的第一组值(要主要区分是哪个轴),进行复制(但在实际处理中不是真正复制,否则太耗内存,而是采用其他对象如ogrid对象,进行网格处理)即可,详细处理过程如图1-4所示。

图1-4 Numpy广播规则示意图

代码实现:

import numpy as np
A = np.arange(0, 40,10).reshape(4, 1)
B = np.arange(0, 3)
print("A矩阵的形状:{},B矩阵的形状:{}".format(A.shape,B.shape))
C=A+B
print("C矩阵的形状:{}".format(C.shape))
print(C)

运行结果:

A矩阵的形状:(4, 1),B矩阵的形状:(3,)
C矩阵的形状:(4, 3)
[[ 0  1  2]
 [10 11 12]
 [20 21 22]
 [30 31 32]]