griddata

Python 中的 griddata() 函数。这是科学计算中一个非常重要且常用的函数,用于将不规则分布的散点数据插值到规则网格上

1. 函数概述

griddata() 位于 scipy.interpolate 模块中。它的核心功能是:给定一组散乱点 (x, y) 及其对应的值 z,计算这些散乱数据在指定规则网格点上的插值。

基本语法

复制代码
from scipy.interpolate import griddata

grid_z = griddata(points, values, xi, method='linear', fill_value=nan, rescale=False)

参数说明

参数 说明
points 形状为 (n, D) 的数组,表示 n 个数据点在 D 维空间中的坐标。对于 2D 数据,通常形状为 (n, 2),即 (x, y) 坐标对。
values 长度为 n 的一维数组,表示每个数据点对应的值 z
xi 需要插值的目标网格点坐标。有两种指定方式: 1. 元组 (x_grid, y_grid),其中 x_gridy_grid 是二维数组(由 meshgrid 生成)。 2. 形状为 (m, D) 的数组,表示 m 个目标点的坐标。
method 插值方法,可选: - 'linear':线性插值(默认,基于 Delaunay 三角剖分) - 'nearest':最近邻插值 - 'cubic':三次插值(仅适用于 2D 数据)
fill_value 用于填充插值区域外点的值,默认为 nan
rescale 如果为 True,则在插值前将点重新缩放至单位立方体。适用于各维度量纲差异大的情况。

2. 典型使用场景

场景 1:将测量数据网格化

假设你在野外不规则地点测量了海拔高度(散乱数据),想要得到一张规则网格上的海拔等高线图。

场景 2:数据可视化

散乱的数据点难以直接绘制平滑的等高线图或曲面图,griddata() 可以将其转换为规则网格数据,便于使用 contourfpcolormeshplot_surface 等函数可视化。

场景 3:数据重采样

将不规则采样的数据重新采样到标准网格上,以便与其他数据集进行比较或进行进一步分析。

3. 详细示例

示例 1:基础二维插值

python 复制代码
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import griddata

# 1. 生成模拟的散乱数据点
np.random.seed(42)
n_points = 100
x = np.random.uniform(-2, 2, n_points)
y = np.random.uniform(-2, 2, n_points)
z = x * np.exp(-x**2 - y**2)  # 计算每个点的函数值

# 2. 创建规则的目标网格
grid_x, grid_y = np.mgrid[-2:2:100j, -2:2:100j]  # 100x100 的网格

# 3. 使用 griddata 进行插值
grid_z_linear = griddata((x, y), z, (grid_x, grid_y), method='linear')
grid_z_nearest = griddata((x, y), z, (grid_x, grid_y), method='nearest')
grid_z_cubic = griddata((x, y), z, (grid_x, grid_y), method='cubic')

# 4. 可视化
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# 原始散点数据
sc1 = axes[0, 0].scatter(x, y, c=z, s=20, cmap='viridis')
axes[0, 0].set_title('Original Scattered Data')
plt.colorbar(sc1, ax=axes[0, 0])

# 插值结果
im1 = axes[0, 1].imshow(grid_z_linear.T, extent=(-2,2,-2,2), origin='lower', cmap='viridis')
axes[0, 1].set_title('Linear Interpolation')
plt.colorbar(im1, ax=axes[0, 1])

im2 = axes[0, 2].imshow(grid_z_nearest.T, extent=(-2,2,-2,2), origin='lower', cmap='viridis')
axes[0, 2].set_title('Nearest Interpolation')
plt.colorbar(im2, ax=axes[0, 2])

im3 = axes[1, 0].imshow(grid_z_cubic.T, extent=(-2,2,-2,2), origin='lower', cmap='viridis')
axes[1, 0].set_title('Cubic Interpolation')
plt.colorbar(im3, ax=axes[1, 0])

# 等高线图
contour = axes[1, 1].contourf(grid_x, grid_y, grid_z_linear, levels=20, cmap='viridis')
axes[1, 1].set_title('Contour Plot (Linear)')
plt.colorbar(contour, ax=axes[1, 1])

