Wirtinger Flow算法的matlab实现和python实现

文章目录

  • [1. 数学模型](#1. 数学模型)
  • [2. Wirtinger Flow 算法](#2. Wirtinger Flow 算法)
    • [2.1. 光谱初始化方法](#2.1. 光谱初始化方法)
    • [2.2. Wirtinger梯度下降](#2.2. Wirtinger梯度下降)
  • [3. 算法实现](#3. 算法实现)
    • [3.1. Matlab实现](#3.1. Matlab实现)
    • [3.2. Python实现](#3.2. Python实现)
  • 参考文献

1. 数学模型

观测数学模型可由下面公式给出
y = ∣ A x ∣ 2 y = |Ax|^2 y=∣Ax∣2

其中 x ∈ C n x\in\mathbb C^{n} x∈Cn, A ∈ C m × n A\in\mathbb C^{m\times n} A∈Cm×n, y ∈ R m y\in\mathbb R^{m} y∈Rm。

所以我们要求解的问题可归为如下非凸最小二乘问题
min ⁡ z ∈ C n    f ( z ) = 1 2 ∥   ∣ A z ∣ 2 − y ∥ 2 2 \min_{z\in\mathbb C^{n}} \;f(z)=\frac12\Bigl\|\,|Az|^{ 2}-y\Bigr\|_2^{2} z∈Cnminf(z)=21 ∣Az∣2−y 22

2. Wirtinger Flow 算法

该算法可以总结成两步:1. 光谱初始化 2. Wirtinger梯度下降

2.1. 光谱初始化方法

具体步骤如下:

  1. 能量标定系数
    λ 2    =    n   1 m  ⁣ ⊤ y ∥ A ∥ F 2 \lambda^{2} \;=\; n\,\frac{\mathbf 1_{m}^{\!\top}y}{\|A\|_{F}^{2}} λ2=n∥A∥F21m⊤y

  2. 构造自伴矩阵
    Y    =    1 m   A ∗ diag ⁡ ( y )   A Y \;=\; \frac1m\,A^{*}\operatorname{diag}(y)\,A Y=m1A∗diag(y)A

    求得其最大特征向量 v v v

  3. 缩放得到初始点
    z 0 = λ   v z_{0}=\lambda\,v z0=λv

2.2. Wirtinger梯度下降

更新公式如下:
z τ + 1 =    z τ    −    μ τ + 1 ∥ z 0 ∥ 2 2    ( 1 m   A ∗ [   ( ∣ A z τ ∣ 2 − y )    ⊙    ( A z τ ) ] ) ⏟ ∇ f ( z τ ) z_{\tau+1} =\;z_\tau\;-\;\frac{\mu_{\tau+1}}{\|z_0\|{2}^{2}}\; \underbrace{\Bigl(\frac1m\,A^{*}\bigl[\,(|A z\tau|^{2}-y)\;⊙\;(A z_\tau)\bigr]\Bigr)}{\nabla f(z\tau)} zτ+1=zτ−∥z0∥22μτ+1∇f(zτ) (m1A∗[(∣Azτ∣2−y)⊙(Azτ)])

公式中的 μ \mu μ更新根据经验公式
μ τ = min ⁡ ( 1 − exp ⁡ ( − τ / τ 0 ) ,   0.2 ) , τ 0 ≈ 330 \mu_\tau=\min(1-\exp(-\tau/\tau_0),\,0.2),\quad \tau_0≈330 μτ=min(1−exp(−τ/τ0),0.2),τ0≈330

3. 算法实现

3.1. Matlab实现

复制代码
clear; close all; clc
%% Measurement model 
% Signal length
n = 128; 
% Complex signal
x = randn(n,1) + 1i*randn(n,1);                     

% measurement number  
m = 5 * n; 
% Measurement matrix
A = 1/sqrt(2)*randn(m,n) + 1i/sqrt(2)*randn(m,n); 
% Measured values
y = abs(A*x).^2 ;                                   

%% Initialization
% power method to get the initial guess
npower_iter = 50;

% Scaled coefficient lambda
lam = sqrt(n * sum(y) / norm(A, 'fro')^2);

% Random input
z0 = randn(n,1); z0 = z0/norm(z0,'fro');
for tt = 1:npower_iter
    z0 = 1/m * A'*(y .* (A*z0));
    z0 = z0/norm(z0,'fro');
end

% Initialized ouput
z = lam * z0;

%% Gradient update
% Max number of iterations
max_iter = 2500;

% update mu
tau0 = 330;                         
mu = @(t) min(1-exp(-t/tau0), 0.2); 

% Store relative errors
relative_error = zeros(max_iter, 1);

for tt = 1:max_iter  
    Az = A*z;

    % Wirtinger gradient
    grad  = 1/m* A'*( ( abs(Az).^2-y ) .* Az ); 

    % ||z0||=lam
    z = z - mu(tt)/lam^2 * grad;            

    % Calculate relative error value
    relative_error_val = norm(x - exp(-1i*angle(trace(x'*z))) * z, 'fro')/norm(x,'fro');
    relative_error(tt) = relative_error_val;  
end
%%
figure,semilogy(relative_error,'LineWidth',1.8, 'Color',[0 0.4470 0.7410])
xlabel('Iteration','FontSize',16,'FontName','Times New Roman')
ylabel('Relative error','FontSize',16,'FontName','Times New Roman')
title('Wirtinger Flow Convergence','FontSize',16,'FontWeight','bold')

3.2. Python实现

复制代码
import numpy as np
import matplotlib.pyplot as plt

n = 128                                        # 信号长度
x = np.random.randn(n) + 1j * np.random.randn(n)   # 复值真信号

m = 5 * n                                      # 测量数
A = (np.random.randn(m, n) + 1j * np.random.randn(m, n)) / np.sqrt(2)

y = np.abs(A @ x) ** 2                         # 强度观测 |Ax|^2

# ---------- Initialization (power method) ------------------------------------
npower_iter = 50                               # 幂迭代次数
lam = np.sqrt(n * y.sum() / np.linalg.norm(A, "fro") ** 2)  # λ

z0 = np.random.randn(n) + 1j * np.random.randn(n)
z0 /= np.linalg.norm(z0)

for _ in range(npower_iter):
    z0 = (A.conj().T @ (y * (A @ z0))) / m
    z0 /= np.linalg.norm(z0)

z = lam * z0                                   # 初值

# ---------- Gradient update ---------------------------------------------------
max_iter = 2500
tau0 = 330.0
rel_err = np.zeros(max_iter)

for tt in range(max_iter):
    mu = min(1.0 - np.exp(-(tt + 1) / tau0), 0.2)   # 步长 μ_t

    Az = A @ z
    grad = (A.conj().T @ ((np.abs(Az) ** 2 - y) * Az)) / m
    z = z - (mu / lam ** 2) * grad

    # 相对误差
    theta = np.angle(np.vdot(x, z))           # vdot = x* · z
    rel_err[tt] = np.linalg.norm(x - np.exp(-1j * theta) * z) / np.linalg.norm(x)

# ---------- plot -------------------------------------------------------
plt.figure(figsize=(6.2, 4.2), facecolor="w")
plt.semilogy(rel_err, lw=1.8, color=(0.0, 0.447, 0.741))
plt.xlabel("Iteration", fontsize=13)
plt.ylabel("Relative error", fontsize=13)
plt.title("Wirtinger Flow Convergence (1-D)", fontsize=15, weight="bold")
plt.grid(ls="--", alpha=0.3)
plt.tight_layout()
plt.show()

参考文献

Candes E J, Li X, Soltanolkotabi M. Phase retrieval via Wirtinger flow: Theory and algorithms[J]. IEEE Transactions on Information Theory, 2015, 61(4): 1985-2007.

相关推荐
执携1 小时前
Python(6) -- 数据容器
开发语言·python
天才测试猿2 小时前
Jmeter+ant+jenkins接口自动化测试框架
自动化测试·软件测试·python·jmeter·职场和发展·jenkins·接口测试
努力冲冲3 小时前
常用排序算法
java·算法·排序算法
计算机毕业编程指导师5 小时前
毕业设计选题推荐之基于Spark的在线教育投融数据可视化分析系统 |爬虫|大数据|大屏|预测|深度学习|数据分析|数据挖掘
大数据·hadoop·python·数据挖掘·spark·毕业设计·在线教育投融
夜斗小神社5 小时前
【LeetCode 热题 100】(六)矩阵
算法·leetcode·矩阵
天地一流殇6 小时前
SimBA算法实现过程
深度学习·算法·对抗攻击·黑盒
花酒锄作田7 小时前
[Python][Go]比较两个JSON文件之间的差异
python·golang
2501_924730617 小时前
智慧城管复杂人流场景下识别准确率↑32%:陌讯多模态感知引擎实战解析
大数据·人工智能·算法·计算机视觉·目标跟踪·视觉检测·边缘计算
weixin_307779137 小时前
C++实现MATLAB矩阵计算程序
开发语言·c++·算法·matlab·矩阵
学不动CV了7 小时前
FreeRTOS入门知识(初识RTOS任务调度)(三)
c语言·arm开发·stm32·单片机·物联网·算法·51单片机