AlignTwoPolyDatas 基于ICP算法的配准和相机视角切换

一:主要的知识点

1、说明

本文只是教程内容的一小段,因博客字数限制,故进行拆分。主教程链接:vtk教程------逐行解析官网所有Python示例-CSDN博客

2、知识点纪要

本段代码主要涉及的有①ICP模型配准,②配准结果的检测,③OBB包围盒

二:代码及注释

python 复制代码
#!/usr/bin/env python
# -*- coding: utf-8 -*-
 
import math
from pathlib import Path
 
# noinspection PyUnresolvedReferences
import vtkmodules.vtkInteractionStyle
# noinspection PyUnresolvedReferences
import vtkmodules.vtkRenderingOpenGL2
from vtkmodules.vtkCommonColor import vtkNamedColors
from vtkmodules.vtkCommonCore import (
    VTK_DOUBLE_MAX,
    vtkPoints
)
from vtkmodules.vtkCommonCore import (
    VTK_VERSION_NUMBER,
    vtkVersion
)
from vtkmodules.vtkCommonDataModel import (
    vtkIterativeClosestPointTransform,
    vtkPolyData
)
from vtkmodules.vtkCommonTransforms import (
    vtkLandmarkTransform,
    vtkTransform
)
from vtkmodules.vtkFiltersGeneral import (
    vtkOBBTree,
    vtkTransformPolyDataFilter
)
from vtkmodules.vtkFiltersModeling import vtkHausdorffDistancePointSetFilter
from vtkmodules.vtkIOGeometry import (
    vtkBYUReader,
    vtkOBJReader,
    vtkSTLReader
)
from vtkmodules.vtkIOLegacy import (
    vtkPolyDataReader,
    vtkPolyDataWriter
    )
