WannierTools输入文件wt.in一键批量生成脚本

做 Wannier 紧束缚计算的同学都懂:先用 Wannier90 拟合出 `wannier90_hr.dat`,再用 WannierTools 算拓扑性质。但这两个程序的输入文件格式差异较大,手动编写复杂繁琐.

本文代码则自动通过识别wannier90.win 从而一键生成wannier_tools的输入文件wt.in

使用方法

复制代码
# 一键生成全部 19 种任务的 wt.in 文件python win2wt.py wannier90.win --all
# 只生成能带和 DOSpython win2wt.py wannier90.win -t bands,dos
# 交互式选择(有菜单)python win2wt.py wannier90.win
# 列出所有可用任务python win2wt.py --list
# 指定输出目录python win2wt.py wannier90.win --all -o ./my_calc```
### 自旋极化体系
脚本会自动检测当前目录下是否有 `wannier90.up_hr.dat` 或 `wannier90.dn_hr.dat`:
```bash# 自动检测:有 up/dn 就生成两个独立目录python win2wt.py wannier90.win --all# → wt-up/  (Hrfile=wannier90.up_hr.dat, SOC=0)# → wt-dn/  (Hrfile=wannier90.dn_hr.dat, SOC=0)
# 手动指定通道python win2wt.py wannier90.win --all --up      # 只要 uppython win2wt.py wannier90.win --all --dn      # 只要 dnpython win2wt.py wannier90.win --all --nospin  # 强制标准模式```配套批量测试脚本:

bash auto_test.sh -t 50    # 逐个测试,超时 50 秒自动跳过

支持的计算任务

电子结构

bands 体态能带结构

bands_plane k 平面能带(可视化 Dirac 锥)

dos 态密度

fs 三维费米面

fs_plane k 平面费米面切片

拓扑性质

berry Berry 曲率分布

wcc Wannier 电荷中心(Wilson loop)

chirality Weyl 点手性(需手动填入坐标)

findnodes 自动搜索 Weyl 点 / 节点线

mirror_chern 镜面 Chern 数

表面/输运

slab_band 表面态能带

slab_ss 表面态自旋织构

ahc 反常霍尔电导率

ane 反常能斯特效应

shc 自旋霍尔电导率

ohe 轨道霍尔效应

landau Hofstadter 蝴蝶图谱

unfold 能带反折叠(超胞 → 原胞)

valley 谷自由度投影

注:本脚本仅辅助文件生成,任何计算任务执行前请自行详细检查输入文件确保文件路径及设置准确。

复制代码
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
win2wt: Wannier90 (.win) → WannierTools (wt.in) 自动转换脚本
================================================================================
输入:  wannier90.win
输出:  wt.in-{task_name}  (每个任务一个独立文件)
依赖:  Python 3.6+, numpy
用法:
    python win2wt.py wannier90.win          # 交互式选择任务
    python win2wt.py wannier90.win --all    # 生成所有任务的 wt.in 文件
    python win2wt.py wannier90.win --list   # 列出所有可用任务
    python win2wt.py wannier90.win -t bands,dos,ahc  # 生成指定任务
.win → wt.in 参数映射关系:
    .win begin unit_cell_cart  →  wt.in LATTICE
    .win begin atoms_cart      →  wt.in ATOM_POSITIONS (Cartesian → Direct)
    .win begin projections     →  wt.in PROJECTORS
    .win fermi_energy          →  wt.in E_FERMI
    .win spinors               →  wt.in SOC
    .win num_wann              →  wt.in NumOccupied
    .win begin kpoint_path     →  wt.in KPATH_BULK
    .win berry / kslice        →  wt.in AHC/BerryCurvature/KPLANE 参数
    .win mp_grid               →  默认 k 网格参考
    .win write_hr              →  wt.in Hrfile = 'wannier90_hr.dat'
