好久没写了,来填坑,matplotlib相关的tips第二篇。包括网格化子图、坐标轴反向、savefig以及stacked bar.

1. matplotlib.gridspec

首先是子图网格化,比较简单但高效的一种方式,相对subplot更好用。样例如下,

1
2
3
4
5
6
7
8
9
from matplotlib import gridspec
# initial grid 2 x 2
gs = gridspec.GridSpec(2, 2, width_ratios=[1,1], height_ratios=[1,1])
ax0 = plt.subplot(gs[0])
xxxxxx
as1 = plt.subplot(gs[1])
xxxxxx

此处width_ratiosheight_ratios参数用于调整横向和纵向的子图之间的比例关系,效果见Tip2.

2. matplotlib的坐标轴翻转 (反向)

坐标轴反向主要出现在image相关的问题中。image通常以(row,column,channel)的数据结构出现,其y轴方向 (Vertical)与我们习惯的纵坐标镜像对称。因此,在处理图像的过程中,需要对坐标进行镜像对称,或者更简单的,对坐标轴进行翻转 (invert)。

  • 图像矩阵翻转
    numpy包中的flipud方法可以用于图像的翻转,是常用方法之一。
  • 坐标轴翻转
    matplotlib.Axis类中提供了invert_xaxisinvert_yaxis两个方法,下面结合第一点的gridspec给出样例。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import matplotlib.pyplot as plt
# display and invert
plt.rcParams["figure.figsize"] = [10.0, 24.0]
from matplotlib import gridspec
gs = gridspec.GridSpec(1,3, width_ratios=[1,1,1])
# raw direction
ax0 = plt.subplot(gs[0])
ax0.imshow(img)
ax0.set_xlabel("Horizontal",fontsize=12)
ax0.set_ylabel("Vertical",fontsize=12)
plt.title("Raw direction",fontsize=12)
# Y invert
ax1 = plt.subplot(gs[1])
ax1.imshow(img)
ax1.invert_yaxis()
ax1.set_xlabel("Horizontal",fontsize=12)
plt.title("Y invert",fontsize=12)
# X invert
ax2 = plt.subplot(gs[2])
ax2.imshow(img)
ax2.invert_xaxis()
ax2.set_xlabel("Horizontal",fontsize=12)
plt.title("X invert",fontsize=12)

其输出的图像如下图所示

3. savefig问题

保存图像是最近遇到的问题,主要涉及分辨率和去白边的问题。当然,如果能得到矢量图,尽量做矢量图。。。savefig的形式如下,

1
2
3
4
xxxxx # 做图的相关程序...
# save
plt.savefig(name,dpi=300,bbox_inches="tight")

此处的dpi表示“digits per inch”,数值越大分辨率越高;而bbox_inches用于控制输出的margin,如果实例化为”tight”可以去掉白边。

4. stacked bar

Stacked bar属于直方图的一种,考虑到每个bin内部的样本可能会细分为某些类或者某些区间,为了更好的描述bin内部样本的分布,可以采用stacked bar plot. 在matplotlib的pyplot类中提供了bar方法用于做直方图,该方法的bottom参数用于vertical方向的stacked bar,而leftright参数可以用于horitontal方向的stacked bar。样例代码如下,

1
2
3
4
5
6
7
8
9
10
11
12
13
plt.rcParams["figure.figsize"] = [7.0, 5.0]
bins = np.arange(0.5, 2.5+0.5, 0.5)
data = np.array([[2, 4, 3, 2, 4],[3, 2, 2, 3, 2], [3, 2 ,3, 1, 3]])
data = data.T
plt.bar(bins, data[:,0], width=0.2)
plt.bar(bins, data[:,1], width=0.2, bottom=data[:,0])
plt.bar(bins, data[:,2], width=0.2, bottom=data[:,0]+data[:,1])
plt.xlabel("Bins",fontsize=12)
plt.ylabel("Bars",fontsize=12)
plt.legend(["data-1", "data-2", "data-3"],loc=2,fontsize=12)
plt.ylim([0,12])

其输出结果如下图