用shape来理解numpy的axis

理解numpy的axis

在查阅numpy官方文档的时候,可以看到有很多函数的参数列表中都有这样一个可选参数axis

Parameters

  • a array_like

    Array containing numbers whose mean is desired. If a is not an array, a conversion is attempted.

  • axis None or int or tuple of ints, optional

    Axis or axes along which the means are computed. The default is to compute the mean of the flattened array.*New in version 1.7.0.*If this is a tuple of ints, a mean is performed over multiple axes, instead of a single axis or all the axes as before.

  • dtype data-type, optional

    Type to use in computing the mean. For integer inputs, the default is float64; for floating point inputs, it is the same as the input dtype.

选自官方手册 numpy.mean https://numpy.org/doc/stable/reference/generated/numpy.mean.html

在我从其他地方刚接触到这个参数时,就只知道如果要将二维数组纵向处理的话,就要把axis的值设为0,但是并不知其所以然。

举个例子,计算下面这个二维数组每列的和

arr = np.arange(12).reshape((3, 4))
print(arr)
[[ 0123]
 [ 4567]
 [ 89 10 11]]

print(np.sum(arr, axis=0))# 这里将axis的值设为0
[12 15 18 21]

今天仔细研究了一下,在网上简单查了一下,发现网络上的(点名CSDN,全是不动脑子和从别处洗稿的博客,越讲越复杂。)非常难以理解

于是自己写了一篇笔记,记录一下我自己的理解,希望能对仍在这方面感到迷惑的读者有所帮助

观察

我们先生成一个4维数组

import numpy as np
from icecream import ic

arr = np.arange(120).reshape((2, 3, 4, 5))
ic(arr.ndim)
ic(arr)

# 结果如下
ic| arr.ndim: 4
ic| arr: array(
[[[[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[ 10,11,12,13,14],
[ 15,16,17,18,19]],
 
 [[ 20,21,22,23,24],
[ 25,26,27,28,29],
[ 30,31,32,33,34],
[ 35,36,37,38,39]],
 
 [[ 40,41,42,43,44],
[ 45,46,47,48,49],
[ 50,51,52,53,54],
[ 55,56,57,58,59]]],
 
 
[[[ 60,61,62,63,64],
[ 65,66,67,68,69],
[ 70,71,72,73,74],
[ 75,76,77,78,79]],
 
 [[ 80,81,82,83,84],
[ 85,86,87,88,89],
[ 90,91,92,93,94],
[ 95,96,97,98,99]],
 
 [[100, 101, 102, 103, 104],
[105, 106, 107, 108, 109],
[110, 111, 112, 113, 114],
[115, 116, 117, 118, 119]]]]
)

这样光看的话初学者很难看出轴在哪里,我们分别在不同的轴上使用numpy.sum()看一下结果的形状

ic(arr.shape)
ic(np.sum(arr, axis=0).shape)
ic(np.sum(arr, axis=1).shape)
ic(np.sum(arr, axis=2).shape)
ic(np.sum(arr, axis=3).shape)

# 结果
ic| arr.shape: (2, 3, 4, 5)
ic| np.sum(arr, axis=0).shape: (3, 4, 5)
ic| np.sum(arr, axis=1).shape: (2, 4, 5)
ic| np.sum(arr, axis=2).shape: (2, 3, 5)
ic| np.sum(arr, axis=3).shape: (2, 3, 4)

这里我们可以看到

  • axis = 0 的时候shape中的2消失
  • axis = 1 的时候shape中的3消失
  • axis = 2 的时候shape中的4消失
  • axis = 3 的时候shape中的5消失

为什么会消失呢,其实是sum函数将这个维度上的值求和了,所以在输出的结果上来看,这个维度就被“折叠”了

观察他们的联系,一下就明白了,axis的值其实就是这些轴的索引值,在方法中传入axis的值就相当于告诉方法要沿着哪个轴来取值

小训练

想象一下,这些结果是怎么样从shape为(2, 3, 4, 5)“折叠”得到的

ic| np.sum(arr, axis=0): array(
[[[ 60,62,64,66,68],
 [ 70,72,74,76,78],
 [ 80,82,84,86,88],
 [ 90,92,94,96,98]],
 
[[100, 102, 104, 106, 108],
 [110, 112, 114, 116, 118],
 [120, 122, 124, 126, 128],
 [130, 132, 134, 136, 138]],
 
[[140, 142, 144, 146, 148],
 [150, 152, 154, 156, 158],
 [160, 162, 164, 166, 168],
 [170, 172, 174, 176, 178]]]
)
ic| np.sum(arr, axis=0).shape: (3, 4, 5)

ic| np.sum(arr, axis=1): array(
[[[ 60,63,66,69,72],
 [ 75,78,81,84,87],
 [ 90,93,96,99, 102],
 [105, 108, 111, 114, 117]],
 
[[240, 243, 246, 249, 252],
 [255, 258, 261, 264, 267],
 [270, 273, 276, 279, 282],
 [285, 288, 291, 294, 297]]]
)
ic| np.sum(arr, axis=1).shape: (2, 4, 5)

ic| np.sum(arr, axis=2): array(
[[[ 30,34,38,42,46],
 [110, 114, 118, 122, 126],
 [190, 194, 198, 202, 206]],
 
[[270, 274, 278, 282, 286],
 [350, 354, 358, 362, 366],
 [430, 434, 438, 442, 446]]])
ic| np.sum(arr, axis=2).shape: (2, 3, 5)

ic| np.sum(arr, axis=3): array(
[[[ 10,35,60,85],
 [110, 135, 160, 185],
 [210, 235, 260, 285]],
 
[[310, 335, 360, 385],
 [410, 435, 460, 485],
 [510, 535, 560, 585]]]
)
ic| np.sum(arr, axis=3).shape: (2, 3, 4)

补充 1

这里简单提一下,既然我们发现传入axis的值其实就是轴的索引值,那么我们扩展一下思维,有没有负索引呢?

在numpy.sort https://numpy.org/doc/stable/reference/generated/numpy.sort.html 的文档中看到,这里axis的默认值为-1,也就是说默认沿着最里面的轴向排序

Parameters

  • aarray_like

    Array to be sorted.

  • axisint or None, optional

    Axis along which to sort. If None, the array is flattened before sorting. The default is -1, which sorts along the last axis.

  • kind{‘quicksort’, ‘mergesort’, ‘heapsort’, ‘stable’}, optional

    Sorting algorithm. The default is ‘quicksort’. Note that both ‘stable’ and ‘mergesort’ use timsort or radix sort under the covers and, in general, the actual implementation will vary with data type. The ‘mergesort’ option is retained for backwards compatibility.Changed in version 1.15.0.: The ‘stable’ option was added.

  • orderstr or list of str, optional

    When a is an array with fields defined, this argument specifies which fields to compare first, second, etc. A single field can be specified as a string, and not all fields need be specified, but unspecified fields will still be used, in the order in which they come up in the dtype, to break ties.

补充 2

一些常用到axis的方法,可以下来练练手

numpy.max
numpy.mean
numpy.min
numpy.repeat
numpy.sort
numpy.delete

这次没有在文档编辑重排板,直接从Typora复制贴过来了,可以省我很多事。个人觉得比网上找的大部分更清楚,但不免有自卖自夸的嫌疑

希望本文对读者有所帮助