from vtkmodules.vtkIOPLY import vtkPLYReader
from vtkmodules.vtkIOXML import vtkXMLPolyDataReader
from vtkmodules.vtkInteractionWidgets import (
    vtkCameraOrientationWidget,
    vtkOrientationMarkerWidget
)
from vtkmodules.vtkRenderingAnnotation import vtkAxesActor
from vtkmodules.vtkRenderingCore import (
    vtkActor,
    vtkDataSetMapper,
    vtkRenderWindow,
    vtkRenderWindowInteractor,
    vtkRenderer
)
 
 
def main():
    colors = vtkNamedColors()
    src_fn = "Data/thingiverse/Grey_Nurse_Shark.stl"
    tgt_fn = "Data/greatWhite.stl"
    print('Loading source:', src_fn)
    source_polydata = read_poly_data(src_fn)
    # Save the source polydata in case the alignment process does not improve
    # segmentation.
    original_source_polydata = vtkPolyData()
    original_source_polydata.DeepCopy(source_polydata)
 
    print('Loading target:', tgt_fn)
    target_polydata = read_poly_data(tgt_fn)
 
    # If the target orientation is markedly different, you may need to apply a
    # transform to orient the target with the source.
    # For example, when using Grey_Nurse_Shark.stl as the source and
    # greatWhite.stl as the target, you need to transform the target.
    trnf = vtkTransform()
    if Path(src_fn).name == 'Grey_Nurse_Shark.stl' and Path(tgt_fn).name == 'greatWhite.stl':
        trnf.RotateY(90)
 
    tpd = vtkTransformPolyDataFilter()
    tpd.SetTransform(trnf)
    tpd.SetInputData(target_polydata)
    tpd.Update()
 
    renderer = vtkRenderer()
    render_window = vtkRenderWindow()
    render_window.AddRenderer(renderer)
    interactor = vtkRenderWindowInteractor()
    interactor.SetRenderWindow(render_window)
 
    """
    对集合 A 中的每个点 a,找到 集合 B 中离它最近的点 b,记录这个最近距离。然后取 这些最近距离中的最大值。
    对集合 B 中的每个点 b,同样在 A 中找到最近的点 a,取最大值。
    两者的最大值,就是 Hausdorff 距离
    """
    distance = vtkHausdorffDistancePointSetFilter()
    distance.SetInputData(0, tpd.GetOutput())
    distance.SetInputData(1, source_polydata)
    distance.Update()
 
    """
    计算Hausdorff距离后,内部会算出最大Hausdorff距离,各点的距离分布,统计信息存在FieldData中
    GetFieldData() 取场数据(FieldData,用来存全局统计结果)
    GetArray('HausdorffDistance') 获取名字叫 HausdorffDistance 的数组
    这个数组通常只有一个值:计算出的 Hausdorff 距离
    """
    distance_before_align = distance.GetOutput(0).GetFieldData().GetArray('HausdorffDistance').GetComponent(0, 0)
 
    # Get initial alignment using oriented bounding boxes.
    align_bounding_boxes(source_polydata, tpd.GetOutput())
 
    distance.SetInputData(0, tpd.GetOutput())
    distance.SetInputData(1, source_polydata)
    distance.Modified()
    distance.Update()
    distance_after_align = distance.GetOutput(0).GetFieldData().GetArray('HausdorffDistance').GetComponent(0, 0)
 
    best_distance = min(distance_before_align, distance_after_align)
 
    if distance_after_align > distance_before_align:
        source_polydata.DeepCopy(original_source_polydata)
 
    # Refine the alignment using IterativeClosestPoint.
    icp = vtkIterativeClosestPointTransform()
    icp.SetSource(source_polydata)
    icp.SetTarget(tpd.GetOutput())
    """
    SetModeToRigidBody 将你内部用来计算点对齐的变换模式设置为刚体变换
    """
    icp.GetLandmarkTransform().SetModeToRigidBody()
    icp.SetMaximumNumberOfLandmarks(100)
    """
    SetMaximumMeanDistance  设置平均距离的收敛值。
    如果源点与目标点之间的平均距离小于这个值,算法就会认为已经收敛,并提前停止迭代
    """
    icp.SetMaximumMeanDistance(.00001)
    icp.SetMaximumNumberOfIterations(500)
    """
    CheckMeanDistanceOn  启用平均距离检查。
    ICP 在每次迭代后都检查是否达到了上面设置的 MaximumMeanDistance 阈值
    """
    icp.CheckMeanDistanceOn()
    """
    StartByMatchingCentroidsOn 启用质心匹配作为初始步骤
    在开始迭代之前,ICP 会首先计算源模型和目标模型的质心,并将源模型的质心平移到目标模型的质心位置
    """
    icp.StartByMatchingCentroidsOn()
    icp.Update()
    icp_mean_distance = icp.GetMeanDistance()
 
    # print(icp)
 
    lm_transform = icp.GetLandmarkTransform()
    transform = vtkTransformPolyDataFilter()
    transform.SetInputData(source_polydata)
    """
    SetTransform(icp)  vtkIterativeClosestPointTransform 继承自 vtkTransform
    它可以直接作为 SetTransform 的参数。这告诉过滤器:"请将ICP 算法最终计算出的所有变换(包括旋转、平移等)应用到源模型上
    """
    transform.SetTransform(lm_transform)
    transform.SetTransform(icp)
    transform.Update()
 
    distance.SetInputData(0, tpd.GetOutput())
    distance.SetInputData(1, transform.GetOutput())
    distance.Update()
 
    # Note: If there is an error extracting eigenfunctions, then this will be zero.
    distance_after_icp = distance.GetOutput(0).GetFieldData().GetArray('HausdorffDistance').GetComponent(0, 0)
 
    # Check if ICP worked.
    if not (math.isnan(icp_mean_distance) or math.isinf(icp_mean_distance)):
        if distance_after_icp < best_distance:
            best_distance = distance_after_icp
 
    print('Distances:')
    print('  Before aligning:                        {:0.5f}'.format(distance_before_align))
    print('  Aligning using oriented bounding boxes: {:0.5f}'.format(distance_before_align))
    print('  Aligning using IterativeClosestPoint:   {:0.5f}'.format(distance_after_icp))
    print('  Best distance:                          {:0.5f}'.format(best_distance))
 
    # Select the source to use.
    source_mapper = vtkDataSetMapper()
    if best_distance == distance_before_align:
        source_mapper.SetInputData(original_source_polydata)
        print('Using original alignment')
    elif best_distance == distance_after_align:
        source_mapper.SetInputData(source_polydata)
        print('Using alignment by OBB')
    else:
        source_mapper.SetInputConnection(transform.GetOutputPort())
        print('Using alignment by ICP')
    source_mapper.ScalarVisibilityOff()
 
 
    writer = vtkPolyDataWriter()
    if best_distance == distance_before_align:
        writer.SetInputData(original_source_polydata)
    elif best_distance == distance_after_align:
        writer.SetInputData(source_polydata)
    else:
        writer.SetInputData(transform.GetOutput())
    writer.SetFileName('AlignedSource.vtk')
    writer.Write()
    writer.SetInputData(tpd.GetOutput())
    writer.SetFileName('Target.vtk')
    writer.Write()
 
    source_actor = vtkActor()
    source_actor.SetMapper(source_mapper)
    source_actor.GetProperty().SetOpacity(0.6)
    source_actor.GetProperty().SetDiffuseColor(
        colors.GetColor3d('White'))
    renderer.AddActor(source_actor)
 
    target_mapper = vtkDataSetMapper()
    target_mapper.SetInputData(tpd.GetOutput())
    target_mapper.ScalarVisibilityOff()
 
    target_actor = vtkActor()
    target_actor.SetMapper(target_mapper)
    target_actor.GetProperty().SetDiffuseColor(
        colors.GetColor3d('Tomato'))
    renderer.AddActor(target_actor)
 
    render_window.AddRenderer(renderer)
    renderer.SetBackground(colors.GetColor3d("sea_green_light"))
    renderer.UseHiddenLineRemovalOn()
 
    if vtk_version_ok(9, 0, 20210718):
        try:
            cam_orient_manipulator = vtkCameraOrientationWidget()
            cam_orient_manipulator.SetParentRenderer(renderer)
            # Enable the widget.
            cam_orient_manipulator.On()
        except AttributeError:
            pass
    else:
        axes = vtkAxesActor()
        widget = vtkOrientationMarkerWidget()
        rgba = [0.0, 0.0, 0.0, 0.0]
        colors.GetColor("Carrot", rgba)
        widget.SetOutlineColor(rgba[0], rgba[1], rgba[2])
        widget.SetOrientationMarker(axes)
        widget.SetInteractor(interactor)
        widget.SetViewport(0.0, 0.0, 0.2, 0.2)
        widget.EnabledOn()
        widget.InteractiveOn()
 
    render_window.SetSize(640, 480)
    render_window.Render()
    render_window.SetWindowName('AlignTwoPolyDatas')
 
    interactor.Start()
 
 