"
""
import os
import sys
import re
import argparse
import numpy as np
from copy import deepcopy
import textwrap
# ============================================================================
# WannierTools 中属于 &SYSTEM 的参数(不在 &PARAMETERS 中!)
# 这些参数如果放到 &PARAMETERS 会导致 "Invalid line in namelist PARAMETERS" 错误
# ============================================================================
SYSTEM_PARAMS = {
    "NSLAB", "NSLAB1", "NSLAB2", "NP",
    "Bmagnitude", "Btheta", "Bphi", "Bx", "By", "Bz",
    "surf_onsite",
}
# ============================================================================
# 计算任务定义
# 每个任务包含: 名称、描述、默认 CONTROL 开关、默认 PARAMETERS、是否需要特殊模块
# ============================================================================
TASK_DEFINITIONS = {
    "bands": {
        "name": "BulkBand - 体态能带结构",
        "description": "沿高对称 k-path 计算体态能带结构。最基本的计算任务。",
        "controls": {
            "BulkBand_calc": True,
            "BulkBand_plane_calc": False,
            "BulkBand_points_calc": False,
            "DOS_calc": False,
            "BulkFS_calc": False,
            "BulkFS_plane_calc": False,
            "FindNodes_calc": False,
            "SlabBand_calc": False,
            "SlabSS_calc": False,
            "SlabArc_calc": False,
            "Wanniercenter_calc": False,
            "BerryCurvature_calc": False,
            "BerryPhase_calc": False,
            "AHC_calc": False,
            "ANE_calc": False,
            "SHC_calc": False,
            "Boltz_OHE_calc": False,
            "LandauLevel_B_calc": False,
            "LandauLevel_k_calc": False,
            "MirrorChern_calc": False,
            "WeylChirality_calc": False,
            "BulkBand_Unfold_line_calc": False,
            "BulkBand_Unfold_plane_calc": False,
            "valley_projection_calc": False,
        },
        "params": {"Nk1": 61},
        "needs": ["KPATH_BULK"],
    },
    "bands_plane": {
        "name": "BulkBand_plane - k 平面能带",
        "description": "在 k 空间平面内计算能带,用于可视化 Dirac 锥等色散特征。",
        "controls": {
            "BulkBand_calc": False,
            "BulkBand_plane_calc": True,
            "BulkBand_points_calc": False,
            "DOS_calc": False,
            "BulkFS_calc": False,
            "BulkFS_plane_calc": False,
            "FindNodes_calc": False,
            "SlabBand_calc": False,
            "SlabSS_calc": False,
            "SlabArc_calc": False,
            "Wanniercenter_calc": False,
            "BerryCurvature_calc": False,
            "BerryPhase_calc": False,
            "AHC_calc": False,
            "ANE_calc": False,
            "SHC_calc": False,
            "Boltz_OHE_calc": False,
            "LandauLevel_B_calc": False,
            "LandauLevel_k_calc": False,
            "MirrorChern_calc": False,
            "WeylChirality_calc": False,
            "BulkBand_Unfold_line_calc": False,
            "BulkBand_Unfold_plane_calc": False,
            "valley_projection_calc": False,
        },
        "params": {"Nk1": 201, "Nk2": 201},
        "needs": ["KPLANE_BULK"],
    },
    "dos": {
        "name": "DOS - 态密度",
        "description": "计算体态密度(Density of States),用于分析能隙、范霍夫奇点等。",
        "controls": {
            "DOS_calc": True,
            "BulkBand_calc": False,
            "BulkBand_plane_calc": False,
            "BulkFS_calc": False,
            "BulkFS_plane_calc": False,
            "FindNodes_calc": False,
            "SlabBand_calc": False,
            "SlabSS_calc": False,
            "SlabArc_calc": False,
            "Wanniercenter_calc": False,
            "BerryCurvature_calc": False,
            "AHC_calc": False,
            "ANE_calc": False,
            "SHC_calc": False,
            "Boltz_OHE_calc": False,
            "LandauLevel_B_calc": False,
            "LandauLevel_k_calc": False,
            "MirrorChern_calc": False,
            "WeylChirality_calc": False,
            "BulkBand_Unfold_line_calc": False,
            "BulkBand_Unfold_plane_calc": False,
            "valley_projection_calc": False,
        },
        "params": {
            "Eta_Arc": 0.05,
            "OmegaNum": 2001,
            "OmegaMin": -10.0,
            "OmegaMax": 10.0,
            "Nk1": 101,
            "Nk2": 101,
            "Nk3": 1,
        },
        "needs": ["KCUBE_BULK"],
    },
    "fs": {
        "name": "BulkFS - 3D 费米面",
        "description": "计算三维费米面,用于分析金属/半金属的费米面拓扑。",
        "controls": {
            "BulkFS_calc": True,
            "BulkBand_calc": False,
            "BulkBand_plane_calc": False,
            "DOS_calc": False,
            "BulkFS_plane_calc": False,
            "FindNodes_calc": False,
            "SlabBand_calc": False,
            "SlabSS_calc": False,
            "SlabArc_calc": False,
            "Wanniercenter_calc": False,
            "BerryCurvature_calc": False,
            "AHC_calc": False,
            "ANE_calc": False,
            "SHC_calc": False,
            "Boltz_OHE_calc": False,
            "LandauLevel_B_calc": False,
            "LandauLevel_k_calc": False,
            "MirrorChern_calc": False,
            "WeylChirality_calc": False,
            "BulkBand_Unfold_line_calc": False,
            "BulkBand_Unfold_plane_calc": False,
            "valley_projection_calc": False,
        },
        "params": {"Nk1": 101, "Nk2": 101, "Nk3": 41},
        "needs": ["KCUBE_BULK"],
    },
    "fs_plane": {
        "name": "BulkFS_plane - 2D 费米面截面",
        "description": "计算 k 空间平面内的费米面截面(等高线图)。",
        "controls": {
            "BulkFS_plane_calc": True,
            "BulkBand_calc": False,
            "BulkBand_plane_calc": False,
            "DOS_calc": False,
            "BulkFS_calc": False,
            "FindNodes_calc": False,
            "SlabBand_calc": False,
            "SlabSS_calc": False,
            "SlabArc_calc": False,
            "Wanniercenter_calc": False,
            "BerryCurvature_calc": False,
            "AHC_calc": False,
            "ANE_calc": False,
            "SHC_calc": False,
            "Boltz_OHE_calc": False,
            "LandauLevel_B_calc": False,
            "LandauLevel_k_calc": False,
            "MirrorChern_calc": False,
            "WeylChirality_calc": False,
            "BulkBand_Unfold_line_calc": False,
            "BulkBand_Unfold_plane_calc": False,
            "valley_projection_calc": False,
        },
        "params": {"Eta_Arc": 0.05, "E_arc": 0.0, "Nk1": 101, "Nk2": 101},
        "needs": ["KPLANE_BULK"],
    },
    "findnodes": {
        "name": "FindNodes - Weyl/Dirac 点搜索",
        "description": "在 3D 布里渊区中搜索能带交叉点(Weyl 点或 Dirac 点)。",
        "controls": {
            "FindNodes_calc": True,
            "BulkBand_calc": False,
            "BulkBand_plane_calc": False,
            "DOS_calc": False,
            "BulkFS_calc": False,
            "BulkFS_plane_calc": False,
            "SlabBand_calc": False,
            "SlabSS_calc": False,
            "SlabArc_calc": False,
            "Wanniercenter_calc": False,
            "BerryCurvature_calc": False,
            "AHC_calc": False,
            "ANE_calc": False,
            "SHC_calc": False,
            "Boltz_OHE_calc": False,
            "LandauLevel_B_calc": False,
            "LandauLevel_k_calc": False,
            "MirrorChern_calc": False,
            "WeylChirality_calc": False,
            "BulkBand_Unfold_line_calc": False,
            "BulkBand_Unfold_plane_calc": False,
            "valley_projection_calc": False,
        },
        "params": {"Nk1": 6, "Nk2": 6, "Nk3": 6, "Gap_threshold": 0.0001},
        "needs": ["KCUBE_BULK"],
    },
    "chirality": {
        "name": "WeylChirality - Weyl 点手性计算",
        "description": "计算每个 Weyl 点的手性(Chern 数 ±1),需要先运行 FindNodes。",
        "controls": {
            "WeylChirality_calc": True,
            "BulkBand_calc": False,
            "BulkBand_plane_calc": False,
            "DOS_calc": False,
            "BulkFS_calc": False,
            "BulkFS_plane_calc": False,
            "FindNodes_calc": False,
            "SlabBand_calc": False,
            "SlabSS_calc": False,
            "SlabArc_calc": False,
            "Wanniercenter_calc": False,
            "BerryCurvature_calc": False,
            "AHC_calc": False,
            "ANE_calc": False,
            "SHC_calc": False,
            "Boltz_OHE_calc": False,
            "LandauLevel_B_calc": False,
            "LandauLevel_k_calc": False,
            "MirrorChern_calc": False,
            "BulkBand_Unfold_line_calc": False,
            "BulkBand_Unfold_plane_calc": False,
            "valley_projection_calc": False,
        },
        "params": {"Nk1": 60, "Nk2": 60},
        "needs": ["WEYL_CHIRALITY_PLACEHOLDER"],
    },
    "slab_band": {
        "name": "SlabBand - 表面能带结构",
        "description": "计算半无限 slab 体系的表面能带结构。",
        "controls": {
            "SlabBand_calc": True,
            "BulkBand_calc": False,
            "BulkBand_plane_calc": False,
            "DOS_calc": False,
            "BulkFS_calc": False,
            "BulkFS_plane_calc": False,
            "FindNodes_calc": False,
            "SlabSS_calc": False,
            "SlabArc_calc": False,
            "Wanniercenter_calc": False,
            "BerryCurvature_calc": False,
            "AHC_calc": False,
            "ANE_calc": False,
            "SHC_calc": False,
            "Boltz_OHE_calc": False,
            "LandauLevel_B_calc": False,
            "LandauLevel_k_calc": False,
            "MirrorChern_calc": False,
            "WeylChirality_calc": False,
            "BulkBand_Unfold_line_calc": False,
            "BulkBand_Unfold_plane_calc": False,
            "valley_projection_calc": False,
        },
        "params": {"NSLAB": 40, "NP": 1, "Nk1": 101},
        "needs": ["KPATH_SLAB"],
    },
    "slab_ss": {
        "name": "SlabSS - 表面态谱函数",
        "description": "计算半无限 slab 的表面态谱函数,可视化拓扑表面态和费米弧。",
        "controls": {
            "SlabSS_calc": True,
            "SlabArc_calc": True,
            "BulkBand_calc": False,
            "BulkBand_plane_calc": False,
            "DOS_calc": False,
            "BulkFS_calc": False,
            "BulkFS_plane_calc": False,
            "FindNodes_calc": False,
            "SlabBand_calc": False,
            "Wanniercenter_calc": False,
            "BerryCurvature_calc": False,
            "AHC_calc": False,
            "ANE_calc": False,
            "SHC_calc": False,
            "Boltz_OHE_calc": False,
            "LandauLevel_B_calc": False,
            "LandauLevel_k_calc": False,
            "MirrorChern_calc": False,
            "WeylChirality_calc": False,
            "BulkBand_Unfold_line_calc": False,
            "BulkBand_Unfold_plane_calc": False,
            "valley_projection_calc": False,
        },
        "params": {
            "Eta_Arc": 0.001,
            "E_arc": 0.0,
            "OmegaNum": 400,
            "OmegaMin": -1.0,
            "OmegaMax": 1.0,
            "Nk1": 201,
            "Nk2": 201,
            "NP": 2,
        },
        "needs": ["KPLANE_SLAB", "KPATH_SLAB"],
    },
    "wcc": {
        "name": "Wanniercenter - Wilson Loop / WCC",
        "description": "计算 Wannier 电荷中心演化(Wilson loop),用于获取 Chern 数或 Z2 不变量。",
        "controls": {
            "Wanniercenter_calc": True,
            "BulkBand_calc": False,
            "BulkBand_plane_calc": False,
            "DOS_calc": False,
            "BulkFS_calc": False,
            "BulkFS_plane_calc": False,
            "FindNodes_calc": False,
            "SlabBand_calc": False,
            "SlabSS_calc": False,
            "SlabArc_calc": False,
            "BerryCurvature_calc": False,
            "AHC_calc": False,
            "ANE_calc": False,
            "SHC_calc": False,
            "Boltz_OHE_calc": False,
            "LandauLevel_B_calc": False,
            "LandauLevel_k_calc": False,
            "MirrorChern_calc": False,
            "WeylChirality_calc": False,
            "BulkBand_Unfold_line_calc": False,
            "BulkBand_Unfold_plane_calc": False,
            "valley_projection_calc": False,
        },
        "params": {"Nk1": 60, "Nk2": 60},
        "needs": ["KPLANE_BULK"],
    },
    "berry": {
        "name": "BerryCurvature - Berry 曲率分布",
        "description": "在 k 平面内计算 Berry 曲率分布,用于分析动量空间的拓扑性质。",
        "controls": {
            "BerryCurvature_calc": True,
            "BulkBand_calc": False,
            "BulkBand_plane_calc": False,
            "DOS_calc": False,
            "BulkFS_calc": False,
            "BulkFS_plane_calc": False,
            "FindNodes_calc": False,
            "SlabBand_calc": False,
            "SlabSS_calc": False,
            "SlabArc_calc": False,
            "Wanniercenter_calc": False,
            "AHC_calc": False,
            "ANE_calc": False,
            "SHC_calc": False,
            "Boltz_OHE_calc": False,
            "LandauLevel_B_calc": False,
            "LandauLevel_k_calc": False,
            "MirrorChern_calc": False,
            "WeylChirality_calc": False,
            "BulkBand_Unfold_line_calc": False,
            "BulkBand_Unfold_plane_calc": False,
            "valley_projection_calc": False,
        },
        "params": {"Nk1": 101, "Nk2": 101},
        "needs": ["KPLANE_BULK"],
    },
    "ahc": {
        "name": "AHC - 反常霍尔电导率",
        "description": "计算反常霍尔电导率 σ_xy(E),需要高 k 点密度。仅适用于磁性/SOC 体系。",
        "controls": {
            "AHC_calc": True,
            "BulkBand_calc": False,
            "BulkBand_plane_calc": False,
            "DOS_calc": False,
            "BulkFS_calc": False,
            "BulkFS_plane_calc": False,
            "FindNodes_calc": False,
            "SlabBand_calc": False,
            "SlabSS_calc": False,
            "SlabArc_calc": False,
            "Wanniercenter_calc": False,
            "BerryCurvature_calc": False,
            "ANE_calc": False,
            "SHC_calc": False,
            "Boltz_OHE_calc": False,
            "LandauLevel_B_calc": False,
            "LandauLevel_k_calc": False,
            "MirrorChern_calc": False,
            "WeylChirality_calc": False,
            "BulkBand_Unfold_line_calc": False,
            "BulkBand_Unfold_plane_calc": False,
            "valley_projection_calc": False,
        },
        "params": {
            "Eta_Arc": 0.01,
            "OmegaNum": 401,
            "OmegaMin": -1.0,
            "OmegaMax": 1.0,
            "Nk1": 101,
            "Nk2": 101,
            "Nk3": 101,
        },
        "needs": ["KCUBE_BULK"],
    },
    "ane": {
        "name": "ANE - 反常能斯特效应",
        "description": "计算反常能斯特电导率随温度的变化,需要高 k 点密度。",
        "controls": {
            "ANE_calc": True,
            "BulkBand_calc": False,
            "BulkBand_plane_calc": False,
            "DOS_calc": False,
            "BulkFS_calc": False,
            "BulkFS_plane_calc": False,
            "FindNodes_calc": False,
            "SlabBand_calc": False,
            "SlabSS_calc": False,
            "SlabArc_calc": False,
            "Wanniercenter_calc": False,
            "BerryCurvature_calc": False,
            "AHC_calc": False,
            "SHC_calc": False,
            "Boltz_OHE_calc": False,
            "LandauLevel_B_calc": False,
            "LandauLevel_k_calc": False,
            "MirrorChern_calc": False,
            "WeylChirality_calc": False,
            "BulkBand_Unfold_line_calc": False,
            "BulkBand_Unfold_plane_calc": False,
            "valley_projection_calc": False,
        },
        "params": {
            "Eta_Arc": 0.01,
            "OmegaNum": 401,
            "OmegaMin": -1.0,
            "OmegaMax": 1.0,
            "Nk1": 101,
            "Nk2": 101,
            "Nk3": 101,
            "Tmin": 10,
            "Tmax": 310,
            "NumT": 31,
            "Bmagnitude": 1.0,
            "Btheta": 0.0,
            "Bphi": 0.0,
        },
        "needs": ["KCUBE_BULK"],
    },
    "shc": {
        "name": "SHC - 自旋霍尔电导率",
        "description": "计算自旋霍尔电导率,需要 SOC 和高 k 点密度。",
        "controls": {
            "SHC_calc": True,
            "BulkBand_calc": False,
            "BulkBand_plane_calc": False,
            "DOS_calc": False,
            "BulkFS_calc": False,
            "BulkFS_plane_calc": False,
            "FindNodes_calc": False,
            "SlabBand_calc": False,
            "SlabSS_calc": False,
            "SlabArc_calc": False,
            "Wanniercenter_calc": False,
            "BerryCurvature_calc": False,
            "AHC_calc": False,
            "ANE_calc": False,
            "Boltz_OHE_calc": False,
            "LandauLevel_B_calc": False,
            "LandauLevel_k_calc": False,
            "MirrorChern_calc": False,
            "WeylChirality_calc": False,
            "BulkBand_Unfold_line_calc": False,
            "BulkBand_Unfold_plane_calc": False,
            "valley_projection_calc": False,
        },
        "params": {
            "Eta_Arc": 0.05,
            "OmegaNum": 1001,
            "OmegaMin": -10.0,
            "OmegaMax": 10.0,
            "Nk1": 101,
            "Nk2": 101,
            "Nk3": 101,
        },
        "needs": ["KCUBE_BULK"],
    },
    "ohe": {
        "name": "Boltz_OHE - 轨道霍尔效应",
        "description": "计算轨道霍尔效应和磁阻,支持不同磁场方向。",
        "controls": {
            "Boltz_OHE_calc": True,
            "Symmetry_Import_calc": True,
            "BulkBand_calc": False,
            "BulkBand_plane_calc": False,
            "DOS_calc": False,
            "BulkFS_calc": False,
            "BulkFS_plane_calc": False,
            "FindNodes_calc": False,
            "SlabBand_calc": False,
            "SlabSS_calc": False,
            "SlabArc_calc": False,
            "Wanniercenter_calc": False,
            "BerryCurvature_calc": False,
            "AHC_calc": False,
            "ANE_calc": False,
            "SHC_calc": False,
            "LandauLevel_B_calc": False,
            "LandauLevel_k_calc": False,
            "MirrorChern_calc": False,
            "WeylChirality_calc": False,
            "BulkBand_Unfold_line_calc": False,
            "BulkBand_Unfold_plane_calc": False,
            "valley_projection_calc": False,
        },
        "params": {
            "OmegaNum": 3,
            "OmegaMin": -0.01,
            "OmegaMax": 0.01,
            "EF_broadening": 0.06,
            "Nk1": 41,
            "Nk2": 41,
            "Nk3": 41,
            "BTauNum": 100,
            "BTauMax": 40.0,
            "Tmin": 30,
            "Tmax": 330,
            "NumT": 11,
            "Nslice_BTau_Max": 20000,
        },
        "needs": ["KCUBE_BULK", "SELECTEDBANDS"],
    },
    "landau": {
        "name": "LandauLevel - 朗道能级 / Hofstadter 蝴蝶",
        "description": "在磁场下计算朗道能级谱(Hofstadter 蝴蝶),需要定义磁超胞。",
        "controls": {
            "LandauLevel_B_calc": True,
            "BulkBand_calc": False,
            "BulkBand_plane_calc": False,
            "DOS_calc": False,
            "BulkFS_calc": False,
            "BulkFS_plane_calc": False,
            "FindNodes_calc": False,
            "SlabBand_calc": False,
            "SlabSS_calc": False,
            "SlabArc_calc": False,
            "Wanniercenter_calc": False,
            "BerryCurvature_calc": False,
            "AHC_calc": False,
            "ANE_calc": False,
            "SHC_calc": False,
            "Boltz_OHE_calc": False,
            "LandauLevel_k_calc": False,
            "MirrorChern_calc": False,
            "WeylChirality_calc": False,
            "BulkBand_Unfold_line_calc": False,
            "BulkBand_Unfold_plane_calc": False,
            "valley_projection_calc": False,
        },
        "params": {
            "NSLAB": 400,
            "Eta_Arc": 0.1,
            "OmegaNum": 1001,
            "OmegaMin": -8.0,
            "OmegaMax": 12.0,
            "Nk1": 11,
            "Magp": 100,
            "NumRandomConfs": 10,
            "Bmagnitude": 10.0,
            "Btheta": 0.0,
            "Bphi": 0.0,
        },
        "needs": ["KPATH_BULK"],
    },
    "mirror_chern": {
        "name": "MirrorChern - 镜面 Chern 数",
        "description": "计算镜面对称保护的 Chern 数,用于镜面拓扑绝缘体。",
        "controls": {
            "MirrorChern_calc": True,
            "BulkBand_calc": True,
            "BulkBand_plane_calc": False,
            "DOS_calc": False,
            "BulkFS_calc": False,
            "BulkFS_plane_calc": False,
            "FindNodes_calc": False,
            "SlabBand_calc": False,
            "SlabSS_calc": False,
            "SlabArc_calc": False,
            "Wanniercenter_calc": False,
            "BerryCurvature_calc": False,
            "AHC_calc": False,
            "ANE_calc": False,
            "SHC_calc": False,
            "Boltz_OHE_calc": False,
            "LandauLevel_B_calc": False,
            "LandauLevel_k_calc": False,
            "WeylChirality_calc": False,
            "BulkBand_Unfold_line_calc": False,
            "BulkBand_Unfold_plane_calc": False,
            "valley_projection_calc": False,
        },
        "params": {"Nk1": 81, "Nk2": 201},
        "needs": ["KPATH_BULK", "KPLANE_BULK"],
    },
    "unfold": {
        "name": "BulkBand_Unfold - 能带展开",
        "description": "将超胞能带展开回原胞布里渊区,需要定义 LATTICE_UNFOLD 等模块。",
        "controls": {
            "BulkBand_Unfold_line_calc": True,
            "BulkBand_calc": False,
            "BulkBand_plane_calc": False,
            "DOS_calc": False,
            "BulkFS_calc": False,
            "BulkFS_plane_calc": False,
            "FindNodes_calc": False,
            "SlabBand_calc": False,
            "SlabSS_calc": False,
            "SlabArc_calc": False,
            "Wanniercenter_calc": False,
            "BerryCurvature_calc": False,
            "AHC_calc": False,
            "ANE_calc": False,
            "SHC_calc": False,
            "Boltz_OHE_calc": False,
            "LandauLevel_B_calc": False,
            "LandauLevel_k_calc": False,
            "MirrorChern_calc": False,
            "WeylChirality_calc": False,
            "BulkBand_Unfold_plane_calc": False,
            "valley_projection_calc": False,
        },
        "params": {
            "Eta_Arc": 0.01,
            "OmegaNum_unfold": 600,
            "OmegaMin": -10.0,
            "OmegaMax": 10.0,
            "Nk1": 101,
        },
        "needs": ["KPATH_BULK", "UNFOLD_BLOCKS_PLACEHOLDER"],
    },
    "valley": {
        "name": "Valley - 谷投影能带",
        "description": "计算谷投影能带,区分 K/K' 谷的贡献。适用于六方 2D 材料。",
        "controls": {
            "valley_projection_calc": True,
            "BulkBand_calc": True,
            "BulkBand_plane_calc": False,
            "DOS_calc": False,
            "BulkFS_calc": False,
            "BulkFS_plane_calc": False,
            "FindNodes_calc": False,
            "SlabBand_calc": False,
            "SlabSS_calc": False,
            "SlabArc_calc": False,
            "Wanniercenter_calc": False,
            "BerryCurvature_calc": False,
            "AHC_calc": False,
            "ANE_calc": False,
            "SHC_calc": False,
            "Boltz_OHE_calc": False,
            "LandauLevel_B_calc": False,
            "LandauLevel_k_calc": False,
            "MirrorChern_calc": False,
            "WeylChirality_calc": False,
            "BulkBand_Unfold_line_calc": False,
            "BulkBand_Unfold_plane_calc": False,
        },
        "params": {"Nk1": 101},
        "needs": ["KPATH_BULK"],
    },
}
# 任务分组(用于菜单显示)
TASK_GROUPS = {
    "能带与态密度": ["bands", "bands_plane", "dos", "valley"],
    "费米面": ["fs", "fs_plane"],
    "拓扑不变量": ["wcc", "berry", "mirror_chern"],
    "Weyl 半金属": ["findnodes", "chirality"],
    "表面态": ["slab_band", "slab_ss"],
    "输运性质": ["ahc", "ane", "shc", "ohe"],
    "磁场效应": ["landau"],
    "超胞分析": ["unfold"],
}
# ============================================================================
# .win 文件解析器
# ============================================================================
class WinParser:
    """解析 Wannier90 (.win) 输入文件"""
    def __init__(self, filepath):
        self.filepath = filepath
        self.raw_content = ""
        self.data = {}
        self._parse()
    def _parse(self):
        """主解析入口"""
        with open(self.filepath, "r") as f:
            self.raw_content = f.read()
        lines = self.raw_content.split("\n")
        # 解析单行键值对
        self._parse_simple_vars(lines)
        # 解析 begin/end 块
        self._parse_blocks(self.raw_content)
        # 后处理
        self._post_process()
    def _parse_simple_vars(self, lines):
        """解析简单变量赋值(非 begin/end 块)"""
        simple_vars = {
            "num_wann": int,
            "num_bands": int,
            "fermi_energy": float,
            "spinors": lambda x: x.lower().strip() in [".true.", "true", "t"],
            "write_hr": lambda x: x.lower().strip() in [".true.", "true", "t"],
            "mp_grid": str,
            "berry": lambda x: x.lower().strip() in [".true.", "true", "t"],
            "berry_task": str,
            "berry_kmesh": str,
            "kpath": lambda x: x.lower().strip() in [".true.", "true", "t"],
            "kpath_task": str,
            "kpath_num_points": int,
            "kslice": lambda x: x.lower().strip() in [".true.", "true", "t"],
            "kslice_task": str,
            "kslice_2dkmesh": int,
            "kslice_corner": str,
            "kslice_b1": str,
            "kslice_b2": str,
        }
        in_block = False
        for line in lines:
            stripped = line.split("!")[0].split("#")[0].strip()
            if not stripped:
                continue
            if stripped.lower().startswith("begin"):
                in_block = True
            if stripped.lower().startswith("end"):
                in_block = False
                continue
            if in_block:
                continue
            # 尝试匹配 key = value
            for key, converter in simple_vars.items():
                pattern = rf"^\s*{key}\s*[=:]\s*(.+)$"
                match = re.match(pattern, stripped, re.IGNORECASE)
                if match:
                    val_str = match.group(1).strip()
                    try:
                        self.data[key] = converter(val_str)
                    except (ValueError, TypeError):
                        self.data[key] = val_str
                    break
    def _parse_blocks(self, content):
        """解析 begin/end 块"""
        blocks = {
            "unit_cell_cart": "lattice_vectors",
            "atoms_cart": "atom_positions",
            "projections": "projections",
            "kpoint_path": "kpoint_path",
            "kpoints": "kpoints",
        }
        for block_name, data_key in blocks.items():
            pattern = rf"begin\s+{block_name}\s*\n(.*?)\n\s*end\s+{block_name}"
            match = re.search(pattern, content, re.IGNORECASE | re.DOTALL)
            if match:
                self.data[data_key] = match.group(1).strip()
    def _post_process(self):
        """后处理:补全缺失值,计算衍生参数"""
        # 默认值
        self.data.setdefault("spinors", False)
        self.data.setdefault("write_hr", True)
        self.data.setdefault("fermi_energy", 0.0)
        self.data.setdefault("num_wann", 1)
        self.data.setdefault("num_bands", 2)
    def get_lattice_vectors(self):
        """获取晶格矢量(Å),返回 3×3 矩阵"""
        raw = self.data.get("lattice_vectors", "")
        if not raw:
            return None
        # 处理 Bohr 单位
        lines = raw.strip().split("\n")
        unit = "Angstrom"
        vec_lines = []
        for line in lines:
            s = line.split("!")[0].split("#")[0].strip().lower()
            if s == "bohr":
                unit = "Bohr"
            elif s == "angstrom" or s == "ang":
                unit = "Angstrom"
            elif s:
                vec_lines.append(line)
        vectors = []
        for line in vec_lines[:3]:
            nums = re.findall(r"[-+]?\d*\.?\d+(?:[eEdD][-+]?\d+)?", line)
            if len(nums) >= 3:
                v = [float(n.replace("d", "e").replace("D", "e")) for n in nums[:3]]
                vectors.append(v)
        if len(vectors) != 3:
            return None
        # Bohr → Angstrom 转换
        if unit.lower() == "bohr":
            vectors = [[x * 0.529177210903 for x in v] for v in vectors]
        return vectors
    def get_atom_positions_cart(self):
        """获取 Cartesian 原子坐标,返回 [(symbol, x, y, z), ...]"""
        raw = self.data.get("atom_positions", "")
        if not raw:
            return []
        lines = raw.strip().split("\n")
        unit = "Angstrom"
        atom_lines = []
        for line in lines:
            s = line.split("!")[0].split("#")[0].strip().lower()
            if s == "bohr":
                unit = "Bohr"
            elif s == "angstrom" or s == "ang":
                unit = "Angstrom"
            elif s:
                atom_lines.append(line)
        atoms = []
        for line in atom_lines:
            parts = line.split()
            if len(parts) >= 4:
                # 格式: Symbol x y z
                symbol = parts[0]
                nums = re.findall(r"[-+]?\d*\.?\d+(?:[eEdD][-+]?\d+)?", line)
                if len(nums) >= 3:
                    x = float(nums[0].replace("d", "e").replace("D", "e"))
                    y = float(nums[1].replace("d", "e").replace("D", "e"))
                    z = float(nums[2].replace("d", "e").replace("D", "e"))
                    if unit.lower() == "bohr":
                        x, y, z = x * 0.529177210903, y * 0.529177210903, z * 0.529177210903
                    atoms.append((symbol, x, y, z))
        return atoms
    def get_projections(self):
        """解析投影,返回投影列表。
        对于中心投影(c=),自动匹配最近原子坐标确定元素符号。
        "
""
        raw = self.data.get("projections", "")
        if not raw:
            return []
        lines = raw.strip().split("\n")
        projections = []
        atoms_cart = self.get_atom_positions_cart()  # [(symbol, x, y, z), ...]
        for line in lines:
            s = line.split("!")[0].split("#")[0].strip()
            comment = ""
            # 提取注释,用于识别原子类型(如 "! V1", "! Nb3", "! S9")
            if "!" in line:
                comment = line.split("!", 1)[1].strip()
            elif "#" in line:
                comment = line.split("#", 1)[1].strip()
            if not s:
                continue
            # 格式1: "c=x,y,z:l=0;l=2" (中心投影)
            # 格式2: "Symbol : orbital1; orbital2" (原子投影)
            if s.startswith("c="):
                # 中心投影:解析坐标和轨道
                # 格式: c=x,y,z:l=0;l=2
                c_match = re.match(r"c\s*=\s*([-+]?\d*\.?\d+(?:[eEdD][-+]?\d+)?)\s*,\s*([-+]?\d*\.?\d+(?:[eEdD][-+]?\d+)?)\s*,\s*([-+]?\d*\.?\d+(?:[eEdD][-+]?\d+)?)\s*:\s*(.+)", s)
                if not c_match:
                    # 尝试无空格版本: c=x,y,z : l=0;l=2
                    c_match = re.match(r"c\s*=\s*(.+?)\s*:\s*(.+)", s)
                    if c_match:
                        coords_str = c_match.group(1)
                        orb_str = c_match.group(2)
                        coord_parts = [c.strip() for c in coords_str.split(",")]
                        if len(coord_parts) != 3:
                            projections.append({"type": "center", "raw": s, "orbitals": [], "symbol": "?"})
                            continue
                        cx = float(coord_parts[0].replace("d", "e").replace("D", "e"))
                        cy = float(coord_parts[1].replace("d", "e").replace("D", "e"))
                        cz = float(coord_parts[2].replace("d", "e").replace("D", "e"))
                    else:
                        projections.append({"type": "center", "raw": s, "orbitals": [], "symbol": "?"})
                        continue
                else:
                    cx = float(c_match.group(1).replace("d", "e").replace("D", "e"))
                    cy = float(c_match.group(2).replace("d", "e").replace("D", "e"))
                    cz = float(c_match.group(3).replace("d", "e").replace("D", "e"))
                    orb_str = c_match.group(4)
                # 解析轨道
                orbs = []
                if "l=" in orb_str.lower():
                    l_vals = re.findall(r"l=(\d+)", orb_str.lower())
                    for lv in l_vals:
                        orbs.extend(self._l_to_orbitals(int(lv)))
                else:
                    raw_orbs = [o.strip() for o in orb_str.split(";")]
                    for ro in raw_orbs:
                        if ro:
                            orbs.extend(self._expand_orbital(ro.lower()))
                # 匹配最近原子坐标确定元素符号
                symbol = "?"
                if atoms_cart:
                    min_dist = float("inf")
                    for atom_sym, ax, ay, az in atoms_cart:
                        d = (cx - ax)**2 + (cy - ay)**2 + (cz - az)**2
                        if d < min_dist:
                            min_dist = d
                            symbol = atom_sym
                    # 如果距离 > 0.5 Å 则可能匹配失败,保留 ?
                    if min_dist > 0.25:
                        symbol = "?"
                # 如果注释中包含元素符号,优先使用注释
                if comment:
                    # 注释格式如 "V1", "Nb3", "S9" --- 提取字母部分
                    atom_symbol_from_comment = re.match(r"([A-Za-z]+)", comment)
                    if atom_symbol_from_comment:
                        symbol = atom_symbol_from_comment.group(1)
                projections.append({
                    "type": "center",
                    "raw": s,
                    "orbitals": orbs,
                    "symbol": symbol,
                    "comment": comment,
                })
            else:
                # 原子投影
                parts = s.split(":")
                if len(parts) >= 2:
                    symbol = parts[0].strip()
                    orb_str = parts[1].strip()
                    orbs = []
                    if "l=" in orb_str.lower():
                        l_vals = re.findall(r"l=(\d+)", orb_str.lower())
                        for lv in l_vals:
                            orbs.extend(self._l_to_orbitals(int(lv)))
                    else:
                        raw_orbs = [o.strip() for o in orb_str.split(";")]
                        for ro in raw_orbs:
                            if ro:
                                orbs.extend(self._expand_orbital(ro.lower()))
                    projections.append({"type": "atom", "symbol": symbol, "orbitals": orbs})
        return projections
    def _l_to_orbitals(self, l_val):
        """将角量子数 l 转换为轨道名称"""
        l_map = {
            0: ["s"],
            1: ["pz", "px", "py"],
            2: ["dz2", "dxz", "dyz", "dx2-y2", "dxy"],
            3: ["fz3", "fxz2", "fyz2", "fxyz", "fz(x2-y2)", "fx(x2-3y2)", "fy(3x2-y2)"],
        }
        return l_map.get(l_val, [])
    def _expand_orbital(self, orb_name):
        """展开简写轨道名"""
        if orb_name == "s":
            return ["s"]
        elif orb_name == "p":
            return ["pz", "px", "py"]
        elif orb_name == "d":
            return ["dz2", "dxz", "dyz", "dx2-y2", "dxy"]
        elif orb_name == "f":
            return ["fz3", "fxz2", "fyz2", "fxyz", "fz(x2-y2)", "fx(x2-3y2)", "fy(3x2-y2)"]
        elif orb_name == "sp":
            return ["s", "pz", "px", "py"]
        elif orb_name == "spd":
            return ["s", "pz", "px", "py", "dz2", "dxz", "dyz", "dx2-y2", "dxy"]
        else:
            return [orb_name]
    def get_kpoint_path(self):
        """获取 k-path 定义,返回 [(label1, k1, label2, k2), ...]"""
        raw = self.data.get("kpoint_path", "")
        if not raw:
            return []
        lines = raw.strip().split("\n")
        paths = []
        for line in lines:
            s = line.split("!")[0].split("#")[0].strip()
            if not s:
                continue
            # 格式: "Label1 kx1 ky1 kz1  Label2 kx2 ky2 kz2"
            parts = s.split()
            if len(parts) >= 8:
                label1 = parts[0]
                k1 = [float(x) for x in parts[1:4]]
                label2 = parts[4]
                k2 = [float(x) for x in parts[5:8]]
                paths.append((label1, k1, label2, k2))
        return paths
    def get_mp_grid(self):
        """获取 Monkhorst-Pack k 网格"""
        raw = self.data.get("mp_grid", "")
        if not raw:
            return (1, 1, 1)
        parts = raw.split()
        if len(parts) >= 3:
            return (int(parts[0]), int(parts[1]), int(parts[2]))
        return (1, 1, 1)
    def get_fermi_energy(self):
        """获取费米能级"""
        return self.data.get("fermi_energy", 0.0)
    def get_soc(self):
        """获取 SOC 标志"""
        return 1 if self.data.get("spinors", False) else 0
    def get_num_wann(self):
        """获取 Wannier 函数数量"""
        return self.data.get("num_wann", 1)
    def get_berry_info(self):
        """获取 Berry 曲率相关参数"""
        return {
            "berry": self.data.get("berry", False),
            "berry_task": self.data.get("berry_task", ""),
            "berry_kmesh": self.data.get("berry_kmesh", ""),
        }
    def get_kslice_info(self):
        """获取 kslice 参数"""
        return {
            "kslice": self.data.get("kslice", False),
            "kslice_task": self.data.get("kslice_task", ""),
            "kslice_2dkmesh": self.data.get("kslice_2dkmesh", 50),
            "kslice_corner": self.data.get("kslice_corner", "0.0 0.0 0.0"),
            "kslice_b1": self.data.get("kslice_b1", "1.0 0.0 0.0"),
            "kslice_b2": self.data.get("kslice_b2", "0.0 1.0 0.0"),
        }
# ============================================================================
# wt.in 生成器
# ============================================================================
class WtInGenerator:
    """根据解析的 .win 数据生成 wt.in 文件
    支持自旋极化体系:若检测到 wannier90.up/dn_hr.dat,自动设置
    spin_channel='up' 或 'dn',并调整 Hrfile、SOC、NumOccupied。
    "
