梯度下降法解决2D映射3D

本人只是业余人士,无意间发现的方法,发出来共同学习

1. 数据准备

这部分不是文章重点,就写随意点了

这块不详细说,总之现在手上有相机内外参数、一个折线在三维空间的坐标、该折线在2张2D图中的坐标(测试数据是由3D到2D映射得到)。

内外参先进行合并,得到点云坐标系到像素坐标系的4*4仿射变换矩阵

python 复制代码
transform_matrix_list = []
for c in camera_config:
    # 外参,4*4矩阵
    c_ext = np.array(c['camera_external']).reshape(4, 4)
    
    # 内参,也写成4*4矩阵和外参对齐
    temp_int = c["camera_internal"]
    c_int = np.array(
        [
            [temp_int["fx"], 0, temp_int["cx"], 0],
            [0, temp_int["fy"], temp_int["cy"], 0],
            [0, 0, 1, 0],
            [0, 0, 0, 1]
        ]
    )
    transform_matrix_list.append(c_int @ c_ext)

准备好的数据,一共2条2D线,即total变量里有2份数据,每份数据包含4*4矩阵和折线的坐标,我准备的测试折线画了4个点,所以array是4行。

python 复制代码
for mat, line in total:
    print(mat, '\n', line)
"""
[[ 1.93561303e+03 -1.84477437e+03 -8.10791540e+00  3.72401360e+03]
 [ 1.07129123e+03  6.77998530e+00 -1.86140364e+03  4.47658540e+03]
 [ 9.99972000e-01  6.28300000e-03 -4.18900000e-03  1.95100000e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  1.00000000e+00]] 
 [[1101.8387 1573.1091]
 [1154.8978 1321.6198]
 [ 775.565  1264.7661]
 [ 315.3221 1350.395 ]]
[[ 1.47402852e+03 -2.74023717e+01  1.87097460e+01  2.45415994e+03]
 [ 4.40832799e+02  4.77015959e+02 -1.10192568e+03  1.40179194e+03]
 [ 6.71733000e-01  7.40679000e-01  1.30110000e-02  3.79660000e-01]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  1.00000000e+00]] 
 [[1706.4499  953.3121]
 [1604.5224  789.089 ]
 [1343.5587  739.4855]
 [1157.7488  761.7332]]
"""

2. 函数准备

思路很简单,将3D坐标依次投影到所有图像里,得到所有投影2D点。选欧氏距离作为损失函数的值,依次计算投影2D点与真实2D点的欧氏距离,再求和。

注意这里使用autograd库得到梯度函数,可以提点速。

python 复制代码
from itertools import product

import pandas as pd
import autograd.numpy as np
from autograd import grad
from scipy.optimize import minimize

def proj2d(
    transform_matrix,
    pts
):
    """
    3D投影到2D
    """
    pts4d = np.vstack([pts.T, np.ones(shape=(1, len(pts)))])
    temp = transform_matrix @ pts4d
    
    return (temp / temp[-2])[:-2, :].T

def total_distance(
    x0,
    info2d=total
):
    """
    损失函数,所有投影2D点与真实2D点欧氏距离的和
    """
    pts3d = x0.reshape(-1, 3)
    total_dist = 0
    for trans_mat, pts2d in info2d:
        total_dist += np.linalg.norm(proj2d(trans_mat, pts3d) - pts2d)

    return total_dist

grad_total_distance = grad(total_distance)

3. 开始优化

首先,局部优化算法都容易陷入局部最小值,所以人为选一些初始点。比如本次测试就是先全部取0,再取-10,再取10,再取-20,再取20,以此类推。

关于x0的形状,这里图方便直接写3*4了,意思是优化4个点的xyz坐标,共12个自变量的值。其实应该把4改成每次投影的点数。

实测jac参数使用autograd库计算出的一阶梯度会比scipy用数值近似快一些。

options参数比较关键,这里只填了xrtol,该参数默认值0,意思是每次迭代自变量的变化达到xrtol就停止迭代。如果该参数设置过小、或是别的任何tol参数设置过小,minimize参函数都很可能报这个错:Desired error not necessarily achieved due to precision loss。可能是因为它担心在有数值精度损失的情况下,怎么优化都达不到预期的tol值。

python 复制代码
for scale, sign in product(range(0, 101, 10), [-1, 1]):
    result = minimize(
        fun=total_distance,
        x0=np.ones(3 * 4) * scale * sign, 
        args=total, 
        method="BFGS",
        jac=grad_total_distance,
        options={"xrtol": 1e-4}
    )
    if result.fun < 1:
        break
result
"""
  message: Optimization terminated successfully.
  success: True
   status: 0
      fun: 0.20395684683401033
        x: [ 2.793e+00  2.108e+00 -2.658e-04  7.559e+00  3.981e+00
            -4.512e-05  1.034e+01  7.683e+00 -6.337e-04  6.555e+00
             7.434e+00  3.803e-06]
      nit: 118
      jac: [ 6.805e+01 -2.646e+02 -1.614e+02  2.468e+01 -8.057e+01
            -4.063e+01 -6.598e+01  9.998e+01 -6.633e+01  1.596e+02
            -1.906e+02 -2.491e+01]
 hess_inv: [[ 1.480e-04  5.076e-05 ...  4.420e-05 -3.787e-06]
            [ 5.076e-05  1.931e-05 ...  2.220e-05 -3.856e-06]
            ...
            [ 4.420e-05  2.220e-05 ...  2.278e-03 -3.827e-04]
            [-3.787e-06 -3.856e-06 ... -3.827e-04  7.910e-05]]
     nfev: 155
     njev: 155
"""
相关推荐
happybasic1 小时前
通过纯文字引导DeepSeek编写一个简单的中国象棋引擎~
人工智能·python·中国象棋·deepseek
夜幕龙1 小时前
Dexcap复现代码数据预处理全流程(四)——demo_clipping_3d.py
人工智能·python·机器人
Tomorrow'sThinker5 小时前
25年1月更新。Windows 上搭建 Python 开发环境:PyCharm 安装全攻略(文中有安装包不用官网下载)
ide·python·pycharm
noravinsc5 小时前
requests请求带cookie
开发语言·python·pycharm
风_流沙7 小时前
python pandas 对mysql 一些常见操作
python·mysql·pandas
qq_273900238 小时前
pytorch torch.scatter_reduce函数介绍
人工智能·pytorch·python
ouyang_ouba8 小时前
pygame飞机大战
开发语言·python·pygame
小码贾8 小时前
OpenCV-Python实战(15)——像素直方图均衡画
人工智能·python·opencv
chusheng18409 小时前
基于 Python Django 的社区爱心养老系统
开发语言·python·django·社区爱心养老系统·python 爱心养老系统·python 社区养老
Better Bench10 小时前
【Python实现连续学习算法】复现2018年ECCV经典算法RWalk
python·连续学习·路径优化·增量学习·路径积分·重要性矩阵·记忆保持