# 3D 曲面图(可选)
from mpl_toolkits.mplot3d import Axes3D
ax3d = fig.add_subplot(2, 3, 6, projection='3d')
ax3d.plot_surface(grid_x, grid_y, grid_z_linear, cmap='viridis', alpha=0.8)
ax3d.scatter(x, y, z, c='red', s=10, alpha=1)
ax3d.set_title('3D Surface with Original Points')

plt.tight_layout()
plt.show()

示例 2:处理缺失值(外推区域)

python 复制代码
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import griddata

# 创建一个有"空洞"的数据分布
x = np.random.uniform(-3, 3, 200)
y = np.random.uniform(-3, 3, 200)
z = np.sin(np.sqrt(x**2 + y**2))

# 移除中心区域的一些点,模拟数据缺失
mask = (x**2 + y**2) < 1
x_masked, y_masked, z_masked = x[~mask], y[~mask], z[~mask]

# 插值到网格
grid_x, grid_y = np.mgrid[-3:3:100j, -3:3:100j]
grid_z = griddata((x_masked, y_masked), z_masked, (grid_x, grid_y),
                  method='linear', fill_value=np.nan)

# 可视化:NaN 区域会显示为空白
plt.figure(figsize=(10, 4))
plt.subplot(121)
plt.scatter(x_masked, y_masked, c=z_masked, s=10, cmap='coolwarm')
plt.title('Data with Hole')

plt.subplot(122)
plt.imshow(grid_z.T, extent=(-3,3,-3,3), origin='lower',
           cmap='coolwarm', alpha=0.7)
plt.colorbar()
plt.title('Interpolation with NaN Fill')
plt.show()

4. 重要注意事项

1. 外推问题

griddata() 默认不进行外推 。对于落在散点凸包(convex hull)外的网格点,会填充为 fill_value(默认为 nan)。如果需要外推,可以考虑:

  • 使用 RBFInterpolator(SciPy 1.7+)

  • 使用 LinearNDInterpolatorCloughTocher2DInterpolator 并设置 fill_value 为外推值

  • 手动扩展数据边界

2. 性能考虑

  • 对于大量数据点(>10^5),griddata() 可能较慢,因为它需要计算 Delaunay 三角剖分

  • method='cubic''linear''nearest' 计算代价更高

  • 如果需要对同一组散点多次插值到不同网格,建议使用 LinearNDInterpolator 等类,避免重复计算三角剖分

3. 替代方案

复制代码
# 更高效的替代(适用于多次插值)
from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator

# 创建插值器
interp_linear = LinearNDInterpolator(list(zip(x, y)), z)
interp_nearest = NearestNDInterpolator(list(zip(x, y)), z)

# 然后可以重复使用
grid_z1 = interp_linear(grid_x, grid_y)
grid_z2 = interp_nearest(grid_x, grid_y)

4. 数据准备

  • 确保 points 中没有重复点,否则可能导致插值问题

  • 检查是否有 naninfvalues

  • 对于三维或更高维插值,只能使用 'linear''nearest' 方法

是的,您理解得很对!griddata() 本质上就是一种"预测器":基于已知的散点数据,预测任意新位置的值。

1. 最简单的比喻:天气预报