""
    def __init__(self, win_parser, spin_channel=None, hrfile_path=None):
        """
        Parameters
        ----------
        win_parser : WinParser
            已解析的 .win 文件解析器
        spin_channel : str or None
            None:  非自旋极化(默认),SOC 从 .win 读取
            'up':  自旋上通道,SOC = 0,NumOccupied = num_wann
            'dn':  自旋下通道,SOC = 0,NumOccupied = num_wann
        hrfile_path : str or None
            Hrfile 的完整路径。若不提供,自动推导为 wannier90.{spin_channel}_hr.dat
            或 wannier90_hr.dat。提供路径可确保 wt.in 从任意工作目录都能找到 hr.dat。
        "
""
        self.win = win_parser
        self.spin_channel = spin_channel
        self._hrfile_path = hrfile_path  # 保存原始路径,用于 _gen_tb_file
        self.lattice = win_parser.get_lattice_vectors()
        self.atoms_cart = win_parser.get_atom_positions_cart()
        self.projections = win_parser.get_projections()
        self.kpath = win_parser.get_kpoint_path()
        self.fermi = win_parser.get_fermi_energy()
        self.num_wann = win_parser.get_num_wann()
        self.mp_grid = win_parser.get_mp_grid()
        self.berry = win_parser.get_berry_info()
        self.kslice = win_parser.get_kslice_info()
        # SOC 检测:自旋极化体系(up/dn 分开)SOC=0,否则从 .win 读取
        if spin_channel:
            self.soc = 0
        else:
            self.soc = win_parser.get_soc()
        # 计算原子 Direct 坐标
        self.atoms_direct = self._cart_to_direct()
        # 计算默认 NumOccupied
        # 自旋极化体系(up/dn):每个自旋通道独立,NumOccupied = num_wann
        # SOC 体系(spinors=true):num_wann = 2 * 轨道数,NumOccupied = num_wann // 2
        # 非 SOC 非极化体系:NumOccupied = num_wann
        self.num_occupied = self._estimate_num_occupied()
    def _cart_to_direct(self):
        """将 Cartesian 坐标转换为 Direct(分数)坐标"""
        if not self.lattice or not self.atoms_cart:
            return []
        lat = np.array(self.lattice, dtype=float)
        try:
            inv_lat = np.linalg.inv(lat.T)  # 注意:晶格矢量是行向量
        except np.linalg.LinAlgError:
            return []
        result = []
        for symbol, x, y, z in self.atoms_cart:
            cart = np.array([x, y, z], dtype=float)
            direct = inv_lat @ cart
            result.append((symbol, direct[0], direct[1], direct[2]))
        return result
    def _estimate_num_occupied(self):
        """估算占据能带数。
        物理依据:
        - 自旋极化体系(spin_channel='up'/'dn'): 每个自旋通道的 Wannier
          函数全部对应该通道的能带,NumOccupied = num_wann(全部占据或
          由 .win 中的 num_bands 和 dis_froz_max 决定)
        - SOC 体系(spinors=true): Wannier 函数的自旋自由度已编码在轨道中,
          num_wann = 2 * 轨道数,占据数约为 num_wann // 2
        - 非 SOC 非极化体系: NumOccupied = num_wann
        注意:此估算对含半金属/磁性金属可能是近似值,建议用户根据 DFT
        价带数手动核对。
        "
