一:主要的知识点
1、说明
本文只是教程内容的一小段,因博客字数限制,故进行拆分。主教程链接:vtk教程------逐行解析官网所有Python示例-CSDN博客
2、知识点纪要
本段代码主要涉及的有①ICP模型配准,②配准结果的检测,③OBB包围盒
二:代码及注释
python
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import math
from pathlib import Path
# noinspection PyUnresolvedReferences
import vtkmodules.vtkInteractionStyle
# noinspection PyUnresolvedReferences
import vtkmodules.vtkRenderingOpenGL2
from vtkmodules.vtkCommonColor import vtkNamedColors
from vtkmodules.vtkCommonCore import (
VTK_DOUBLE_MAX,
vtkPoints
)
from vtkmodules.vtkCommonCore import (
VTK_VERSION_NUMBER,
vtkVersion
)
from vtkmodules.vtkCommonDataModel import (
vtkIterativeClosestPointTransform,
vtkPolyData
)
from vtkmodules.vtkCommonTransforms import (
vtkLandmarkTransform,
vtkTransform
)
from vtkmodules.vtkFiltersGeneral import (
vtkOBBTree,
vtkTransformPolyDataFilter
)
from vtkmodules.vtkFiltersModeling import vtkHausdorffDistancePointSetFilter
from vtkmodules.vtkIOGeometry import (
vtkBYUReader,
vtkOBJReader,
vtkSTLReader
)
from vtkmodules.vtkIOLegacy import (
vtkPolyDataReader,
vtkPolyDataWriter
)
from vtkmodules.vtkIOPLY import vtkPLYReader
from vtkmodules.vtkIOXML import vtkXMLPolyDataReader
from vtkmodules.vtkInteractionWidgets import (
vtkCameraOrientationWidget,
vtkOrientationMarkerWidget
)
from vtkmodules.vtkRenderingAnnotation import vtkAxesActor
from vtkmodules.vtkRenderingCore import (
vtkActor,
vtkDataSetMapper,
vtkRenderWindow,
vtkRenderWindowInteractor,
vtkRenderer
)
def main():
colors = vtkNamedColors()
src_fn = "Data/thingiverse/Grey_Nurse_Shark.stl"
tgt_fn = "Data/greatWhite.stl"
print('Loading source:', src_fn)
source_polydata = read_poly_data(src_fn)
# Save the source polydata in case the alignment process does not improve
# segmentation.
original_source_polydata = vtkPolyData()
original_source_polydata.DeepCopy(source_polydata)
print('Loading target:', tgt_fn)
target_polydata = read_poly_data(tgt_fn)
# If the target orientation is markedly different, you may need to apply a
# transform to orient the target with the source.
# For example, when using Grey_Nurse_Shark.stl as the source and
# greatWhite.stl as the target, you need to transform the target.
trnf = vtkTransform()
if Path(src_fn).name == 'Grey_Nurse_Shark.stl' and Path(tgt_fn).name == 'greatWhite.stl':
trnf.RotateY(90)
tpd = vtkTransformPolyDataFilter()
tpd.SetTransform(trnf)
tpd.SetInputData(target_polydata)
tpd.Update()
renderer = vtkRenderer()
render_window = vtkRenderWindow()
render_window.AddRenderer(renderer)
interactor = vtkRenderWindowInteractor()
interactor.SetRenderWindow(render_window)
"""
对集合 A 中的每个点 a,找到 集合 B 中离它最近的点 b,记录这个最近距离。然后取 这些最近距离中的最大值。
对集合 B 中的每个点 b,同样在 A 中找到最近的点 a,取最大值。
两者的最大值,就是 Hausdorff 距离
"""
distance = vtkHausdorffDistancePointSetFilter()
distance.SetInputData(0, tpd.GetOutput())
distance.SetInputData(1, source_polydata)
distance.Update()
"""
计算Hausdorff距离后,内部会算出最大Hausdorff距离,各点的距离分布,统计信息存在FieldData中
GetFieldData() 取场数据(FieldData,用来存全局统计结果)
GetArray('HausdorffDistance') 获取名字叫 HausdorffDistance 的数组
这个数组通常只有一个值:计算出的 Hausdorff 距离
"""
distance_before_align = distance.GetOutput(0).GetFieldData().GetArray('HausdorffDistance').GetComponent(0, 0)
# Get initial alignment using oriented bounding boxes.
align_bounding_boxes(source_polydata, tpd.GetOutput())
distance.SetInputData(0, tpd.GetOutput())
distance.SetInputData(1, source_polydata)
distance.Modified()
distance.Update()
distance_after_align = distance.GetOutput(0).GetFieldData().GetArray('HausdorffDistance').GetComponent(0, 0)
best_distance = min(distance_before_align, distance_after_align)
if distance_after_align > distance_before_align:
source_polydata.DeepCopy(original_source_polydata)
# Refine the alignment using IterativeClosestPoint.
icp = vtkIterativeClosestPointTransform()
icp.SetSource(source_polydata)
icp.SetTarget(tpd.GetOutput())
"""
SetModeToRigidBody 将你内部用来计算点对齐的变换模式设置为刚体变换
"""
icp.GetLandmarkTransform().SetModeToRigidBody()
icp.SetMaximumNumberOfLandmarks(100)
"""
SetMaximumMeanDistance 设置平均距离的收敛值。
如果源点与目标点之间的平均距离小于这个值,算法就会认为已经收敛,并提前停止迭代
"""
icp.SetMaximumMeanDistance(.00001)
icp.SetMaximumNumberOfIterations(500)
"""
CheckMeanDistanceOn 启用平均距离检查。
ICP 在每次迭代后都检查是否达到了上面设置的 MaximumMeanDistance 阈值
"""
icp.CheckMeanDistanceOn()
"""
StartByMatchingCentroidsOn 启用质心匹配作为初始步骤
在开始迭代之前,ICP 会首先计算源模型和目标模型的质心,并将源模型的质心平移到目标模型的质心位置
"""
icp.StartByMatchingCentroidsOn()
icp.Update()
icp_mean_distance = icp.GetMeanDistance()
# print(icp)
lm_transform = icp.GetLandmarkTransform()
transform = vtkTransformPolyDataFilter()
transform.SetInputData(source_polydata)
"""
SetTransform(icp) vtkIterativeClosestPointTransform 继承自 vtkTransform
它可以直接作为 SetTransform 的参数。这告诉过滤器:"请将ICP 算法最终计算出的所有变换(包括旋转、平移等)应用到源模型上
"""
transform.SetTransform(lm_transform)
transform.SetTransform(icp)
transform.Update()
distance.SetInputData(0, tpd.GetOutput())
distance.SetInputData(1, transform.GetOutput())
distance.Update()
# Note: If there is an error extracting eigenfunctions, then this will be zero.
distance_after_icp = distance.GetOutput(0).GetFieldData().GetArray('HausdorffDistance').GetComponent(0, 0)
# Check if ICP worked.
if not (math.isnan(icp_mean_distance) or math.isinf(icp_mean_distance)):
if distance_after_icp < best_distance:
best_distance = distance_after_icp
print('Distances:')
print(' Before aligning: {:0.5f}'.format(distance_before_align))
print(' Aligning using oriented bounding boxes: {:0.5f}'.format(distance_before_align))
print(' Aligning using IterativeClosestPoint: {:0.5f}'.format(distance_after_icp))
print(' Best distance: {:0.5f}'.format(best_distance))
# Select the source to use.
source_mapper = vtkDataSetMapper()
if best_distance == distance_before_align:
source_mapper.SetInputData(original_source_polydata)
print('Using original alignment')
elif best_distance == distance_after_align:
source_mapper.SetInputData(source_polydata)
print('Using alignment by OBB')
else:
source_mapper.SetInputConnection(transform.GetOutputPort())
print('Using alignment by ICP')
source_mapper.ScalarVisibilityOff()
writer = vtkPolyDataWriter()
if best_distance == distance_before_align:
writer.SetInputData(original_source_polydata)
elif best_distance == distance_after_align:
writer.SetInputData(source_polydata)
else:
writer.SetInputData(transform.GetOutput())
writer.SetFileName('AlignedSource.vtk')
writer.Write()
writer.SetInputData(tpd.GetOutput())
writer.SetFileName('Target.vtk')
writer.Write()
source_actor = vtkActor()
source_actor.SetMapper(source_mapper)
source_actor.GetProperty().SetOpacity(0.6)
source_actor.GetProperty().SetDiffuseColor(
colors.GetColor3d('White'))
renderer.AddActor(source_actor)
target_mapper = vtkDataSetMapper()
target_mapper.SetInputData(tpd.GetOutput())
target_mapper.ScalarVisibilityOff()
target_actor = vtkActor()
target_actor.SetMapper(target_mapper)
target_actor.GetProperty().SetDiffuseColor(
colors.GetColor3d('Tomato'))
renderer.AddActor(target_actor)
render_window.AddRenderer(renderer)
renderer.SetBackground(colors.GetColor3d("sea_green_light"))
renderer.UseHiddenLineRemovalOn()
if vtk_version_ok(9, 0, 20210718):
try:
cam_orient_manipulator = vtkCameraOrientationWidget()
cam_orient_manipulator.SetParentRenderer(renderer)
# Enable the widget.
cam_orient_manipulator.On()
except AttributeError:
pass
else:
axes = vtkAxesActor()
widget = vtkOrientationMarkerWidget()
rgba = [0.0, 0.0, 0.0, 0.0]
colors.GetColor("Carrot", rgba)
widget.SetOutlineColor(rgba[0], rgba[1], rgba[2])
widget.SetOrientationMarker(axes)
widget.SetInteractor(interactor)
widget.SetViewport(0.0, 0.0, 0.2, 0.2)
widget.EnabledOn()
widget.InteractiveOn()
render_window.SetSize(640, 480)
render_window.Render()
render_window.SetWindowName('AlignTwoPolyDatas')
interactor.Start()
def vtk_version_ok(major, minor, build):
"""
Check the VTK version.
:param major: Major version.
:param minor: Minor version.
:param build: Build version.
:return: True if the requested VTK version is greater or equal to the actual VTK version.
"""
needed_version = 10000000000 * int(major) + 100000000 * int(minor) + int(build)
try:
vtk_version_number = VTK_VERSION_NUMBER
except AttributeError: # as error:
ver = vtkVersion()
vtk_version_number = 10000000000 * ver.GetVTKMajorVersion() + 100000000 * ver.GetVTKMinorVersion() \
+ ver.GetVTKBuildVersion()
if vtk_version_number >= needed_version:
return True
else:
return False
def read_poly_data(file_name):
import os
path, extension = os.path.splitext(file_name)
extension = extension.lower()
if extension == ".ply":
reader = vtkPLYReader()
reader.SetFileName(file_name)
reader.Update()
poly_data = reader.GetOutput()
elif extension == ".vtp":
reader = vtkXMLPolyDataReader()
reader.SetFileName(file_name)
reader.Update()
poly_data = reader.GetOutput()
elif extension == ".obj":
reader = vtkOBJReader()
reader.SetFileName(file_name)
reader.Update()
poly_data = reader.GetOutput()
elif extension == ".stl":
reader = vtkSTLReader()
reader.SetFileName(file_name)
reader.Update()
poly_data = reader.GetOutput()
elif extension == ".vtk":
reader = vtkPolyDataReader()
reader.SetFileName(file_name)
reader.Update()
poly_data = reader.GetOutput()
elif extension == ".g":
reader = vtkBYUReader()
reader.SetGeometryFileName(file_name)
reader.Update()
poly_data = reader.GetOutput()
else:
# Return a None if the extension is unknown.
poly_data = None
return poly_data
def align_bounding_boxes(source, target):
# Use OBBTree to create an oriented bounding box for target and source
"""
vtkOBBtree
tkOBBTree 是一种 空间加速结构,它会递归地把几何体拆分成 包围盒(OBB)层次结构
MaxLevel 控制这个树的 最大递归深度
这里设置成1,意思是:
只构建 根节点 + 一层子节点 的 OBB 树
层数越小 → 树越"粗糙",速度快但精度低
层数越大 → 树越"细致",适合做精确的碰撞检测或最近点查询
"""
source_obb_tree = vtkOBBTree()
source_obb_tree.SetDataSet(source)
source_obb_tree.SetMaxLevel(1)
source_obb_tree.BuildLocator()
target_obb_tree = vtkOBBTree()
target_obb_tree.SetDataSet(target)
target_obb_tree.SetMaxLevel(1)
target_obb_tree.BuildLocator()
"""
BuildLocator 真正构建 OBB 树的层次结构
在设置了数据集(SetDataSet)和参数(SetMaxLevel)之后,你必须调用它来"编译"出树,
才能进行后续查询(比如射线相交、碰撞检测、最近点搜索)
"""
source_landmarks = vtkPolyData()
source_obb_tree.GenerateRepresentation(0, source_landmarks)
"""
GenerateRepresentation 生成包围盒的几何形状
0:这个参数指定了要生成的表示层级。vtkOBBTree 是一种分层的包围盒,0 表示最高层级,也就是整个模型的最外层包围盒
source_landmarks:GenerateRepresentation 方法会把生成的8个顶点和12条边(一个六面体)的几何数据填充到 source_landmarks 这个 vtkPolyData 对象中
"""
target_landmarks = vtkPolyData()
target_obb_tree.GenerateRepresentation(0, target_landmarks)
"""
vtkLandmarkTransform 是 VTK 里一个 点集配准(registration) 的基础类,
它根据两组 对应点(landmarks) 来计算一个最佳的变换矩阵(旋转、平移、可选缩放)
它需要两组对应点:源点集(Source Landmarks)和 目标点集(Target Landmarks)
"""
lm_transform = vtkLandmarkTransform()
lm_transform.SetModeToSimilarity()
"""
LandmarkTransform 支持 3 种模式
RigidBody(刚体变换) 只允许旋转 + 平移(保持形状和大小完全不变)
Similarity(相似变换) 允许旋转 + 平移 + 缩放(各向同性缩放,保持比例)
Affine(仿射变换) 允许旋转 + 平移 + 缩放 + 剪切(最灵活,但可能改变形状)
这里用 Similarity 模式,说明希望计算的变换能 匹配位置和大小,但不允许形状变形
"""
lm_transform.SetTargetLandmarks(target_landmarks.GetPoints())
best_distance = VTK_DOUBLE_MAX
best_points = vtkPoints()
best_distance = best_bounding_box(
"X",
target,
source,
target_landmarks,
source_landmarks,
best_distance,
best_points)
best_distance = best_bounding_box(
"Y",
target,
source,
target_landmarks,
source_landmarks,
best_distance,
best_points)
best_distance = best_bounding_box(
"Z",
target,
source,
target_landmarks,
source_landmarks,
best_distance,
best_points)
lm_transform.SetSourceLandmarks(best_points)
lm_transform.Modified()
lm_transform_pd = vtkTransformPolyDataFilter()
lm_transform_pd.SetInputData(source)
lm_transform_pd.SetTransform(lm_transform)
lm_transform_pd.Update()
source.DeepCopy(lm_transform_pd.GetOutput())
return
def best_bounding_box(axis, target, source, target_landmarks, source_landmarks, best_distance, best_points):
"""
这个函数的目的在于在粗略对齐阶段,
通过系统性地测试不同旋转角度,来找到源模型和目标模型之间最佳的初始对齐
"""
distance = vtkHausdorffDistancePointSetFilter()
test_transform = vtkTransform()
"""
vtkTransformPolyDataFilter 是将一个 vtkPolyData 数据集应用几何变换,然后生成一个新的、已变换的 vtkPolyData 对象
"""
test_transform_pd = vtkTransformPolyDataFilter()
"""
vtkLandmarkTransform 用于两个点集的配准
是vtkTransform的子类,可以根据源点集和目标点击来计算一个最佳的变换矩阵
"""
lm_transform = vtkLandmarkTransform()
lm_transform.SetModeToSimilarity()
lm_transform.SetTargetLandmarks(target_landmarks.GetPoints())
lm_transform_pd = vtkTransformPolyDataFilter()
source_center = source_landmarks.GetCenter()
delta = 90.0
for i in range(0, 4):
angle = delta * i
# Rotate about center
test_transform.Identity()
test_transform.Translate(source_center[0], source_center[1], source_center[2])
if axis == "X":
test_transform.RotateX(angle)
elif axis == "Y":
test_transform.RotateY(angle)
else:
test_transform.RotateZ(angle)
test_transform.Translate(-source_center[0], -source_center[1], -source_center[2])
test_transform_pd.SetTransform(test_transform)
test_transform_pd.SetInputData(source_landmarks)
test_transform_pd.Update()
lm_transform.SetSourceLandmarks(test_transform_pd.GetOutput().GetPoints())
lm_transform.Modified()
lm_transform_pd.SetInputData(source)
lm_transform_pd.SetTransform(lm_transform)
lm_transform_pd.Update()
distance.SetInputData(0, target)
distance.SetInputData(1, lm_transform_pd.GetOutput())
distance.Update()
test_distance = distance.GetOutput(0).GetFieldData().GetArray("HausdorffDistance").GetComponent(0, 0)
if test_distance < best_distance:
best_distance = test_distance
best_points.DeepCopy(test_transform_pd.GetOutput().GetPoints())
return best_distance
if __name__ == '__main__':
main()