Python机器学习库笔记(2)——matplotlib.pyplot

matplotlib是Python数据科学中常用的绘图工具库,我们常用到的是其中一个叫pyplot的子库,它能让我们作出像matlab中的图。掌握matplotlib.pyplot是机器学习项目中数据可视化的关键。

基本用法

1
2
3
4
5
6
7
8
9
10
11
12
13
import numpy as np
import matplotlib.pyplot as plt

# 在IPython、Jupyter Notebook有用,可以省略plt.show()
%matplotlib inline

# 测试用例
x = np.linspace(-5, 5, 100)
y = 2*x + 1
# 最简单的作图
plt.plot(x, y)
# 在Pycharm等IDE中显示图形
plt.show()

设置中文

当不改变参数时,matplotlib作的图上的中文会变成方框,因此需要手动修改字体。

1
2
3
import matplotlib as mpl
mpl.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False # 解决负号变方框的问题

图形实例

通常先创建一个figure()实例,之后可以操作这个图形的名称、坐标轴、图例等。

figure()可以接收多个参数,以下列出几个常用的参数:

  • num:指定画布的编号,如果使用pycharm等IDE,它也是画布窗口的名称。
  • figsize:指定画布的大小。
  • dpi:指定图形的分辨率。
  • facecolor:指定画布的背景色。
1
plt.figure(num=0, figsize=(5, 5), dpi=72, facecolor='k')

当调用figure()时,除了创建画布对象外,还会获得当前的绘图区域,之后进行的操作默认在当前绘图区域进行。如果figure()不传入任何参数,实际上是在subplot(111)上作图。

折线图