""
        if self.spin_channel:
            # 自旋极化体系:所有 num_wann 个 Wannier 函数都对应自旋通道的能带
            return self.num_wann
        elif self.soc == 1:
            # SOC 情况下,num_wann 包含自旋,占据数约为 num_wann/2
            return self.num_wann // 2
        else:
            return self.num_wann
    def _get_default_surface(self):
        """获取默认 SURFACE 矩阵"""
        return " 1  0  0\n 0  1  0\n 0  0  1"
    def _get_default_kplane(self):
        """获取默认 KPLANE_BULK"""
        kslice = self.win.get_kslice_info()
        if kslice["kslice"]:
            corner = kslice["kslice_corner"]
            b1 = kslice["kslice_b1"]
            b2 = kslice["kslice_b2"]
            return (
                f"  {corner}   ! Original point for 3D k plane\n"
                f"  {b1}   ! The first vector to define 3d k space plane\n"
                f"  {b2}   ! The second vector to define 3d k space plane"
            )
        return (
            " 0.00  0.00  0.00   ! Original point for 3D k plane\n"
            " 1.00  0.00  0.00   ! The first vector to define 3d k space plane\n"
            " 0.00  1.00  0.00   ! The second vector to define 3d k space plane"
        )
    def _get_default_kcube(self):
        """获取默认 KCUBE_BULK"""
        return (
            " 0.00  0.00  0.00   ! Original point for 3D k plane\n"
            " 1.00  0.00  0.00   ! The first vector to define 3d k space plane\n"
            " 0.00  1.00  0.00   ! The second vector to define 3d k space plane\n"
            " 0.00  0.00  1.00   ! The third vector to define 3d k cube"
        )
    def generate(self, task_key, output_path=None, overrides=None):
        """生成指定任务的 wt.in 文件。
        自动将 NSLAB, NP, Bmagnitude 等属于 &SYSTEM 的参数
        分离到 &SYSTEM namelist,避免 "
Invalid line in namelist PARAMETERS
" 错误。
        "
""
        if task_key not in TASK_DEFINITIONS:
            raise ValueError(f"未知任务: {task_key}。可用任务: {list(TASK_DEFINITIONS.keys())}")
        # 记住输出目录,供 _gen_tb_file 计算 Hrfile 相对路径
        if output_path:
            self._output_dir = os.path.dirname(os.path.abspath(output_path)) or "."
        task = TASK_DEFINITIONS[task_key]
        controls = deepcopy(task["controls"])
        all_params = deepcopy(task["params"])
        # 应用覆盖参数
        if overrides:
            all_params.update(overrides)
        # 分离系统级参数和数值参数
        # 系统级参数属于 &SYSTEM(NSLAB, NP, Bmagnitude 等)
        # 数值参数属于 &PARAMETERS(Nk1, Nk2, Eta_Arc 等)
        system_params = {}
        params = {}
        for key, val in all_params.items():
            if key in SYSTEM_PARAMS:
                system_params[key] = val
            else:
                params[key] = val
        lines = []
        lines.append(self._gen_tb_file())
        lines.append("")
        lines.append(self._gen_control(controls))
        lines.append("")
        lines.append(self._gen_system(system_params))
        lines.append("")
        lines.append(self._gen_parameters(params))
        lines.append("")
        lines.append(self._gen_lattice())
        lines.append("")
        lines.append(self._gen_atom_positions())
        lines.append("")
        lines.append(self._gen_projectors())
        lines.append("")
        lines.append(self._gen_surface())
        lines.append("")
        # 按需添加模块
        for need in task.get("needs", []):
            module = self._gen_module(need)
            if module:
                lines.append(module)
                lines.append("")
        content = "\n".join(lines)
        if output_path:
            with open(output_path, "w") as f:
                f.write(content)
            print(f"  ✓ 已生成: {output_path}")
        return content
    def _gen_tb_file(self):
        """生成 &TB_FILE namelist。
        Hrfile 使用从输出目录到 hr.dat 的相对路径(如 '../wannier90_hr.dat'),
        避免硬编码绝对路径导致换机器运行时报 "
