NumPyのsumのdtypeが変わる話

June 28, 2020, 8:08 a.m. edited June 28, 2020, 8:15 a.m.

#NumPy  #Python 

>>> a = np.array([100, 100, 100, 100], dtype=np.int8)

としたときに

>>> a.sum()

とすると、一見オーバーフローしそう(\(2^{(8-1)}=128<400\))だが、実際には

>>> a.sum()
400

とほとんどの環境では問題なく計算される。なぜなら、numpy.sumにて、

The dtype of a is used by default unless a has an integer dtype of less precision than the default platform integer.

とあるように、デフォルトでは元の配列のdtypeが使用されるが、もしもそれがデフォルトのプラットフォームの整数の精度より小さかったらその限りではない。したがって、私の64 bit環境では、

>>> a.dtype
dtype('int8')
>>> a.sum().dtype
dtype('int64')

と確かにnp.int64に変換されている1

また、

In that case, if a is signed then the platform integer is used while if a is unsigned then an unsigned integer of the same precision as the platform integer is used.

とあるように、符号なしでは

>>> b = np.array([200, 200, 200, 200], dtype=np.uint8)
>>> b.dtype
dtype('uint8')
>>> b.sum().dtype
dtype('uint64')

という挙動になる。

一方で、わざわざ説明に整数と明記されているように、浮動小数で試すと、

>>> c = np.array([0, 0.1, 0.2, 0.3], dtype=np.float16)
>>> c.dtype
dtype('float16')
>>> c.sum().dtype
dtype('float16')

と変わらないようである。


  1. そうなると32 bit環境ではnp.int32になりそうだけど、試せる環境がないのでわからない。。。