想象一下:

  • 您在全国各地有气象站(散点 points

  • 每个气象站测量了温度(已知值 values

  • 您想知道任意一个新地点 的温度(预测

复制代码
# 比喻:气象站数据
气象站位置 = [(北京_x, 北京_y), (上海_x, 上海_y), ...]  # 这是 points
气象站温度 = [30, 35, ...]  # 这是 values

# 您想知道南京的温度
南京位置 = (南京_x, 南京_y)  # 这是 xi
南京温度 = griddata(气象站位置, 气象站温度, 南京位置)  # 预测!

2. 实际使用案例:根据少数测量点预测整个区域

场景:土壤污染检测

python 复制代码
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import griddata

# 步骤1:已知数据 - 在几个点测量了污染物浓度
# 这些是您的"测量点"
known_points = np.array([
    [10, 20],  # 点1坐标 (x=10, y=20)
    [30, 40],  # 点2
    [50, 60],  # 点3
    [70, 80],  # 点4
    [90, 10]   # 点5
])

# 在这些点测得的污染物浓度 (ppm)
known_values = np.array([1.2, 2.5, 1.8, 3.0, 0.9])

print("已知测量点:")
for i, (point, value) in enumerate(zip(known_points, known_values)):
    print(f"  点{i+1}: 位置{point}, 浓度={value}ppm")

# 步骤2:您想知道这些位置的值(预测!)
# 比如:工厂附近、居民区等关键位置
locations_to_predict = np.array([
    [15, 25],  # 位置A
    [45, 55],  # 位置B
    [85, 15]   # 位置C
])

# 步骤3:使用 griddata 预测
predicted_values = griddata(known_points, known_values, 
                            locations_to_predict, 
                            method='linear')

print("\n预测结果:")
for loc, val in zip(locations_to_predict, predicted_values):
    print(f"  位置{loc}: 预测浓度 = {val:.2f}ppm")

# 步骤4:可视化
plt.figure(figsize=(10, 4))

# 子图1:显示测量点
plt.subplot(121)
plt.scatter(known_points[:, 0], known_points[:, 1], 
            c=known_values, s=200, cmap='Reds', edgecolors='black')
plt.colorbar(label='污染物浓度 (ppm)')
plt.scatter(locations_to_predict[:, 0], locations_to_predict[:, 1],
            c='blue', s=100, marker='x', label='待预测点')
plt.title('测量点(红色)和待预测点(蓝色)')
plt.xlabel('X坐标')
plt.ylabel('Y坐标')
plt.legend()

# 子图2:生成整个区域的预测图(热点图)
plt.subplot(122)
# 创建密集网格来预测整个区域
grid_x, grid_y = np.mgrid[0:100:100j, 0:100:100j]
grid_z = griddata(known_points, known_values, (grid_x, grid_y), 
                  method='linear', fill_value=0)

# 显示预测的热点图
plt.imshow(grid_z.T, extent=(0, 100, 0, 100), origin='lower',
           cmap='Reds', alpha=0.7)
plt.colorbar(label='预测浓度 (ppm)')
plt.scatter(known_points[:, 0], known_points[:, 1], 
            c='black', s=50, label='测量点')
plt.title('整个区域的污染物浓度预测')
plt.xlabel('X坐标')
plt.ylabel('Y坐标')
plt.legend()

plt.tight_layout()
plt.show()

3. 更简单的例子:理解输入输出

python 复制代码
import numpy as np
from scipy.interpolate import griddata

# ========== 例子1:最简单的预测 ==========
print("=== 例子1:从3个点预测新点 ===")

# 已知3个点的位置和值(就像知道3个地方的房价)
points = np.array([
    [0, 0],  # 点A:市中心
    [10, 0], # 点B:近郊
    [0, 10]  # 点C:远郊
])
values = np.array([50000, 30000, 20000])  # 房价(元/平米)

# 您想预测这些位置的价格
new_locations = np.array([
    [5, 0],   # 位置1:市中心和近郊之间
    [0, 5],   # 位置2:市中心和远郊之间
    [5, 5]    # 位置3:中间位置
])

# 预测!
predictions = griddata(points, values, new_locations, method='linear')

for i, (loc, pred) in enumerate(zip(new_locations, predictions)):
    print(f"位置{loc}: 预测房价 = ¥{pred:,.0f}/平米")

# ========== 例子2:实际应用 - 海拔高度预测 ==========
print("\n=== 例子2:登山海拔预测 ===")

# 已知几个地点的海拔高度
mountain_points = np.array([
    [0, 0, 100],    # 山脚:海拔100米
    [1, 0, 500],    # 山腰
    [0, 1, 300],    # 另一个方向
    [1, 1, 800]     # 接近山顶
])

# griddata需要分开坐标和值
coords = mountain_points[:, :2]  # 只取x,y坐标
heights = mountain_points[:, 2]  # 海拔值

# 想预测的登山路径
path_points = np.array([
    [0.2, 0.3],  # 路径点1
    [0.5, 0.5],  # 路径点2
    [0.8, 0.8]   # 路径点3
])

predicted_heights = griddata(coords, heights, path_points, method='linear')

print("登山路径海拔预测:")
for i, (point, height) in enumerate(zip(path_points, predicted_heights)):
    print(f"  路径点{i+1} (位置{point}): 海拔约{height:.0f}米")

4. griddata() 的工作流程总结

复制代码
输入:
   已知点坐标 → points = [(x1,y1), (x2,y2), ...]
   这些点的值 → values = [z1, z2, ...]
   想知道的位置 → xi = [(x_new1,y_new1), (x_new2,y_new2), ...]

处理:
   griddata() 内部:
   1. 根据已知点构建三角网格
   2. 找到每个新点在哪个三角形里
   3. 用三角形顶点的值计算新点的值(插值)

输出:
   预测值 → [z_new1, z_new2, ...]

5. 什么时候用 griddata()?什么时候用机器学习?

特点 griddata() 机器学习模型
数据量 小到中等(几百到几千点) 可处理大数据
原理 几何插值(基于距离) 统计学习
速度 预测快,但构建网格慢 训练慢,预测快
外推 不能外推(只在内插) 可以外推(但有风险)
使用场景 空间数据插值、测量数据处理 复杂模式识别、预测

简单选择指南

  • 如果您的数据是空间/地理数据 ,且点之间变化平滑 → 用 griddata()

  • 如果关系复杂、有噪声、需要预测未来 → 用机器学习

6. 实际工程中的应用模板

python 复制代码
def predict_at_new_locations(known_data, new_locations, method='linear'):
    """
    使用已知数据预测新位置的值
    
    参数:
    known_data: list of [(x1,y1,z1), (x2,y2,z2), ...] 或 (points, values)
    new_locations: list of [(x_new1,y_new1), (x_new2,y_new2), ...]
    
    返回:
    预测值列表
    """
    # 准备数据
    if isinstance(known_data, tuple):
        points, values = known_data
    else:
        known_data = np.array(known_data)
        points = known_data[:, :2]  # 前两列是坐标
        values = known_data[:, 2]   # 第三列是值
    
    new_locations = np.array(new_locations)
    
    # 预测
    predictions = griddata(points, values, new_locations, method=method)
    
    return predictions

# 使用示例
# 已知的传感器读数
sensor_data = [
    [1, 1, 25.5],  # (x,y,温度)
    [2, 3, 26.0],
    [4, 2, 24.8],
    [3, 5, 25.2]
]

# 想知道这些位置的情况
query_points = [
    [1.5, 2.0],  # 位置A
    [3.0, 3.5]   # 位置B
]

predicted_temps = predict_at_new_locations(sensor_data, query_points)
print(f"位置A预测温度: {predicted_temps[0]:.1f}°C")
print(f"位置B预测温度: {predicted_temps[1]:.1f}°C")

总结回答您的问题

是的,griddata() 就是一个预测器:

  1. 输入已知数据(在哪里, 值多少)

  2. 输入想知道的位置(新地点在哪里)

  3. 输出预测值(新地点的预测值)

就像您有了一些城市的温度数据,可以预测中间城市的温度一样简单直观!

相关推荐
拓端研究室1 小时前
2026年医药行业展望报告:创新、出海、AI医疗与商业化|附220+份报告PDF、数据、可视化模板汇总下载
大数据·人工智能
shayudiandian1 小时前
模型压缩与量化:让AI更轻更快
人工智能
LeonIter1 小时前
用回归分析为短剧APP“号脉”:我们如何找到留存的关键驱动力与产品迭代优先级?
人工智能·数据挖掘·回归
后端小张1 小时前
【AI学习】深入探秘AI之神经网络的奥秘
人工智能·深度学习·神经网络·opencv·学习·机器学习·自然语言处理
说私域2 小时前
社群经济视域下智能名片链动2+1模式商城小程序的商业价值重构
人工智能·小程序·重构·开源
xu_yule3 小时前
算法基础(数论)—费马小定理
c++·算法·裴蜀定理·欧拉定理·费马小定理·同余方程·扩展欧几里得定理
girl-07264 小时前
2025.12.28代码分析总结
算法
NAGNIP6 小时前
GPT-5.1 发布:更聪明,也更有温度的 AI
人工智能·算法
NAGNIP7 小时前
激活函数有什么用?有哪些常用的激活函数?
人工智能·算法
骚戴7 小时前
2025 Python AI 实战:零基础调用 LLM API 开发指南
人工智能·python·大模型·llm·api·ai gateway