no HmnR input
"。
        "
""
        if self._hrfile_path:
            hr_abs = os.path.abspath(self._hrfile_path)
            out_dir = getattr(self, '_output_dir', os.getcwd())
            hrfile = os.path.relpath(hr_abs, out_dir)
        elif self.spin_channel:
            hrfile = f"wannier90.{self.spin_channel}_hr.dat"
        else:
            hrfile = "wannier90_hr.dat"
        return f"&TB_FILE\nHrfile = '{hrfile}'\n/"
    def _gen_control(self, controls):
        lines = ["&CONTROL"]
        for key, val in controls.items():
            type_char = "T" if val else "F"
            lines.append(f"  {key:<30s} = {type_char}")
        lines.append("/")
        return "\n".join(lines)
    def _gen_system(self, system_params=None):
        """生成 &SYSTEM namelist。
        SOC, E_FERMI, NumOccupied 为必填项。
        system_params 包含属于 &SYSTEM 的可选参数(NSLAB, NP, Bmagnitude 等)。
        "
""
        lines = ["&SYSTEM"]
        lines.append(f"  SOC = {self.soc}")
        lines.append(f"  E_FERMI = {self.fermi}")
        lines.append(f"  NumOccupied = {self.num_occupied}")
        if system_params:
            for key in sorted(system_params.keys()):
                val = system_params[key]
                if isinstance(val, float):
                    lines.append(f"  {key} = {val}")
                elif isinstance(val, bool):
                    lines.append(f"  {key} = {'T' if val else 'F'}")
                else:
                    lines.append(f"  {key} = {val}")
        lines.append("/")
        return "\n".join(lines)
    def _gen_parameters(self, params):
        lines = ["&PARAMETERS"]
        for key, val in params.items():
            if isinstance(val, float):
                lines.append(f"  {key} = {val}")
            elif isinstance(val, bool):
                lines.append(f"  {key} = {'T' if val else 'F'}")
            else:
                lines.append(f"  {key} = {val}")
        lines.append("/")
        return "\n".join(lines)
    def _gen_lattice(self):
        if not self.lattice:
            return "LATTICE\nAngstrom\n# [请手动填写晶格矢量]\n1.0 0.0 0.0\n0.0 1.0 0.0\n0.0 0.0 1.0"
        lines = ["LATTICE", "Angstrom"]
        for v in self.lattice:
            lines.append(f"  {v[0]:12.6f} {v[1]:12.6f} {v[2]:12.6f}")
        return "\n".join(lines)
    def _gen_atom_positions(self):
        if not self.atoms_direct:
            if not self.atoms_cart:
                return "ATOM_POSITIONS\n1\nDirect\n# [请手动填写原子坐标]\nX 0.0 0.0 0.0"
            # 有 Cartesian 但没有 Direct(转换失败),使用 Cartesian
            lines = ["ATOM_POSITIONS", f"{len(self.atoms_cart)}", "Cartesian"]
            for symbol, x, y, z in self.atoms_cart:
                lines.append(f"  {symbol:<4s} {x:12.6f} {y:12.6f} {z:12.6f}  0.0  0.0  0.0")
            return "\n".join(lines)
        lines = ["ATOM_POSITIONS", f"{len(self.atoms_direct)}", "Direct"]
        for symbol, x, y, z in self.atoms_direct:
            lines.append(f"  {symbol:<4s} {x:12.6f} {y:12.6f} {z:12.6f}  0.0  0.0  0.0")
        return "\n".join(lines)
    def _gen_projectors(self):
        """生成 PROJECTORS 块。
        自动处理中心投影(c=)和原子投影(Symbol:orbitals)两种格式。
        中心投影:每个 c= 已对应一个原子,直接使用。
        原子投影:按符号类型定义,需展开为每个原子的投影。
        "
""
        if not self.projections:
            return "PROJECTORS\n1  ! number of projectors\nX s"
        proj_counts = []
        proj_lines = []
        # 判断是否为原子类型投影(需要展开到每个原子)
        is_atom_type = all(p.get("type") == "atom" for p in self.projections)
        if is_atom_type:
            # 原子类型投影:symbol → orbitals 映射
            proj_map = {}
            for p in self.projections:
                proj_map[p["symbol"]] = p.get("orbitals", [])
            # 遍历原子列表,为每个原子匹配投影
            for symbol, x, y, z in self.atoms_cart:
                if symbol in proj_map:
                    orbs = proj_map[symbol]
                else:
                    # 尝试大小写不敏感匹配
                    matched = False
                    for key in proj_map:
                        if key.lower() == symbol.lower():
                            orbs = proj_map[key]
                            matched = True
                            break
                    if not matched:
                        orbs = []
                norbs = len(orbs)
                proj_counts.append(norbs)
                if norbs > 0:
                    proj_lines.append(f"  {symbol:<4s} {' '.join(orbs)}")
                else:
                    proj_lines.append(f"# {symbol}  ! [未找到投影定义]")
        else:
            # 中心投影:每个 c= 已对应一个原子
            for p in self.projections:
                norbs = len(p.get("orbitals", []))
                proj_counts.append(norbs)
                symbol = p.get("symbol", "?")
                if norbs > 0:
                    proj_lines.append(f"  {symbol:<4s} {' '.join(p['orbitals'])}")
                else:
                    proj_lines.append(f"# {p.get('raw', '?')}  ! [未解析到轨道]")
        count_str = " ".join(str(c) for c in proj_counts)
        lines = [f"PROJECTORS", f"  {count_str}  ! number of projectors per atom"]
        lines.extend(proj_lines)
        return "\n".join(lines)
    def _gen_surface(self):
        return "SURFACE            ! See doc for details\n" + self._get_default_surface()
    def _gen_module(self, need):
        """生成特定模块"""
        if need == "KPATH_BULK":
            return self._gen_kpath_bulk()
        elif need == "KPLANE_BULK":
            return self._gen_kplane_bulk()
        elif need == "KCUBE_BULK":
            return self._gen_kcube_bulk()
        elif need == "KPATH_SLAB":
            return self._gen_kpath_slab()
        elif need == "KPLANE_SLAB":
            return self._gen_kplane_slab()
        elif need == "SELECTEDBANDS":
            return self._gen_selectedbands()
        elif need == "WEYL_CHIRALITY_PLACEHOLDER":
            return self._gen_weyl_chirality_placeholder()
        elif need == "UNFOLD_BLOCKS_PLACEHOLDER":
            return self._gen_unfold_placeholder()
        return ""
    def _gen_kpath_bulk(self):
        """从 .win 的 kpoint_path 生成 KPATH_BULK"""
        if not self.kpath:
            return (
                "KPATH_BULK            ! k point path\n"
                "1              ! number of k line\n"
                "  G 0.0 0.0 0.0  X 0.5 0.0 0.0  ! [自动生成,请手动检查]"
            )
        lines = ["KPATH_BULK            ! k point path"]
        lines.append(f"  {len(self.kpath)}              ! number of k line only for bulk band")
        for label1, k1, label2, k2 in self.kpath:
            lines.append(f"  {label1:<4s} {k1[0]:9.5f} {k1[1]:9.5f} {k1[2]:9.5f}  "
                         f"{label2:<4s} {k2[0]:9.5f} {k2[1]:9.5f} {k2[2]:9.5f}")
        return "\n".join(lines)
    def _gen_kplane_bulk(self):
        return "KPLANE_BULK\n" + self._get_default_kplane()
    def _gen_kcube_bulk(self):
        return "KCUBE_BULK\n" + self._get_default_kcube()
    def _gen_kpath_slab(self):
        """从 bulk kpath 推导 slab kpath(投影到 2D)"""
        if self.kpath:
            # 从 bulk kpath 提取 2D 投影(取前两个分量)
            lines = ["KPATH_SLAB"]
            lines.append(f"  {len(self.kpath)}        ! number of k line for 2D case")
            for label1, k1, label2, k2 in self.kpath:
                lines.append(f"  {label1:<4s} {k1[0]:9.5f} {k1[1]:9.5f}  "
                             f"{label2:<4s} {k2[0]:9.5f} {k2[1]:9.5f}")
            return "\n".join(lines)
        return (
            "KPATH_SLAB\n"
            "2        ! number of k line for 2D case\n"
            "  X -0.5 0.0  G 0.0 0.0  ! k path for 2D case\n"
            "  G 0.0 0.0  X 0.5 0.0   ! [自动生成,请手动检查]"
        )
    def _gen_kplane_slab(self):
        """从 kslice 或 bulk kpath 推导 slab k-plane"""
        # 优先使用 kslice 信息
        kslice = self.win.get_kslice_info()
        if kslice["kslice"]:
            corner = kslice["kslice_corner"]
            b1 = kslice["kslice_b1"]
            b2 = kslice["kslice_b2"]
            # 取前两个分量作为 2D plane
            corner_parts = corner.split()
            b1_parts = b1.split()
            b2_parts = b2.split()
            if len(corner_parts) >= 2 and len(b1_parts) >= 2 and len(b2_parts) >= 2:
                return (
                    "KPLANE_SLAB\n"
                    f"  {corner_parts[0]:>6s} {corner_parts[1]:>6s}      ! Original point for 2D k plane\n"
                    f"  {b1_parts[0]:>6s} {b1_parts[1]:>6s}      ! The first vector to define 2D k plane\n"
                    f"  {b2_parts[0]:>6s} {b2_parts[1]:>6s}      ! The second vector to define 2D k plane"
                )
        return (
            "KPLANE_SLAB\n"
            "  -0.5 -0.5      ! Original point for 2D k plane\n"
            "   1.0  0.0      ! The first vector to define 2D k plane\n"
            "   0.0  1.0      ! The second vector to define 2D k plane"
        )
    def _gen_selectedbands(self):
        return (
            "SELECTEDBANDS\n"
            f"  1\n"
            f"  {self.num_occupied}"
        )
    def _gen_weyl_chirality_placeholder(self):
        return (
            "WEYL_CHIRALITY\n"
            "0            ! Num_Weyls (请从 FindNodes 输出中获取并修改)\n"
            "Cartesian    ! Direct or Cartesian coordinate\n"
            "0.004        ! Radius of the ball surround a Weyl point\n"
            "# [提示] 先运行 findnodes 获取 Weyl 点坐标,然后修改 Num_Weyls\n"
            "# 并在此处逐个添加 Weyl 点坐标行,格式: <wx>  <wy>  <wz>"
        )
    def _gen_unfold_placeholder(self):
        return (
            "# ============================================================\n"
            "# [能带展开] 需要手动定义以下模块(替换注释为实际数据):\n"
            "# ============================================================\n"
            "# LATTICE_UNFOLD         - 展开目标晶格(原胞晶格矢量)\n"
            "# ATOM_POSITIONS_UNFOLD  - 展开目标原子坐标\n"
            "# PROJECTORS_UNFOLD      - 展开目标投影轨道\n"
            "# SELECTED_ATOMS         - 超胞中选中的原子索引\n"
            "# ============================================================\n"
            "# 示例(请根据实际超胞修改):\n"
            "#\n"
            "# LATTICE_UNFOLD\n"
            "# Angstrom\n"
            "#   <a1x> <a1y> <a1z>\n"
            "#   <a2x> <a2y> <a2z>\n"
            "#   <a3x> <a3y> <a3z>\n"
            "#\n"
            "# ATOM_POSITIONS_UNFOLD\n"
            "#   <num_atoms>      ! 原胞原子数\n"
            "# Direct\n"
            "#   <symbol> <ax> <ay> <az>\n"
            "#   ...\n"
            "#\n"
            "# PROJECTORS_UNFOLD\n"
            "#   <num_proj> <num_proj> ...  ! 每个原子的投影轨道数\n"
            "#   <symbol> <orb1> <orb2> ...\n"
            "#   ...\n"
            "#\n"
            "# SELECTED_ATOMS\n"
            "#   1               ! 组数\n"
            "#   <num_selected>  ! 选中的原子数\n"
            "#   <idx1> <idx2> ...\n"
            "# ============================================================"
        )
