因为梯度下降算法需要绘制 3 维图像,故学习之,日后借鉴。
本文稿翻译自 mplot3d tutorial — Matplotlib 2.0.2 documentation。
使用 matplotlib 绘制 3D 图像,一般要加入一个新的 axes 类型 Axes3D:
1
2
3
4
|
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
|
其中的ax
,就是添加一个三维坐标系Axes3D
的对象,如下图所示
3D 图形分为如下几类:
线形图
Axes3D.plot(xs,ys,**args,**kwargs)
绘制 2D 或者 3D 的数据。
Argument |
Description |
xs, ys |
x, y coordinates of vertices |
zs |
z value(s), either one for all points or one for each point. |
zdir |
Which direction to use as z (‘x’, ‘y’ or ‘z’) when plotting a 2D set. |
关键参数传给了plot()
函数,例如下面的代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import matplotlib.pyplot as plt
mpl.rcParams['legend.fontsize'] = 10
fig = plt.figure()
ax = fig.gca(projection='3d')
theta = np.linspace(-4 * np.pi, 4 * np.pi, 100)
z = np.linspace(-2, 2, 100)
r = z**2 + 1
x = r * np.sin(theta)
y = r * np.cos(theta)
ax.plot(x, y, z, label='parametric curve')
ax.legend()
plt.show()
|
绘制的图形如下
从这个例子可以看出,matplot 画图的基本步骤包括:导入必要的模块,创建 figure 对象,设置 3D 的 ax,创建自变量,写出函数关系式,绘制图形。
散点图
Axes3D.scatter(*xs*, *ys*, *zs=0*, *zdir='z'*, *s=20*, *c=None*, *depthshade=True*, **args*, **\*kwargs*)
Argument |
Description |
xs, ys |
Positions of data points. |
zs |
Either an array of the same length as xs and ys or a single value to place all points in the same plane. Default is 0. |
zdir |
Which direction to use as z (‘x’, ‘y’ or ‘z’) when plotting a 2D set. |
s |
Size in points^2. It is a scalar or an array of the same length as x and y. |
c |
A color. c can be a single color format string, or a sequence of color specifications of length N, or a sequence of Nnumbers to be mapped to colors using the cmap and norm specified via kwargs (see below). Note that c should not be a single numeric RGB or RGBA sequence because that is indistinguishable from an array of values to be colormapped. c can be a 2-D array in which the rows are RGB or RGBA, however, including the case of a single row to specify the same color for all points. |
depthshade |
Whether or not to shade the scatter markers to give the appearance of depth. Default is True. |
关键参数传给了scatter()
函数,如下面的例子
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
|
'''
==============
3D scatterplot
==============
Demonstration of a basic scatterplot in 3D.
'''
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
def randrange(n, vmin, vmax):
'''
Helper function to make an array of random numbers having shape (n, )
with each number distributed Uniform(vmin, vmax).
'''
return (vmax - vmin)*np.random.rand(n) + vmin
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
n = 100
# For each set of style and range settings, plot n random points in the box
# defined by x in [23, 32], y in [0, 100], z in [zlow, zhigh].
for c, m, zlow, zhigh in [('r', 'o', -50, -25), ('b', '^', -30, -5)]:
xs = randrange(n, 23, 32)
ys = randrange(n, 0, 100)
zs = randrange(n, zlow, zhigh)
ax.scatter(xs, ys, zs, c=c, marker=m)
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
plt.show()
|
这个函数里画了 2 组散点图,分别用其中的小三角和红色的圆点表示。函数randrange
产生[vmin,vmax]
上的均匀分布的一列数。如下图所示
线框图
Axes3D.plot_wireframe(X, Y, Z, args, kwargs)
绘制 3D 的线框图,其中的参数rstride
和cstride
表示对输入数据的采样,它们不能和rcount
以及ccount
同时使用,不然会产生错误,后者表示从输入数据中采样多少以生成线框图。
Argument |
Description |
X, Y, |
Data values as 2D arrays |
Z |
|
rstride |
Array row stride (step size), defaults to 1 |
cstride |
Array column stride (step size), defaults to 1 |
rcount |
Use at most this many rows, defaults to 50 |
ccount |
Use at most this many columns, defaults to 50 |
关键参数传给了Linecollection
,返回一个Line3DCollection
的类。举例如下
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
|
'''
=================
3D wireframe plot
=================
A very basic demonstration of a wireframe plot.
'''
from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
# Grab some test data.
X, Y, Z = axes3d.get_test_data(0.05)
# Plot a basic wireframe.
ax.plot_wireframe(X, Y, Z, rstride=10, cstride=10)
plt.show()
|
绘制图形如下
其中rstride
和cstride
分别代表采样的密度,这里是每隔 10 个点计算一个 Z 值,如果设置成 1,绘制的图形会更密集,如下图所示
表面图
Axes3D.plot_surface(X, Y, Z, *args, **kwargs)
默认使用纯色为阴影着色,不过它也可以通过 cmap 支持颜色映射。
Argument |
Description |
X, Y, Z |
Data values as 2D arrays |
rstride |
Array row stride (step size) |
cstride |
Array column stride (step size) |
rcount |
Use at most this many rows, defaults to 50 |
ccount |
Use at most this many columns, defaults to 50 |
color |
Color of the surface patches |
cmap |
A colormap for the surface patches. |
facecolors |
Face colors for the individual patches |
norm |
An instance of Normalize to map values to colors |
vmin |
Minimum value to map |
vmax |
Maximum value to map |
shade |
Whether to shade the facecolors |
其他的参数传给Ploy3DCollection
,举例如下
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
|
'''
======================
3D surface (color map)
======================
Demonstrates plotting a 3D surface colored with the coolwarm color map.
The surface is made opaque by using antialiased=False.
Also demonstrates using the LinearLocator and custom formatting for the
z axis tick labels.
'''
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
import numpy as np
fig = plt.figure()
ax = fig.gca(projection='3d')
# Make data.
X = np.arange(-5, 5, 0.25)
Y = np.arange(-5, 5, 0.25)
X, Y = np.meshgrid(X, Y)
R = np.sqrt(X**2 + Y**2)
Z = np.sin(R)
# Plot the surface.
surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm,
linewidth=0, antialiased=False)
# Customize the z axis.
ax.set_zlim(-1.01, 1.01)
ax.zaxis.set_major_locator(LinearLocator(10))
ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))
# Add a color bar which maps values to colors.
fig.colorbar(surf, shrink=0.5, aspect=5)
plt.show()
|
绘制图形如下
参考图形知道cm
用来做 color mapping,重新设置arange
的步长为 0.01,可以得到如下的图形
表面光滑细致多了。
2D/3D 图形共存
这篇文章主要是用来画 3 维图形的,以上的几个图形已经够用,下面介绍一些其他的技能。现在的是在 2D 中画 3D 图形。直接上代码和图像
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
|
"""
=======================
Plot 2D data on 3D plot
=======================
Demonstrates using ax.plot's zdir keyword to plot 2D data on
selective axes of a 3D plot.
"""
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.gca(projection='3d')
# Plot a sin curve using the x and y axes.
x = np.linspace(0, 1, 100)
y = np.sin(x * 2 * np.pi) / 2 + 0.5
ax.plot(x, y, zs=0, zdir='z', label='curve in (x,y)')
# Plot scatterplot data (20 2D points per colour) on the x and z axes.
colors = ('r', 'g', 'b', 'k')
x = np.random.sample(20*len(colors))
y = np.random.sample(20*len(colors))
c_list = []
for c in colors:
c_list.append([c]*20)
# By using zdir='y', the y value of these points is fixed to the zs value 0
# and the (x,y) points are plotted on the x and z axes.
ax.scatter(x, y, zs=0, zdir='y', c=c_list, label='points in (x,z)')
# Make legend, set axes limits and labels
ax.legend()
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_zlim(0, 1)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
# Customize the view angle so it's easier to see that the scatter points lie
# on the plane y=0
ax.view_init(elev=20., azim=-35)
plt.show()
|
从上面的代码,可以看出如何设置坐标轴的取值范围,设置 label 的方法。
加入文字
Axes3D.text(x, y, z, s, zdir=None, **kwargs)
在画图中我们可能需要在特定位置加入文字说明,下面就是一个例子
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
|
'''
======================
Text annotations in 3D
======================
Demonstrates the placement of text annotations on a 3D plot.
Functionality shown:
- Using the text function with three types of 'zdir' values: None,
an axis name (ex. 'x'), or a direction tuple (ex. (1, 1, 0)).
- Using the text function with the color keyword.
- Using the text2D function to place text on a fixed position on the ax object.
'''
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.gca(projection='3d')
# Demo 1: zdir
zdirs = (None, 'x', 'y', 'z', (1, 1, 0), (1, 1, 1))
xs = (1, 4, 4, 9, 4, 1)
ys = (2, 5, 8, 10, 1, 2)
zs = (10, 3, 8, 9, 1, 8)
for zdir, x, y, z in zip(zdirs, xs, ys, zs):
label = '(%d, %d, %d), dir=%s' % (x, y, z, zdir)
ax.text(x, y, z, label, zdir)
# Demo 2: color
ax.text(9, 0, 0, "red", color='red')
# Demo 3: text2D
# Placement 0, 0 would be the bottom left, 1, 1 would be the top right.
ax.text2D(0.05, 0.95, "2D Text", transform=ax.transAxes)
# Tweaking display region and labels
ax.set_xlim(0, 10)W
ax.set_ylim(0, 10)
ax.set_zlim(0, 10)
ax.set_xlabel('X axis')
ax.set_ylabel('Y axis')
ax.set_zlabel('Z axis')
plt.show()
|