def vtk_version_ok(major, minor, build):
    """
    Check the VTK version.
    :param major: Major version.
    :param minor: Minor version.
    :param build: Build version.
    :return: True if the requested VTK version is greater or equal to the actual VTK version.
    """
    needed_version = 10000000000 * int(major) + 100000000 * int(minor) + int(build)
    try:
        vtk_version_number = VTK_VERSION_NUMBER
    except AttributeError:  # as error:
        ver = vtkVersion()
        vtk_version_number = 10000000000 * ver.GetVTKMajorVersion() + 100000000 * ver.GetVTKMinorVersion() \
                             + ver.GetVTKBuildVersion()
    if vtk_version_number >= needed_version:
        return True
    else:
        return False
 
 
def read_poly_data(file_name):
    import os
    path, extension = os.path.splitext(file_name)
    extension = extension.lower()
    if extension == ".ply":
        reader = vtkPLYReader()
        reader.SetFileName(file_name)
        reader.Update()
        poly_data = reader.GetOutput()
    elif extension == ".vtp":
        reader = vtkXMLPolyDataReader()
        reader.SetFileName(file_name)
        reader.Update()
        poly_data = reader.GetOutput()
    elif extension == ".obj":
        reader = vtkOBJReader()
        reader.SetFileName(file_name)
        reader.Update()
        poly_data = reader.GetOutput()
    elif extension == ".stl":
        reader = vtkSTLReader()
        reader.SetFileName(file_name)
        reader.Update()
        poly_data = reader.GetOutput()
    elif extension == ".vtk":
        reader = vtkPolyDataReader()
        reader.SetFileName(file_name)
        reader.Update()
        poly_data = reader.GetOutput()
    elif extension == ".g":
        reader = vtkBYUReader()
        reader.SetGeometryFileName(file_name)
        reader.Update()
        poly_data = reader.GetOutput()
    else:
        # Return a None if the extension is unknown.
        poly_data = None
    return poly_data
 
 
