Python 中的 griddata() 函数。这是科学计算中一个非常重要且常用的函数,用于将不规则分布的散点数据插值到规则网格上。
1. 函数概述
griddata() 位于 scipy.interpolate 模块中。它的核心功能是:给定一组散乱点 (x, y) 及其对应的值 z,计算这些散乱数据在指定规则网格点上的插值。
基本语法
from scipy.interpolate import griddata
grid_z = griddata(points, values, xi, method='linear', fill_value=nan, rescale=False)
参数说明
| 参数 | 说明 |
|---|---|
points |
形状为 (n, D) 的数组,表示 n 个数据点在 D 维空间中的坐标。对于 2D 数据,通常形状为 (n, 2),即 (x, y) 坐标对。 |
values |
长度为 n 的一维数组,表示每个数据点对应的值 z。 |
xi |
需要插值的目标网格点坐标。有两种指定方式: 1. 元组 (x_grid, y_grid),其中 x_grid 和 y_grid 是二维数组(由 meshgrid 生成)。 2. 形状为 (m, D) 的数组,表示 m 个目标点的坐标。 |
method |
插值方法,可选: - 'linear':线性插值(默认,基于 Delaunay 三角剖分) - 'nearest':最近邻插值 - 'cubic':三次插值(仅适用于 2D 数据) |
fill_value |
用于填充插值区域外点的值,默认为 nan。 |
rescale |
如果为 True,则在插值前将点重新缩放至单位立方体。适用于各维度量纲差异大的情况。 |
2. 典型使用场景
场景 1:将测量数据网格化
假设你在野外不规则地点测量了海拔高度(散乱数据),想要得到一张规则网格上的海拔等高线图。
场景 2:数据可视化
散乱的数据点难以直接绘制平滑的等高线图或曲面图,griddata() 可以将其转换为规则网格数据,便于使用 contourf、pcolormesh、plot_surface 等函数可视化。
场景 3:数据重采样
将不规则采样的数据重新采样到标准网格上,以便与其他数据集进行比较或进行进一步分析。
3. 详细示例
示例 1:基础二维插值
python
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import griddata
# 1. 生成模拟的散乱数据点
np.random.seed(42)
n_points = 100
x = np.random.uniform(-2, 2, n_points)
y = np.random.uniform(-2, 2, n_points)
z = x * np.exp(-x**2 - y**2) # 计算每个点的函数值
# 2. 创建规则的目标网格
grid_x, grid_y = np.mgrid[-2:2:100j, -2:2:100j] # 100x100 的网格
# 3. 使用 griddata 进行插值
grid_z_linear = griddata((x, y), z, (grid_x, grid_y), method='linear')
grid_z_nearest = griddata((x, y), z, (grid_x, grid_y), method='nearest')
grid_z_cubic = griddata((x, y), z, (grid_x, grid_y), method='cubic')
# 4. 可视化
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
# 原始散点数据
sc1 = axes[0, 0].scatter(x, y, c=z, s=20, cmap='viridis')
axes[0, 0].set_title('Original Scattered Data')
plt.colorbar(sc1, ax=axes[0, 0])
# 插值结果
im1 = axes[0, 1].imshow(grid_z_linear.T, extent=(-2,2,-2,2), origin='lower', cmap='viridis')
axes[0, 1].set_title('Linear Interpolation')
plt.colorbar(im1, ax=axes[0, 1])
im2 = axes[0, 2].imshow(grid_z_nearest.T, extent=(-2,2,-2,2), origin='lower', cmap='viridis')
axes[0, 2].set_title('Nearest Interpolation')
plt.colorbar(im2, ax=axes[0, 2])
im3 = axes[1, 0].imshow(grid_z_cubic.T, extent=(-2,2,-2,2), origin='lower', cmap='viridis')
axes[1, 0].set_title('Cubic Interpolation')
plt.colorbar(im3, ax=axes[1, 0])
# 等高线图
contour = axes[1, 1].contourf(grid_x, grid_y, grid_z_linear, levels=20, cmap='viridis')
axes[1, 1].set_title('Contour Plot (Linear)')
plt.colorbar(contour, ax=axes[1, 1])
# 3D 曲面图(可选)
from mpl_toolkits.mplot3d import Axes3D
ax3d = fig.add_subplot(2, 3, 6, projection='3d')
ax3d.plot_surface(grid_x, grid_y, grid_z_linear, cmap='viridis', alpha=0.8)
ax3d.scatter(x, y, z, c='red', s=10, alpha=1)
ax3d.set_title('3D Surface with Original Points')
plt.tight_layout()
plt.show()
示例 2:处理缺失值(外推区域)
python
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import griddata
# 创建一个有"空洞"的数据分布
x = np.random.uniform(-3, 3, 200)
y = np.random.uniform(-3, 3, 200)
z = np.sin(np.sqrt(x**2 + y**2))
# 移除中心区域的一些点,模拟数据缺失
mask = (x**2 + y**2) < 1
x_masked, y_masked, z_masked = x[~mask], y[~mask], z[~mask]
# 插值到网格
grid_x, grid_y = np.mgrid[-3:3:100j, -3:3:100j]
grid_z = griddata((x_masked, y_masked), z_masked, (grid_x, grid_y),
method='linear', fill_value=np.nan)
# 可视化:NaN 区域会显示为空白
plt.figure(figsize=(10, 4))
plt.subplot(121)
plt.scatter(x_masked, y_masked, c=z_masked, s=10, cmap='coolwarm')
plt.title('Data with Hole')
plt.subplot(122)
plt.imshow(grid_z.T, extent=(-3,3,-3,3), origin='lower',
cmap='coolwarm', alpha=0.7)
plt.colorbar()
plt.title('Interpolation with NaN Fill')
plt.show()
4. 重要注意事项
1. 外推问题
griddata() 默认不进行外推 。对于落在散点凸包(convex hull)外的网格点,会填充为 fill_value(默认为 nan)。如果需要外推,可以考虑:
-
使用
RBFInterpolator(SciPy 1.7+) -
使用
LinearNDInterpolator或CloughTocher2DInterpolator并设置fill_value为外推值 -
手动扩展数据边界
2. 性能考虑
-
对于大量数据点(>10^5),
griddata()可能较慢,因为它需要计算 Delaunay 三角剖分 -
method='cubic'比'linear'和'nearest'计算代价更高 -
如果需要对同一组散点多次插值到不同网格,建议使用
LinearNDInterpolator等类,避免重复计算三角剖分
3. 替代方案
# 更高效的替代(适用于多次插值)
from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator
# 创建插值器
interp_linear = LinearNDInterpolator(list(zip(x, y)), z)
interp_nearest = NearestNDInterpolator(list(zip(x, y)), z)
# 然后可以重复使用
grid_z1 = interp_linear(grid_x, grid_y)
grid_z2 = interp_nearest(grid_x, grid_y)
4. 数据准备
-
确保
points中没有重复点,否则可能导致插值问题 -
检查是否有
nan或inf在values中 -
对于三维或更高维插值,只能使用
'linear'或'nearest'方法
是的,您理解得很对!griddata() 本质上就是一种"预测器":基于已知的散点数据,预测任意新位置的值。
1. 最简单的比喻:天气预报
想象一下:
-
您在全国各地有气象站(散点
points) -
每个气象站测量了温度(已知值
values) -
您想知道任意一个新地点 的温度(预测)
# 比喻:气象站数据
气象站位置 = [(北京_x, 北京_y), (上海_x, 上海_y), ...] # 这是 points
气象站温度 = [30, 35, ...] # 这是 values
# 您想知道南京的温度
南京位置 = (南京_x, 南京_y) # 这是 xi
南京温度 = griddata(气象站位置, 气象站温度, 南京位置) # 预测!
2. 实际使用案例:根据少数测量点预测整个区域
场景:土壤污染检测
python
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import griddata
# 步骤1:已知数据 - 在几个点测量了污染物浓度
# 这些是您的"测量点"
known_points = np.array([
[10, 20], # 点1坐标 (x=10, y=20)
[30, 40], # 点2
[50, 60], # 点3
[70, 80], # 点4
[90, 10] # 点5
])
# 在这些点测得的污染物浓度 (ppm)
known_values = np.array([1.2, 2.5, 1.8, 3.0, 0.9])
print("已知测量点:")
for i, (point, value) in enumerate(zip(known_points, known_values)):
print(f" 点{i+1}: 位置{point}, 浓度={value}ppm")
# 步骤2:您想知道这些位置的值(预测!)
# 比如:工厂附近、居民区等关键位置
locations_to_predict = np.array([
[15, 25], # 位置A
[45, 55], # 位置B
[85, 15] # 位置C
])
# 步骤3:使用 griddata 预测
predicted_values = griddata(known_points, known_values,
locations_to_predict,
method='linear')
print("\n预测结果:")
for loc, val in zip(locations_to_predict, predicted_values):
print(f" 位置{loc}: 预测浓度 = {val:.2f}ppm")
# 步骤4:可视化
plt.figure(figsize=(10, 4))
# 子图1:显示测量点
plt.subplot(121)
plt.scatter(known_points[:, 0], known_points[:, 1],
c=known_values, s=200, cmap='Reds', edgecolors='black')
plt.colorbar(label='污染物浓度 (ppm)')
plt.scatter(locations_to_predict[:, 0], locations_to_predict[:, 1],
c='blue', s=100, marker='x', label='待预测点')
plt.title('测量点(红色)和待预测点(蓝色)')
plt.xlabel('X坐标')
plt.ylabel('Y坐标')
plt.legend()
# 子图2:生成整个区域的预测图(热点图)
plt.subplot(122)
# 创建密集网格来预测整个区域
grid_x, grid_y = np.mgrid[0:100:100j, 0:100:100j]
grid_z = griddata(known_points, known_values, (grid_x, grid_y),
method='linear', fill_value=0)
# 显示预测的热点图
plt.imshow(grid_z.T, extent=(0, 100, 0, 100), origin='lower',
cmap='Reds', alpha=0.7)
plt.colorbar(label='预测浓度 (ppm)')
plt.scatter(known_points[:, 0], known_points[:, 1],
c='black', s=50, label='测量点')
plt.title('整个区域的污染物浓度预测')
plt.xlabel('X坐标')
plt.ylabel('Y坐标')
plt.legend()
plt.tight_layout()
plt.show()
3. 更简单的例子:理解输入输出
python
import numpy as np
from scipy.interpolate import griddata
# ========== 例子1:最简单的预测 ==========
print("=== 例子1:从3个点预测新点 ===")
# 已知3个点的位置和值(就像知道3个地方的房价)
points = np.array([
[0, 0], # 点A:市中心
[10, 0], # 点B:近郊
[0, 10] # 点C:远郊
])
values = np.array([50000, 30000, 20000]) # 房价(元/平米)
# 您想预测这些位置的价格
new_locations = np.array([
[5, 0], # 位置1:市中心和近郊之间
[0, 5], # 位置2:市中心和远郊之间
[5, 5] # 位置3:中间位置
])
# 预测!
predictions = griddata(points, values, new_locations, method='linear')
for i, (loc, pred) in enumerate(zip(new_locations, predictions)):
print(f"位置{loc}: 预测房价 = ¥{pred:,.0f}/平米")
# ========== 例子2:实际应用 - 海拔高度预测 ==========
print("\n=== 例子2:登山海拔预测 ===")
# 已知几个地点的海拔高度
mountain_points = np.array([
[0, 0, 100], # 山脚:海拔100米
[1, 0, 500], # 山腰
[0, 1, 300], # 另一个方向
[1, 1, 800] # 接近山顶
])
# griddata需要分开坐标和值
coords = mountain_points[:, :2] # 只取x,y坐标
heights = mountain_points[:, 2] # 海拔值
# 想预测的登山路径
path_points = np.array([
[0.2, 0.3], # 路径点1
[0.5, 0.5], # 路径点2
[0.8, 0.8] # 路径点3
])
predicted_heights = griddata(coords, heights, path_points, method='linear')
print("登山路径海拔预测:")
for i, (point, height) in enumerate(zip(path_points, predicted_heights)):
print(f" 路径点{i+1} (位置{point}): 海拔约{height:.0f}米")
4. griddata() 的工作流程总结
输入:
已知点坐标 → points = [(x1,y1), (x2,y2), ...]
这些点的值 → values = [z1, z2, ...]
想知道的位置 → xi = [(x_new1,y_new1), (x_new2,y_new2), ...]
处理:
griddata() 内部:
1. 根据已知点构建三角网格
2. 找到每个新点在哪个三角形里
3. 用三角形顶点的值计算新点的值(插值)
输出:
预测值 → [z_new1, z_new2, ...]
5. 什么时候用 griddata()?什么时候用机器学习?
| 特点 | griddata() |
机器学习模型 |
|---|---|---|
| 数据量 | 小到中等(几百到几千点) | 可处理大数据 |
| 原理 | 几何插值(基于距离) | 统计学习 |
| 速度 | 预测快,但构建网格慢 | 训练慢,预测快 |
| 外推 | 不能外推(只在内插) | 可以外推(但有风险) |
| 使用场景 | 空间数据插值、测量数据处理 | 复杂模式识别、预测 |
简单选择指南:
-
如果您的数据是空间/地理数据 ,且点之间变化平滑 → 用
griddata() -
如果关系复杂、有噪声、需要预测未来 → 用机器学习
6. 实际工程中的应用模板
python
def predict_at_new_locations(known_data, new_locations, method='linear'):
"""
使用已知数据预测新位置的值
参数:
known_data: list of [(x1,y1,z1), (x2,y2,z2), ...] 或 (points, values)
new_locations: list of [(x_new1,y_new1), (x_new2,y_new2), ...]
返回:
预测值列表
"""
# 准备数据
if isinstance(known_data, tuple):
points, values = known_data
else:
known_data = np.array(known_data)
points = known_data[:, :2] # 前两列是坐标
values = known_data[:, 2] # 第三列是值
new_locations = np.array(new_locations)
# 预测
predictions = griddata(points, values, new_locations, method=method)
return predictions
# 使用示例
# 已知的传感器读数
sensor_data = [
[1, 1, 25.5], # (x,y,温度)
[2, 3, 26.0],
[4, 2, 24.8],
[3, 5, 25.2]
]
# 想知道这些位置的情况
query_points = [
[1.5, 2.0], # 位置A
[3.0, 3.5] # 位置B
]
predicted_temps = predict_at_new_locations(sensor_data, query_points)
print(f"位置A预测温度: {predicted_temps[0]:.1f}°C")
print(f"位置B预测温度: {predicted_temps[1]:.1f}°C")
总结回答您的问题
是的,griddata() 就是一个预测器:
-
输入已知数据 :
(在哪里, 值多少) -
输入想知道的位置 :
(新地点在哪里) -
输出预测值 :
(新地点的预测值)
就像您有了一些城市的温度数据,可以预测中间城市的温度一样简单直观!