plt.plot()前两个参数接收X轴和Y轴的变量,此外还可以接受多个参数以改变图形主体的外观:

  • color(可简写为c):指定曲线的颜色,可以用常用颜色的英文(或其首字母),也可以制定RGB值(如#00FFFF)。
  • marker:指定数据点的类型,如o代表圆点,s代表方块,x代表叉号,等等。
  • linewidth:指定曲线的宽度。
  • linestyle:指定曲线的类型,如-代表实线,--代表虚线,-.代表点划线,等等。
  • alpha:指定曲线的透明度。
1
2
3
4
x = np.sort(np.random.rand(100))
y = np.sort(np.random.rand(100))

plt.plot(x, y, c='r', linewidth=5)

用于设定外观的的参数可以用简写,例如下面这个"bx--"参数相当于color='b', marker='x', linestyle='--'

1
2
# plt.plot(x, y, color='b', marker='x', linestyle='--')
plt.plot(x, y, 'bx--')

散点图

plt.scatter()来绘制散点图,除了上述在折线图可用的参数外,散点图还有其他一些可选参数:

  • s:点的大小,可以设置为与数据数组大小相同的数组来指定每个点的大小。(同样,颜色也可以这样指定)
1
2
3
4
5
x = np.random.randn(50)
y = np.random.randn(50)
area = np.random.randint(10, 100, 50)

plt.scatter(x, y, s=area)

柱状图

plt.bar()用于绘制柱状图,它还可以接受如下的参数:

  • height:用数组指定每个柱子的高度。
  • width:指定柱子的宽度。
  • align:设置对齐方式,默认是center
1
2
3
4
5
n = 5
x = np.arange(n)
y = np.random.randint(1, 10, n)

plt.bar(x, y, width=.5)

在一个坐标系中显示多组数据,可以用如下的方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
n = 5
x = np.arange(n)
y1 = np.random.randint(1, 10, n)
y2 = np.random.randint(1, 10, n)
y3 = np.random.randint(1, 10, n)

# 计算每个柱子的宽度
total_width = 1 #总长度包含每条柱子长度和一条空白长度
width = total_width / 4 # y种类数量+1

plt.bar(x - width, y1, width=width, label='y1')
plt.bar(x, y2, width=width, label='y2')
plt.bar(x + width, y3, width=width, label='y3')
plt.legend(loc='best')

箱线图

plt.boxplot()可以用于绘制箱线图,它默认的构造方法已经足够使用。以下列出一些其他可能用到的参数:

  • vert:布尔型,指定图的方向,True为纵向,False为横向。
  • patch_artist:布尔型,指定四分位框内是否填充,True为填充。
1
2
3
x = np.random.randn(100).reshape(20, 5)

plt.boxplot(x)

饼图

plt.pie()用于绘制饼图,它可接收如下参数:

  • explode:当要把某一部分凸出时,将对应的值设为0.1就够。
  • labels:对应每个部分的标签。
  • autopct:显示百分比的格式。
  • shadow:是否显示阴影。
  • startangle:起始角度,从X轴起逆时针旋转的度数。
1
2
3
4
5
6
x = [10, 30, 20, 40]
explode = [0, 0.1, 0, 0]
labels = ['A', 'B', 'C', 'D']

plt.pie(x, explode=explode, labels=labels, autopct='%1.1f%%', shadow=False, startangle=90)
plt.axis('equal') # 加上这句,饼图呈圆形,否则为椭圆

热图

plt.imshow()可以用于生成热图,分析变量间相关性时常用此图。imshow()主要接收如下参数:

  • x:即数据,可以是二维浮点型(灰度)、三维浮点型或unit8(RGB)或者四维浮点型或unit8(RGBA)数组。
  • cmap:用于指定颜色图谱,默认为RGB(A)色彩空间,其余用的比较多的有grayjet等。
  • interpolation:插值方法,默认为nearest(不同版本可能有差异),用其他方法可以平滑色块边缘。

通常还会定义一个plt.colorbar()用于标识颜色。

  • shrink:设置Bar的长度(比例)。
1
2
3
4
x = np.random.rand(10, 10)

plt.imshow(x, cmap=plt.cm.summer, interpolation="bilinear")
plt.colorbar(shrink=.5)

显示图像

plt.imshow()也能够显示栅格图像,一个很常见的例子是用于显示MNIST的手写数字。

1
2
3
4
5
a = np.random.rand(1024).reshape(32, 32)

plt.imshow(a, cmap=plt.cm.binary, interpolation="nearest")
plt.title('Bedrock')
plt.axis("off") # 关闭坐标轴

图形的其他元素

坐标轴

坐标轴可以手动设置范围、刻度等。以X轴为例:

  • plt.xlim:设置坐标轴的范围。
  • plt.xlabel:设置坐标轴的名称。
  • plt.xticks:设置坐标轴的刻度,可以接收一个数组,也可以接受一对数组及其对位的刻度。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
x = np.linspace(-5, 5, 100)
y = 2*x + 1

plt.figure()
plt.plot(x, y)
# 坐标轴范围
plt.xlim((-2, 1))
plt.ylim((-3, 3))
# 坐标轴名称
plt.xlabel("X axis")
plt.ylabel("Y axis")
# 坐标轴刻度
plt.xticks(np.arange(-2., 1.5, 0.5))
plt.yticks([-2.5, 0., 2.5], ["low", "mid", "high"])

利用plt.gca()(Get Current Axis)可以对当前绘图区域的坐标轴进行更多操作,例如将两条坐标轴置于中间变成十字形。

  • ax.set_title:设置图表标题。
  • ax.set_xlabel:设置坐标轴的名称。
  • ax.set_xticks:设置坐标轴的刻度。
  • ax.set_xticklabels:设置上面所设置刻度对应的标签。
  • ax.spines['left']:一幅图共有四个坐标轴,分别用leftrighttopbottom表示。
  • ax.spines['left'].set_color:设置该轴颜色,将颜色设置为none可以隐藏该轴。
  • ax.spines['left'].set_position:设置坐标轴的位置,用'data'将坐标轴移动到第二个参数指定的数据处。如设置到坐标原点即ax.spines['left'].set_position(('data', 0))
  • ax.xaxis.set_ticks_position:设置刻度相对于坐标轴的位置。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
x = np.linspace(-5, 5, 100)
y = np.sin(x)

plt.figure()
ax = plt.gca()
# 设置标题
ax.set_title('$sin(x)$', color='k', fontsize=20)
# 隐藏上、右边框
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
# 将左、下轴移至中间
ax.spines['bottom'].set_position(('data', 0))
ax.spines['left'].set_position(('data', 0))
ax.plot(x, y)

标题、图例、网格

一幅完整的图还包括标题、图例。

标题可以通过plt.title()方法设置,同时可以传入参数设置标题的颜色、字体等。

图例可以通过plt.legend()方法显示,前提是创建图形时传入了label参数。还可以接收loc参数来指定图例的位置,一般指定为best即可,图例会被自动放在最合适的位置。

如果要显示网格,可以通过plt.grid()方法显示。通过传入axis参数可以指定显示'x''y'或者'both'方向上的网格线。网格线也可以使用color等参数。

1
2
3
4
5
6
7
8
9
10
11
12
x = np.linspace(-5, 5, 100)
y1 = 2*x + 1
y2 = x**2 - 3*x + 1

plt.plot(x, y1, "b-", label='line1')
plt.plot(x, y2, "r--", label='line2')
# 设置标题
plt.title('Title', fontsize=20)
# 显示图例
plt.legend(loc='best')
# 显示网格
plt.grid(axis='y', color='grey', linestyle=':')

标注

当需要特别标注出图上一个点时,先需要标出该点,过该点向坐标轴引一条垂线,再用箭头将其标注出来。matplotlib提供了annotate()的方法来让我们实现这一功能。

annotate()第一个参数接收一组字符串作为显示的标注文字,其余部分参数如下:

  • xy:被标注的点坐标。
  • xytext:标注文字的坐标。
  • arrowprops:字典类型的参数,定义箭头的属性。
  • arrowstylearrowprops的一个键,用于设置箭头的形状,如'->'
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
x = np.linspace(-5, 5, 100)
y = x**2 + 1
plt.plot(x, y)
plt.xlim(-5, 5)
plt.ylim(0, 25)

# 以(-3,10)这个点为例
x0 = -3
y0 = x0**2 + 1
# 标注该点
plt.scatter(x0, y0, s=50, color='k')
# 向X轴引垂线
plt.plot([x0, x0], [y0, 0], 'k--')
# 向Y轴引垂线
plt.plot([-5, x0], [y0, y0], 'k--')
# 用箭头引出该点
plt.annotate("Here", xy=(-3, 10), xytext=(-2, 15), arrowprops=dict(arrowstyle='->'))

绘制子图

subplot()用法

pyplot提供了subplot这一方法来在一幅画布中绘制多个子图。plt.subplot(m,n,k)表示将画布分成m行n列,此图位于第k个位置(先从左往右、再从上往下数)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
x = np.linspace(-5, 5, 100)
y1 = 1/(1+np.exp(-x))
y2 = np.tanh(x)
y3 = np.maximum(0, x)
y4 = np.maximum(0.1*x, x)

plt.figure()

ax1 = plt.subplot(2,2,1)
ax1.plot(x, y1, 'r-')

ax2 = plt.subplot(2,2,2)
ax2.plot(x, y2, 'g:')

ax3 = plt.subplot(2,2,3)
ax3.plot(x, y3, 'b*')

ax4 = plt.subplot(2,2,4)
ax4.plot(x, y4, 'y-.')

plt.subplot(m,n,k)的参数中的逗号可以省略。并且每幅子图并不一定要大小一致,例如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
plt.figure()

plt.subplot(211)
plt.plot(x, y1)
plt.title('Sigmoid') # 添加子图的标题

plt.subplot(234)
plt.plot(x, y2)
plt.title('tanh')

plt.subplot(235)
plt.plot(x, y3)
plt.title('ReLu')

plt.subplot(236)
plt.plot(x, y4)
plt.title('PReLu')

plt.suptitle('Title') # 添加总图的标题

subplots()用法

plt.subplots()是另一种创建子图的方法。返回的类型是元组,第一个是画布对象,第二个是子图的集合。

传入sharexsharey参数可以决定子图的X轴、Y轴的范围是否一致。

1
2
3
4
5
fig, ax = plt.subplots(2, 2, sharex=True, sharey=False)
ax[0][0].plot(x, y1)
ax[0][1].plot(x, y2)
ax[1][0].plot(x, y3)
ax[1][1].plot(x, y4)

共享坐标轴

有时需要在一幅子图中展示两组数据,可以用twinx()或者twiny()的方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
x = np.linspace(-5, 5, 100)
y1 = 2*x + 1
y2 = x**2 - 3*x + 1

fig = plt.figure()
ax1 = plt.subplot(111)

ax2 = ax1.twinx() # 共享X轴

ax1.plot(x, y1, 'r-')
ax2.plot(x, y2, 'b--')

ax1.set_xlabel('TwinX axis')
ax1.set_ylabel('Y1 axis', color='r')
ax2.set_ylabel('Y2 axis', color='b')