def align_bounding_boxes(source, target):
    # Use OBBTree to create an oriented bounding box for target and source
    """
    vtkOBBtree
    tkOBBTree 是一种 空间加速结构,它会递归地把几何体拆分成 包围盒(OBB)层次结构
    MaxLevel 控制这个树的 最大递归深度
    这里设置成1,意思是:
    只构建 根节点 + 一层子节点 的 OBB 树
    层数越小 → 树越"粗糙",速度快但精度低
    层数越大 → 树越"细致",适合做精确的碰撞检测或最近点查询
    """
    source_obb_tree = vtkOBBTree()
    source_obb_tree.SetDataSet(source)
    source_obb_tree.SetMaxLevel(1)
    source_obb_tree.BuildLocator()
 
    target_obb_tree = vtkOBBTree()
    target_obb_tree.SetDataSet(target)
    target_obb_tree.SetMaxLevel(1)
    target_obb_tree.BuildLocator()
    """
    BuildLocator  真正构建 OBB 树的层次结构
    在设置了数据集(SetDataSet)和参数(SetMaxLevel)之后,你必须调用它来"编译"出树,
    才能进行后续查询(比如射线相交、碰撞检测、最近点搜索)
    """
 
    source_landmarks = vtkPolyData()
    source_obb_tree.GenerateRepresentation(0, source_landmarks)
    """
    GenerateRepresentation  生成包围盒的几何形状
    0:这个参数指定了要生成的表示层级。vtkOBBTree 是一种分层的包围盒,0 表示最高层级,也就是整个模型的最外层包围盒
    source_landmarks:GenerateRepresentation 方法会把生成的8个顶点和12条边(一个六面体)的几何数据填充到 source_landmarks 这个 vtkPolyData 对象中
    """
 
    target_landmarks = vtkPolyData()
    target_obb_tree.GenerateRepresentation(0, target_landmarks)
 
    """
    vtkLandmarkTransform 是 VTK 里一个 点集配准(registration) 的基础类,
    它根据两组 对应点(landmarks) 来计算一个最佳的变换矩阵(旋转、平移、可选缩放)
    它需要两组对应点:源点集(Source Landmarks)和 目标点集(Target Landmarks)
    """
    lm_transform = vtkLandmarkTransform()
    lm_transform.SetModeToSimilarity()
    """
    LandmarkTransform 支持 3 种模式
    RigidBody(刚体变换) 只允许旋转 + 平移(保持形状和大小完全不变)
    Similarity(相似变换) 允许旋转 + 平移 + 缩放(各向同性缩放,保持比例)
    Affine(仿射变换) 允许旋转 + 平移 + 缩放 + 剪切(最灵活,但可能改变形状)
    这里用 Similarity 模式,说明希望计算的变换能 匹配位置和大小,但不允许形状变形
    """
    lm_transform.SetTargetLandmarks(target_landmarks.GetPoints())
    best_distance = VTK_DOUBLE_MAX
    best_points = vtkPoints()
    best_distance = best_bounding_box(
        "X",
        target,
        source,
        target_landmarks,
        source_landmarks,
        best_distance,
        best_points)
    best_distance = best_bounding_box(
        "Y",
        target,
        source,
        target_landmarks,
        source_landmarks,
        best_distance,
        best_points)
    best_distance = best_bounding_box(
        "Z",
        target,
        source,
        target_landmarks,
        source_landmarks,
        best_distance,
        best_points)
 
    lm_transform.SetSourceLandmarks(best_points)
    lm_transform.Modified()
 
    lm_transform_pd = vtkTransformPolyDataFilter()
    lm_transform_pd.SetInputData(source)
    lm_transform_pd.SetTransform(lm_transform)
    lm_transform_pd.Update()
 
    source.DeepCopy(lm_transform_pd.GetOutput())
 
    return
 
 