# ============================================================================
# 自旋极化检测
# ============================================================================
def _prompt_hr_path(hr_type, search_dirs, expect_name):
    """当自动搜索找不到 hr.dat 时,提示用户手动输入路径。
    Parameters
    ----------
    hr_type : str
        描述性文字,如 "
spin-up
", "
spin-down
", "
标准(无自旋)
"
    expect_name : str
        期望文件名,如 "
wannier90.up_hr.dat
"
    search_dirs : list
        已搜索过的目录列表
    Returns
    -------
    str or None : 用户指定的有效路径,或 None(用户放弃)
    "
""
    print(f"\n  {hr_type} 的 hr.dat 文件未自动找到。")
    print(f"  期望文件名: {expect_name}")
    print(f"  已搜索目录:")
    for d in search_dirs:
        print(f"    {d}")
    print(f"\n  请手动输入 {expect_name} 的完整路径,")
    print(f"  或输入 'q' / 回车跳过。")
    for _ in range(5):
        try:
            path = input(f"  请输入路径: ").strip()
        except (EOFError, KeyboardInterrupt):
            return None
        if not path or path.lower() == 'q':
            return None
        if os.path.isfile(path):
            return os.path.abspath(path)
        print(f"  文件不存在: {path},请重新输入。")
    print(f"  达到最大尝试次数,已跳过。")
    return None
def _find_required_hr(win_dir, cwd, channel):
    """搜索 hr.dat 文件并返回绝对路径。未找到返回 None。
    Parameters
    ----------
    win_dir : str
        .win 文件所在目录
    cwd : str
        当前工作目录
    channel : str
        "
up
" → wannier90.up_hr.dat, "
dn
" → wannier90.dn_hr.dat, "
" → wannier90_hr.dat
    "
""
    fname = f"wannier90.{channel}_hr.dat" if channel else "wannier90_hr.dat"
    # 去重搜索路径(win_dir 优先,cwd 其次)
    search_dirs = []
    for d in [win_dir, cwd]:
        d_abs = os.path.abspath(d)
        if d_abs not in search_dirs:
            search_dirs.append(d_abs)
    for d in search_dirs:
        fpath = os.path.join(d, fname)
        if os.path.isfile(fpath):
            return os.path.abspath(fpath)
    return None
def detect_spin_channel(win_dir, cwd=None):
    """检测是否存在自旋极化 hr.dat 文件,返回实际文件路径。
    搜索顺序: win_dir (与 .win 同目录) → cwd (当前工作目录)
    物理背景:
    - 共线磁性 DFT 计算中 Wannier90 独立输出 wannier90.up_hr.dat / .dn_hr.dat
    - 两个自旋通道的紧束缚哈密顿量不完全相同(交换劈裂),需分别计算
    - 通道内 SOC=0,Hrfile 指向对应文件,NumOccupied=num_wann
    Parameters
    ----------
    win_dir : str
        .win 文件所在目录(hr.dat 通常在同一目录)
    cwd : str or None
        当前工作目录(备选搜索路径)
    Returns
    -------
    dict: {
        "
up
": "
/path/to/wannier90.up_hr.dat
" 或 None,
        "
dn
": "
/path/to/wannier90.dn_hr.dat
" 或 None,
        "
has_spin
": bool,
        "
available
": ["
up
", "
dn
"]  # 实际找到的通道列表
    }
    "
""
    result = {"up": None, "dn": None, "has_spin": False, "available": []}
    # 搜索目录列表:优先 .win 所在目录,其次当前工作目录
    search_dirs = [os.path.abspath(win_dir)]
    if cwd:
        cwd_abs = os.path.abspath(cwd)
        if cwd_abs not in search_dirs:
            search_dirs.append(cwd_abs)
    for search_dir in search_dirs:
        for ch in ["up", "dn"]:
            if result[ch] is not None:
                continue  # 已找到,不覆盖(win_dir 优先)
            fname = f"wannier90.{ch}_hr.dat"
            fpath = os.path.join(search_dir, fname)
            if os.path.isfile(fpath):
                result[ch] = fpath
                result[ch + "_found_dir"] = search_dir
                result["available"].append(ch)
    # 也检测 Wannier90 其他命名: *_{up,down}_hr.dat (排除标准 wannier90_hr.dat)
    if not result["available"]:
        for search_dir in search_dirs:
            try:
                for f in os.listdir(search_dir):
                    for suffix, ch in [("_up_hr.dat", "up"), ("_down_hr.dat", "dn")]:
                        if f.endswith(suffix) and f != "wannier90_hr.dat":
                            if result[ch] is None:
                                fpath = os.path.join(search_dir, f)
                                result[ch] = fpath
                                result[ch + "_found_dir"] = search_dir
                                if ch not in result["available"]:
                                    result["available"].append(ch)
            except (FileNotFoundError, PermissionError):
                continue
    result["has_spin"] = len(result["available"]) > 0
    # 检测标准(非自旋)hr.dat
    result["non_spin_hr"] = None
    for search_dir in search_dirs:
        std_path = os.path.join(search_dir, "wannier90_hr.dat")
        if os.path.isfile(std_path):
            result["non_spin_hr"] = std_path
            result["non_spin_found_dir"] = search_dir
            break
    return result
def get_spin_selection(spin_info, mode="interactive"):
    """获取用户自旋通道选择。
    Parameters
    ----------
    spin_info : dict
        detect_spin_channel() 的返回值
    mode : str
        "
interactive
": 交互式选择
        "