def best_bounding_box(axis, target, source, target_landmarks, source_landmarks, best_distance, best_points):
    """
    这个函数的目的在于在粗略对齐阶段,
    通过系统性地测试不同旋转角度,来找到源模型和目标模型之间最佳的初始对齐
    """
 
    distance = vtkHausdorffDistancePointSetFilter()
    test_transform = vtkTransform()
 
    """
    vtkTransformPolyDataFilter 是将一个 vtkPolyData 数据集应用几何变换,然后生成一个新的、已变换的 vtkPolyData 对象
    """
    test_transform_pd = vtkTransformPolyDataFilter()
    """
    vtkLandmarkTransform 用于两个点集的配准
    是vtkTransform的子类,可以根据源点集和目标点击来计算一个最佳的变换矩阵
    """
    lm_transform = vtkLandmarkTransform()
    lm_transform.SetModeToSimilarity()
    lm_transform.SetTargetLandmarks(target_landmarks.GetPoints())
 
    lm_transform_pd = vtkTransformPolyDataFilter()
 
    source_center = source_landmarks.GetCenter()
 
    delta = 90.0
    for i in range(0, 4):
        angle = delta * i
        # Rotate about center
        test_transform.Identity()
        test_transform.Translate(source_center[0], source_center[1], source_center[2])
        if axis == "X":
            test_transform.RotateX(angle)
        elif axis == "Y":
            test_transform.RotateY(angle)
        else:
            test_transform.RotateZ(angle)
        test_transform.Translate(-source_center[0], -source_center[1], -source_center[2])
 
        test_transform_pd.SetTransform(test_transform)
        test_transform_pd.SetInputData(source_landmarks)
        test_transform_pd.Update()
 
        lm_transform.SetSourceLandmarks(test_transform_pd.GetOutput().GetPoints())
        lm_transform.Modified()
 
        lm_transform_pd.SetInputData(source)
        lm_transform_pd.SetTransform(lm_transform)
        lm_transform_pd.Update()
 
        distance.SetInputData(0, target)
        distance.SetInputData(1, lm_transform_pd.GetOutput())
        distance.Update()
 
        test_distance = distance.GetOutput(0).GetFieldData().GetArray("HausdorffDistance").GetComponent(0, 0)
        if test_distance < best_distance:
            best_distance = test_distance
            best_points.DeepCopy(test_transform_pd.GetOutput().GetPoints())
 
    return best_distance
 
 
if __name__ == '__main__':
    main()
相关推荐
嗝o゚6 小时前
Flutter与开源鸿蒙:一场“应用定义权”的静默战争,与开发者的“范式跃迁”机会
python·flutter
一只会奔跑的小橙子6 小时前
pytest安装对应的库的方法
python
ohoy6 小时前
EasyPoi 数据脱敏
开发语言·python·excel
BoBoZz196 小时前
MarchingCubes 网格数据体素化并提取等值面
python·vtk·图形渲染·图形处理
ekprada7 小时前
DAY36 复习日
开发语言·python·机器学习
爱笑的眼睛117 小时前
强化学习组件:超越Hello World的架构级思考与实践
java·人工智能·python·ai
Boxsc_midnight7 小时前
【规范驱动的开发方式】之【spec-kit】 的安装入门指南
人工智能·python·深度学习·软件工程·设计规范
条件漫步7 小时前
Miniconda config channels的查看、删除、添加
python
爱笑的眼睛117 小时前
深入解析PyTorch nn模块:超越基础模型构建的高级技巧与实践
java·人工智能·python·ai