auto
": 自动选择(有 up/dn 就全选,包括 None(标准通道)如果无 hr 文件)
    Returns
    -------
    list of (str, str) tuples: [("
up
", hrfile_path), ("
dn
", hrfile_path)] 或
                              [(None, hrfile_path)] 或 []
    "
""
    available = spin_info.get("available", [])
    if not available:
        # 无自旋极化,检查标准 wannier90_hr.dat 是否存在
        std_hr = spin_info.get("non_spin_hr")
        if std_hr:
            return [(None, std_hr)]
        else:
            return []  # 无任何 hr.dat,交由调用方报错
    if mode == "auto":
        # 自动模式:有 up 就选 up,有 dn 就选 dn
        result = []
        for ch in available:
            hr_path = spin_info.get(ch)
            if hr_path:
                result.append((ch, hr_path))
        return result
    # 交互式选择
    print(f"\n{'=' * 72}")
    print("  检测到自旋极化计算!")
    for ch in available:
        print(f"  发现文件: {spin_info.get(ch)}")
    print(f"{'=' * 72}")
    print(f"\n  请选择要生成的自旋通道:")
    print(f"    [1] spin-up   (Hrfile -> wannier90.up_hr.dat, SOC = 0, NumOccupied = num_wann)")
    if "dn" in available:
        print(f"    [2] spin-down (Hrfile -> wannier90.dn_hr.dat, SOC = 0, NumOccupied = num_wann)")
        print(f"    [3] 两个通道都生成 (独立 wt-up/ 和 wt-dn/ 目录)")
        print(f"    [4] 跳过(不生成自旋极化文件)")
    else:
        print(f"    [2] 跳过(不生成)")
    while True:
        try:
            choice = input("\n  请选择: ").strip()
            if choice == "1":
                return [(ch, spin_info[ch]) for ch in available if spin_info[ch]][:1]
            elif choice == "2" and "dn" in available:
                return [(ch, spin_info[ch]) for ch in ["dn"] if spin_info[ch]]
            elif choice == "3" and "dn" in available:
                return [(ch, spin_info[ch]) for ch in available if spin_info[ch]]
            elif choice in ("2", "4"):
                return []
            else:
                print("  无效选择,请重新输入")
        except (EOFError, KeyboardInterrupt):
            return []
# ============================================================================
# 交互式菜单
# ============================================================================
def print_banner():
    """打印横幅"""
    print("=" * 72)
    print("   win2wt: Wannier90 (.win) → WannierTools (wt.in) 转换工具")
    print(f"   版本 {__version__} ({__date__})")
    print("=" * 72)
def print_menu():
    """打印任务选择菜单"""
    print("\n可用计算任务(输入编号或任务代码选择):\n")
    task_list = list(TASK_DEFINITIONS.keys())
    idx = 1
    task_index = {}
    for group_name, group_keys in TASK_GROUPS.items():
        print(f"  ┌─ {group_name} ──────────────────────────────────────────────┐")
        for key in group_keys:
            task = TASK_DEFINITIONS[key]
            print(f"  │ [{idx:2d}] {key:<18s} {task['name']:<42s} │")
            task_index[idx] = key
            idx += 1
        print(f"  └{'─' * 62}┘")
    print(f"\n  [{idx:2d}] 生成全部 19 个任务")
    task_index[idx] = "ALL"
    idx += 1
    print(f"  [{idx:2d}] 退出")
    return task_index
def get_user_selection(task_index):
    """获取用户选择"""
    while True:
        try:
            choice = input("\n请选择 [输入编号或任务代码,多个用逗号分隔]: ").strip()
            if not choice:
                continue
            selected = set()
            # 支持逗号分隔的多个选择
            parts = [p.strip() for p in choice.split(",")]
            for part in parts:
                # 尝试数字
                try:
                    num = int(part)
                    if num in task_index:
                        val = task_index[num]
                        if val == "ALL":
                            return list(TASK_DEFINITIONS.keys())
                        selected.add(val)
                    else:
                        print(f"  无效编号: {num}")
                        continue
                except ValueError:
                    # 尝试任务代码
                    if part.lower() in TASK_DEFINITIONS:
                        selected.add(part.lower())
                    elif part.lower() == "all":
                        return list(TASK_DEFINITIONS.keys())
                    elif part.lower() in ["q", "quit", "exit"]:
                        return []
                    else:
                        print(f"  无效任务代码: {part}")
                        continue
            if selected:
                return list(selected)
            else:
                print("  未选择有效任务,请重新输入。")
        except (EOFError, KeyboardInterrupt):
            print("\n")
            return []
def interactive_mode(win_path, output_dir=None, spin_channel=None):
    """交互式模式。
    自动检测自旋极化体系(搜索 .win 目录 + 当前工作目录),
    提示用户选择自旋通道和计算任务,生成对应的 wt.in 文件。
    "
""
    print_banner()
    # 解析 .win 文件
    print(f"\n正在解析: {win_path}")
    try:
        parser = WinParser(win_path)
    except Exception as e:
        print(f"  ✗ 解析失败: {e}")
        sys.exit(1)
    # 检测自旋极化(搜索 win_dir + cwd)
    win_dir = os.path.dirname(os.path.abspath(win_path)) or "."
    cwd = os.getcwd()
    spin_info = detect_spin_channel(win_dir, cwd)
    # 显示解析摘要
    print(f"  ✓ 晶格: {parser.get_lattice_vectors() is not None}")
    print(f"  ✓ 原子数: {len(parser.get_atom_positions_cart())}")
    print(f"  ✓ 投影数: {len(parser.get_projections())}")
    print(f"  ✓ k-path: {len(parser.get_kpoint_path())} 条")
    print(f"  ✓ 费米能级: {parser.get_fermi_energy()} eV")
    print(f"  ✓ SOC: {'是' if parser.get_soc() else '否'}")
    print(f"  ✓ Wannier 函数数: {parser.get_num_wann()}")
    num_occ = parser.get_num_wann() // 2 if parser.get_soc() else parser.get_num_wann()
    print(f"  ✓ 估算占据数: {num_occ}")
    if spin_info["has_spin"]:
        print(f"  ✓ 检测到自旋极化 HR 文件:")
        for ch in spin_info["available"]:
            print(f"      {ch.upper()}: {spin_info[ch]}")
    elif spin_info.get("non_spin_hr"):
        print(f"  ✓ 标准 hr.dat: {spin_info['non_spin_hr']}")
    else:
        print(f"  ⚠ 未检测到 hr.dat 文件(将使用默认路径 wannier90_hr.dat)")
    # 获取自旋通道选择
    if spin_channel is not None:
        # 用户显式指定了 --spin up|dn|both
        if spin_channel == "both":
            spin_channels = [(ch, spin_info[ch]) for ch in spin_info["available"] if spin_info[ch]]
        else:
            hr_path = os.path.join(win_dir, f"wannier90.{spin_channel}_hr.dat")
            if not os.path.isfile(hr_path):
                hr_path = f"wannier90.{spin_channel}_hr.dat"
            spin_channels = [(spin_channel, hr_path)]
    else:
        spin_channels = get_spin_selection(spin_info, mode="interactive")
    if not spin_channels:
        print("退出。")
        return
    # 显示菜单
    task_index = print_menu()
    # 获取选择
    selected = get_user_selection(task_index)
    if not selected:
        print("退出。")
        return
    # 生成文件
    if output_dir is None:
        output_dir = win_dir
    os.makedirs(output_dir, exist_ok=True)
    print(f"\n正在生成 wt.in 文件到: {output_dir}/")
    print("-" * 72)
    for ch, hr_path in spin_channels:
        generator = WtInGenerator(parser, spin_channel=ch, hrfile_path=hr_path)
        if ch:
            spin_dir = os.path.join(output_dir, f"wt-{ch}")
            os.makedirs(spin_dir, exist_ok=True)
            print(f"\n  [自旋通道: {ch.upper()}] Hrfile: {hr_path}")
            print(f"  [输出到: {spin_dir}/]")
        else:
            spin_dir = output_dir
        for task_key in selected:
            task = TASK_DEFINITIONS[task_key]
            output_path = os.path.join(spin_dir, f"wt.in-{task_key}")
            try:
                generator.generate(task_key, output_path)
            except Exception as e:
                print(f"  ✗ {task_key} 生成失败: {e}")
    print("-" * 72)
    total = len(selected) * len(spin_channels)
    print(f"\n完成! 共生成 {total} 个 wt.in 文件。")
    print(f"使用方法: cp wt.in-<task> wt.in && mpirun -np 4 wt.x")
def batch_mode_forced(win_path, tasks, output_dir=None):
    """强制标准模式 --- 不检测 up/dn,直接使用 wannier90_hr.dat。
    用于用户显式指定 --nospin 的场景。
    "
""
    parser = WinParser(win_path)
    if output_dir is None:
        output_dir = os.path.dirname(os.path.abspath(win_path)) or "."
    os.makedirs(output_dir, exist_ok=True)
    win_dir = os.path.dirname(os.path.abspath(win_path)) or "."
    cwd = os.getcwd()
    search_dirs = [win_dir, cwd]
    hr_path = _find_required_hr(win_dir, cwd, "")
    if hr_path is None:
        hr_path = _prompt_hr_path("标准(无自旋)", search_dirs, "wannier90_hr.dat")
        if hr_path is None:
            print("  已跳过,未生成任何文件。")
            return
    generator = WtInGenerator(parser, spin_channel=None, hrfile_path=hr_path)
    print(f"\n  模式: 强制标准")
    print(f"  Hrfile: {hr_path}")
    print(f"  输出到: {output_dir}/")
    for task_key in tasks:
        if task_key not in TASK_DEFINITIONS:
            print(f"  ✗ 未知任务: {task_key},跳过。")
            continue
        output_path = os.path.join(output_dir, f"wt.in-{task_key}")
        generator.generate(task_key, output_path)
def batch_mode(win_path, tasks, output_dir=None, spin_channel=None):
    """批量生成模式。
    自动检测当前文件夹和 .win 目录的 up/dn hr.dat 文件,
    有则自动分别生成对应通道的 wt.in 文件。
    "
""
    parser = WinParser(win_path)
    if output_dir is None:
        output_dir = os.path.dirname(os.path.abspath(win_path)) or "."
    win_dir = os.path.dirname(os.path.abspath(win_path)) or "."
    cwd = os.getcwd()
    if spin_channel is not None:
        # 用户显式指定了 --up / --dn → 必须找到对应文件
        ch = spin_channel
        hr_path = _find_required_hr(win_dir, cwd, ch)
        if hr_path is None:
            hr_path = _prompt_hr_path(
                f"spin-{ch}", [win_dir, cwd], f"wannier90.{ch}_hr.dat"
            )
            if hr_path is None:
                print("  已跳过,未生成任何文件。")
                return False
        spin_channels = [(ch, hr_path)]
    else:
        # 自动检测:有 up/dn 就都用,没有就回退到标准 hr.dat
        spin_info = detect_spin_channel(win_dir, cwd)
        spin_channels = get_spin_selection(spin_info, mode="auto")
    if not spin_channels:
        same_dirs = []
        for d in [win_dir, cwd]:
            d_abs = os.path.abspath(d)
            if d_abs not in same_dirs:
                same_dirs.append(d_abs)
        hr_path = _prompt_hr_path("通用", same_dirs, "wannier90_hr.dat")
        if hr_path is None:
            print("  未检测到任何 hr.dat 文件,退出。")
            print("  期望文件: wannier90_hr.dat, wannier90.up_hr.dat 或 wannier90.dn_hr.dat")
            return False
        spin_channels = [(None, hr_path)]
    for ch, hr_path in spin_channels:
        generator = WtInGenerator(parser, spin_channel=ch, hrfile_path=hr_path)
        if ch:
            spin_dir = os.path.join(output_dir, f"wt-{ch}")
            os.makedirs(spin_dir, exist_ok=True)
            print(f"\n{'=' * 60}")
            print(f"  自旋通道: {ch.upper()}")
            print(f"  Hrfile:   {hr_path}")
            print(f"  SOC:      0 (自旋通道内无 SOC)")
            print(f"  NumOccupied: {generator.num_occupied} (= num_wann)")
            print(f"  输出目录: {spin_dir}/")
            print(f"{'=' * 60}")
        else:
            spin_dir = output_dir
            os.makedirs(spin_dir, exist_ok=True)
            print(f"\n  标准模式(无自旋极化)")
            print(f"  Hrfile: {hr_path}")
            print(f"  输出到: {spin_dir}/")
        for task_key in tasks:
            if task_key not in TASK_DEFINITIONS:
                print(f"  ✗ 未知任务: {task_key},跳过。")
                continue
            output_path = os.path.join(spin_dir, f"wt.in-{task_key}")
            generator.generate(task_key, output_path)
    return True
# ============================================================================
# 主入口
# ============================================================================
def main():
    parser = argparse.ArgumentParser(
        description="win2wt: Wannier90 (.win) → WannierTools (wt.in) 自动转换工具",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=textwrap.dedent("""\
            示例:
              %(prog)s wannier90.win                  # 交互式选择任务和自旋通道
              %(prog)s wannier90.win --all            # 生成所有任务(自动检测 up/dn)
              %(prog)s wannier90.win -t bands,dos     # 生成指定任务
              %(prog)s wannier90.win --all --up       # 所有任务, 只生成 spin-up
              %(prog)s wannier90.win --all --dn       # 所有任务, 只生成 spin-down
              %(prog)s wannier90.win --all -o ./out   # 指定输出目录
              %(prog)s wannier90.win --list           # 列出任务
            提示: --all 模式下默认自动检测 up/dn 并同时生成,
                  无需手动指定 --spin。
        "
""
),
    )
    parser.add_argument("win_file", nargs="?", help="Wannier90 .win 输入文件路径")
    parser.add_argument("--all", "-a", action="store_true",
                        help="生成所有 19 个任务的 wt.in 文件")
    parser.add_argument("--list", "-l", action="store_true",
                        help="列出所有可用任务")
    parser.add_argument("--tasks", "-t", type=str,
                        help="逗号分隔的任务代码列表 (如 bands,dos,ahc)")
    parser.add_argument("--output", "-o", type=str, default=None,
                        help="输出目录(默认与 .win 同目录)")
    parser.add_argument("--up", action="store_true",
                        help="只生成 spin-up 通道(跳过交互选择)")
    parser.add_argument("--dn", action="store_true",
                        help="只生成 spin-down 通道(跳过交互选择)")
    parser.add_argument("--nospin", action="store_true",
                        help="强制使用标准模式(不检测 up/dn,即使用 wannier90_hr.dat)")
    args = parser.parse_args()
    # --list 模式
    if args.list:
        print_banner()
        print("\n可用任务列表:\n")
        for group_name, group_keys in TASK_GROUPS.items():
            print(f"  [{group_name}]")
            for key in group_keys:
                task = TASK_DEFINITIONS[key]
                print(f"    {key:<18s} - {task['name']}")
            print()
        return
    # 检查互斥选项
    if args.up and args.dn:
        print("错误: --up 和 --dn 不能同时使用。要生成两个通道,直接不加即可(默认自动检测)。")
        sys.exit(1)
    if args.up and args.nospin:
        print("错误: --up 和 --nospin 不能同时使用。")
        sys.exit(1)
    if args.dn and args.nospin:
        print("错误: --dn 和 --nospin 不能同时使用。")
        sys.exit(1)
    # 需要 .win 文件
    if not args.win_file:
        parser.print_help()
        sys.exit(1)
    if not os.path.exists(args.win_file):
        print(f"错误: 文件不存在: {args.win_file}")
        sys.exit(1)
    # 确定输出目录
    output_dir = args.output
    if output_dir is None:
        output_dir = os.path.dirname(os.path.abspath(args.win_file)) or "."
    # 确定自旋通道
    if args.nospin:
        spin_ch = "nospin"
    elif args.up:
        spin_ch = "up"
    elif args.dn:
        spin_ch = "dn"
    else:
        spin_ch = None  # auto-detect
    if args.all:
        print_banner()
        print(f"\n正在生成所有 19 个任务的 wt.in 文件...")
        if spin_ch == "nospin":
            batch_mode_forced(args.win_file, list(TASK_DEFINITIONS.keys()), output_dir)
            print(f"\n完成! 文件输出到: {output_dir}/")
        else:
            ok = batch_mode(args.win_file, list(TASK_DEFINITIONS.keys()), output_dir, spin_ch)
            if ok:
                print(f"\n完成! 文件输出到: {output_dir}/")
    elif args.tasks:
        tasks = [t.strip() for t in args.tasks.split(",")]
        if spin_ch == "nospin":
            batch_mode_forced(args.win_file, tasks, output_dir)
            print(f"\n完成! 文件输出到: {output_dir}/")
        else:
            ok = batch_mode(args.win_file, tasks, output_dir, spin_ch)
            if ok:
                print(f"\n完成! 文件输出到: {output_dir}/")
    else:
        interactive_mode(args.win_file, output_dir, spin_ch)
if __name__ == "__main__":
    main()

auto_test.sh 测试脚本

复制代码
#!/bin/bash
# ============================================================================
# auto_test.sh --- WannierTools wt.in 批量自动测试脚本
# 版本: 1.0 | 日期: 2026-06-06
#
# 功能:
#   1. 批量测试所有 wt.in-* 文件(cp 为 wt.in → wt.x → 保存 WT.out)
#   2. 超时控制(防止某任务卡死)
#   3. 自动检测 hr.dat 文件(支持 up/dn 自旋极化)
#   4. 汇总报告(通过/超时/失败/错误统计)
#   5. 支持并行模式(利用 GNU parallel 或背景进程)
#
# 用法:
#   ./auto_test.sh                             # 在当前目录测试所有 wt.in-*
#   ./auto_test.sh -d /path/to/wt-files         # 指定目录
#   ./auto_test.sh -t 120                       # 超时 120 秒
#   ./auto_test.sh -p 4                         # 4 核并行
#   ./auto_test.sh -s "bands dos ahc"           # 只测试指定任务
#   ./auto_test.sh --only-check                 # 只检查 in 文件,不运行
# ============================================================================
set -euo pipefail
# --- 默认参数 ---
WT_EXEC="${WT_EXEC:-wt.x}"
TIMEOUT=50
PARALLEL=1
WORK_DIR="."
SELECTED_TASKS=""
ONLY_CHECK=false
VERBOSE=false
# --- 颜色 ---
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
CYAN='\033[0;36m'
NC='\033[0m' # No Color
# --- 帮助 ---
usage() {
    cat <<EOF
用法: $0 [选项]
选项:
  -d DIR       工作目录 (默认: .)
  -t SECONDS   单任务超时秒数 (默认: 50)
  -p N         并行核数 (默认: 1, 串行)
  -s "task1 task2 ..."  只测试指定任务
  --only-check 只检查 in 文件语法,不运行 wt.x
  -v           详细输出
  -h           显示帮助
示例:
  $0                          # 测试当前目录所有 wt.in-*
  $0 -d ./wt-up -t 120        # 测试 wt-up 目录,120s 超时
  $0 -s "bands dos ahc"       # 只测试 bands, dos, ahc
  $0 -p 4                     # 4 核并行测试
EOF
    exit 0
}
# --- 解析参数 ---
while [[ $# -gt 0 ]]; do
    case "$1" in
        -d) WORK_DIR="$2"; shift 2 ;;
        -t) TIMEOUT="$2"; shift 2 ;;
        -p) PARALLEL="$2"; shift 2 ;;
        -s) SELECTED_TASKS="$2"; shift 2 ;;
        --only-check) ONLY_CHECK=true; shift ;;
        -v) VERBOSE=true; shift ;;
        -h|--help) usage ;;
        *) echo "未知选项: $1"; usage ;;
    esac
done
cd "$WORK_DIR" || { echo -e "${RED}错误: 无法进入目录 $WORK_DIR${NC}"; exit 1; }
# --- 检测 hr.dat ---
detect_hr() {
    if [ -f "wannier90.up_hr.dat" ]; then
        echo "up"
    elif [ -f "wannier90.dn_hr.dat" ]; then
        echo "dn"
    elif [ -f "wannier90_hr.dat" ]; then
        echo "standard"
    else
        echo ""
    fi
}
# --- 检查单个 wt.in 是否有致命参数错误 ---
check_wt_in() {
    local f="$1"
    local errors=0
    # 检查 NSLAB 是否在 &PARAMETERS(应在 &SYSTEM)
    if grep -q "^&PARAMETERS" "$f" 2>/dev/null; then
        if awk '/^&PARAMETERS/,/^\//' "$f" | grep -q "NSLAB\|NSLAB1\|NSLAB2\|^[[:space:]]*NP[[:space:]]*=" 2>/dev/null; then
            echo -e "  ${RED}✗ NSLAB/NP 在 &PARAMETERS 中(应在 &SYSTEM)${NC}"
            errors=$((errors + 1))
        fi
        if awk '/^&PARAMETERS/,/^\//' "$f" | grep -q "Bmagnitude\|Btheta\|Bphi\|^[[:space:]]*Bx\|^[[:space:]]*By\|^[[:space:]]*Bz" 2>/dev/null; then
            echo -e "  ${RED}✗ Bmagnitude/Btheta/Bphi 在 &PARAMETERS 中(应在 &SYSTEM)${NC}"
            errors=$((errors + 1))
        fi
    fi
    # 检查 Hrfile 是否与目录中的 hr.dat 匹配
    local hr=$(detect_hr)
    if [ -n "$hr" ] && [ "$hr" != "standard" ]; then
        if ! grep -q "wannier90.${hr}_hr.dat" "$f" 2>/dev/null; then
            echo -e "  ${YELLOW}⚠ Hrfile 可能不匹配(期望 wannier90.${hr}_hr.dat)${NC}"
        fi
    fi
    return $errors
}
# --- 运行单个测试 ---
run_one() {
    local f="$1"
    local task="${f#wt.in-}"
    local out_file="WT-${f}.out"
    local start_time=$(date +%s)
    if $VERBOSE; then
        printf "${BLUE}[测试]${NC} %-18s ... " "$task"
    fi
    # 清理旧文件
    rm -f WT.out "$out_file"
    # 复制为 wt.in(WannierTools 只读 ./wt.in)
    cp "$f" wt.in
    # 运行(超时控制)
    if timeout "$TIMEOUT" "$WT_EXEC" > /dev/null 2>&1; then
        local elapsed=$(($(date +%s) - start_time))
        if [ -f WT.out ]; then
            cp WT.out "$out_file"
        fi
        if $VERBOSE; then
            echo -e "${GREEN}OK${NC} (${elapsed}s)"
        fi
        echo "PASS ${task} ${elapsed}" >> "$SUMMARY_FILE"
    else
        local rc=$?
        if [ -f WT.out ]; then
            cp WT.out "$out_file"
        fi
        if [ $rc -eq 124 ]; then
            local elapsed="$TIMEOUT"
            if $VERBOSE; then
                echo -e "${YELLOW}TIMEOUT${NC} (>${TIMEOUT}s)"
            fi
            echo "TIMEOUT ${task} ${elapsed}" >> "$SUMMARY_FILE"
        else
            local nerr=$(grep -c "Error\|ERROR" "$out_file" 2>/dev/null || echo 0)
            if $VERBOSE; then
                echo -e "${RED}FAIL${NC} (exit=$rc, errors=$nerr)"
            fi
            echo "FAIL ${task} ${rc} ${nerr}" >> "$SUMMARY_FILE"
        fi
    fi
    # 清理临时 wt.in
    rm -f wt.in
}
# --- 主流程 ---
main() {
    echo -e "${CYAN}========================================${NC}"
    echo -e "${CYAN}  WannierTools 批量测试工具${NC}"
    echo -e "${CYAN}========================================${NC}"
    echo "  工作目录: $(pwd)"
    echo "  超时设置: ${TIMEOUT}s"
    echo "  并行核数: ${PARALLEL}"
    echo "  wt.x路径: ${WT_EXEC}"
    echo "  时间: $(date '+%Y-%m-%d %H:%M:%S')"
    # 检测 hr.dat
    local hr=$(detect_hr)
    if [ -z "$hr" ]; then
        echo -e "${RED}错误: 未找到 hr.dat 文件!${NC}"
        echo "  期望: wannier90_hr.dat, wannier90.up_hr.dat 或 wannier90.dn_hr.dat"
        exit 1
    fi
    echo -e "  hr.dat:  ${GREEN}$hr${NC}"
    # 收集任务列表
    local files=()
    if [ -n "$SELECTED_TASKS" ]; then
        for task in $SELECTED_TASKS; do
            if [ -f "wt.in-${task}" ]; then
                files+=("wt.in-${task}")
            else
                echo -e "  ${YELLOW}警告: wt.in-${task} 不存在,跳过${NC}"
            fi
        done
    else
        for f in wt.in-*; do
            [ -f "$f" ] && files+=("$f")
        done
    fi
    local total=${#files[@]}
    if [ $total -eq 0 ]; then
        echo -e "${RED}错误: 未找到任何 wt.in-* 文件!${NC}"
        exit 1
    fi
    echo -e "  任务数:  ${total}"
    # 语法检查模式
    if $ONLY_CHECK; then
        echo ""
        echo -e "${CYAN}--- 语法检查 ---${NC}"
        local check_errors=0
        for f in "${files[@]}"; do
            local task="${f#wt.in-}"
            echo -e "  ${BLUE}$task${NC}"
            if ! check_wt_in "$f"; then
                check_errors=$((check_errors + 1))
            fi
        done
        echo ""
        if [ $check_errors -eq 0 ]; then
            echo -e "${GREEN}全部通过!${NC}"
        else
            echo -e "${RED}$check_errors 个文件存在问题${NC}"
        fi
        exit $check_errors
    fi
    # 运行测试
    echo ""
    echo -e "${CYAN}--- 运行测试 ---${NC}"
    echo ""
    SUMMARY_FILE=$(mktemp)
    if [ "$PARALLEL" -gt 1 ]; then
        echo "  并行模式: $PARALLEL 核"
        export TIMEOUT WT_EXEC SUMMARY_FILE VERBOSE RED GREEN YELLOW BLUE CYAN NC
        export -f run_one detect_hr
        printf '%s\n' "${files[@]}" | xargs -P "$PARALLEL" -I {} bash -c 'run_one "$@"' _ {}
        wait
    else
        local count=0
        for f in "${files[@]}"; do
            count=$((count + 1))
            printf "[%2d/%2d] " $count $total
            run_one "$f"
        done
    fi
    # 汇总
    echo ""
    echo -e "${CYAN}========================================${NC}"
    echo -e "${CYAN}  测试汇总${NC}"
    echo -e "${CYAN}========================================${NC}"
    local n_pass=$(grep -c "^PASS" "$SUMMARY_FILE" 2>/dev/null || echo 0)
    local n_timeout=$(grep -c "^TIMEOUT" "$SUMMARY_FILE" 2>/dev/null || echo 0)
    local n_fail=$(grep -c "^FAIL" "$SUMMARY_FILE" 2>/dev/null || echo 0)
    echo ""
    echo -e "  ${GREEN}通过:   $n_pass${NC}"
    echo -e "  ${YELLOW}超时:   $n_timeout${NC}"
    echo -e "  ${RED}失败:   $n_fail${NC}"
    echo -e "  ${CYAN}总计:   $total${NC}"
    echo ""
    # 列出失败详情
    if [ "$n_fail" -gt 0 ]; then
        echo -e "${RED}失败详情:${NC}"
        grep "^FAIL" "$SUMMARY_FILE" | while read -r line; do
            echo "  $line"
        done
    fi
    # 列出超时详情
    if [ "$n_timeout" -gt 0 ]; then
        echo -e "${YELLOW}超时任务 (仍在运行):${NC}"
        grep "^TIMEOUT" "$SUMMARY_FILE" | while read -r line; do
            echo "  $line"
        done
    fi
    # 输出文件位置
    echo ""
    echo "WT.out 文件已保存至: $(pwd)/WT-wt.in-*.out"
    rm -f "$SUMMARY_FILE"
    # 返回码
    if [ "$n_fail" -gt 0 ]; then
        exit 1
    else
        exit 0
    fi
}
main
相关推荐
weixin_468466852 小时前
网络数据采集新手入门指南
python·网络爬虫·conda·编程
大神15732 小时前
Cordova Android 签名三种方式详解:证书生成、命令行直接签名与配置文件自动签名
android·java
武子康2 小时前
调查研究-170 Vert.x 是什么?它和 Netty 到底是什么关系?一张图讲清 Java 异步技术栈选型
java·后端
叫我:松哥2 小时前
基于卷积神经网络的人脸情绪识别算法,引入残差连接与SE注意力模块
人工智能·深度学习·神经网络·算法·cnn·迁移学习·图像识别
m沐沐2 小时前
【计算机视觉】OpenCV 模板匹配银行卡数字识别---上
人工智能·后端·python·opencv·计算机视觉·pycharm·numpy
deephub2 小时前
2026 年开源 Agent 工具包选型指南:延迟、审计、可移植性与语言栈
人工智能·python·大语言模型·多智能体
ellenwan20262 小时前
期货量化尾盘没清仓:天勤 trading_time 过滤与收盘前平仓
python·区块链
rising start2 小时前
九、vue3 组件通信:全场景详解
前端·vue.js·typescript
VOLUN2 小时前
告别 AI 乱码!Vue3+TS 项目的 AI 编码助手规范实践
前端·ai编程