希尔排序(Shell Sort)详解
------原理、步长选择与高性能 C++ 实现
一、什么是希尔排序?
希尔排序(Shell Sort)是插入排序的改进版本 ,由 Donald L. Shell 于 1959 年提出。
它通过分组插入排序的方式,先对相距较远的元素进行粗略排序,再逐步缩小距离,最终完成精细排序。
✅ 核心思想
"先让数组大致有序,再用插入排序收尾。"
- 将数组按步长(gap) 分成若干子序列
- 对每个子序列执行插入排序
- 逐步减小
gap,直到gap = 1(即标准插入排序)
📈 优势
- 比普通插入排序快得多(尤其在中等规模数据上)
- 原地排序(O(1) 额外空间)
- 无需递归,适合嵌入式或受限环境
⚠️ 劣势
- 不稳定(相同元素可能交换位置)
- 性能高度依赖 gap 序列的选择
cpp 实现
cpp
void ShellSort(int arr[], int size, int stride) {
// 防止初始 gap 为 0
if (size <= 1) return;
for (int gap = size / stride; gap > 0; gap /= stride) {
for (int i = gap; i < size; i++) {
int temp = arr[i];
int j = i - gap;
for (; j >= 0; j -= gap) {
if (arr[j] > temp) {
arr[j + gap] = arr[j];
}
else {
break;
}
}
arr[j + gap] = temp;
}
}
}
二、为什么 gap 序列如此重要?
希尔排序的性能完全由 gap 序列决定:
| 序列类型 | 最坏时间复杂度 | 是否可能退化 | 实际表现 |
|---|---|---|---|
| 原始 Shell(gap = n/2, n/4, ...) | O(n2)O(n^2)O(n2) | ✅ 是(如 n=2^k) | 差,已淘汰 |
| Knuth 序列 | O(n3/2)O(n^{3/2})O(n3/2) | ❌ 否 | 优秀,理论保障 |
| Ciura 序列 | 未知(实测最优) | ❌ 否 | 当前最快 |
💡 关键原则:
- 最后一步必须是
gap = 1- 相邻 gap 应尽量互质(避免某些元素永不比较)
- gap 应快速减小但不过于稀疏
三、推荐 gap 序列(只记这两个!)
1. Knuth 序列(理论优雅,教学首选)
- 公式 :
h0=1,hk=3⋅hk−1+1 h_0 = 1,\quad h_k = 3 \cdot h_{k-1} + 1 h0=1,hk=3⋅hk−1+1 - 序列 :
1, 4, 13, 40, 121, 364, 1093, ... - 使用方式 :从最大
gap ≤ n/3开始,倒序使用 - 优点:公式简单,无需预存,理论有保障
🗣️ 口诀:"乘3加1,倒着排"
2. Ciura 序列(实测最快,工业首选)
- 来源:Marcin Ciura 通过百万级实验得出
- 基础序列 (必须记住):
1, 4, 10, 23, 57, 132, 301, 701 - 扩展规则 (n > 1577 时):
next_gap=⌊2.25×prev_gap⌋ \text{next\_gap} = \left\lfloor 2.25 \times \text{prev\_gap} \right\rfloor next_gap=⌊2.25×prev_gap⌋ - 优点:实测比 Knuth 快 20%~30%,被 GNU qsort 等采用
🗣️ 口诀:"1-4-10-23-57-132-301-701,倒着用!"
四、C++ 高性能实现
以下两个实现均包含:
- ✅ 小数组优化(N ≤ 32 时回退到插入排序)
- ✅ 边界安全
- ✅ 工业级可用
cpp
#include <iostream>
// 插入排序(用于小数组)
void insert_sort(int arr[], int n) {
for (int i = 1; i < n; ++i) {
int key = arr[i];
int j = i - 1;
while (j >= 0 && arr[j] > key) {
arr[j + 1] = arr[j];
--j;
}
arr[j + 1] = key;
}
}
// ────────────────────────────────
// 希尔排序:Knuth 序列
// ────────────────────────────────
void shell_sort_knuth(int arr[], int n) {
if (n <= 32) { insert_sort(arr, n); return; }
// 生成最大 gap (≤ n/3)
int gap = 1;
while (gap < n / 3) gap = gap * 3 + 1;
// 从大到小应用 gap
while (gap >= 1) {
for (int i = gap; i < n; ++i) {
int temp = arr[i];
int j = i - gap;
while (j >= 0 && arr[j] > temp) {
arr[j + gap] = arr[j];
j -= gap;
}
arr[j + gap] = temp;
}
gap /= 3;
}
}
// ────────────────────────────────
// 希尔排序:Ciura 序列(实测最快)
// ────────────────────────────────
void shell_sort_ciura(int arr[], int n) {
if (n <= 32) { insert_sort(arr, n); return; }
// 预定义 Ciura 序列(倒序!)
static const int gaps[] = {701, 301, 132, 57, 23, 10, 4, 1};
const int num_gaps = sizeof(gaps) / sizeof(gaps[0]);
// 使用预定义 gaps
for (int k = 0; k < num_gaps; ++k) {
int gap = gaps[k];
if (gap >= n) continue;
for (int i = gap; i < n; ++i) {
int temp = arr[i];
int j = i - gap;
while (j >= 0 && arr[j] > temp) {
arr[j + gap] = arr[j];
j -= gap;
}
arr[j + gap] = temp;
}
}
// 动态扩展(n 很大时)
if (n > 1577) {
int gap = 701;
while (true) {
gap = static_cast<int>(2.25 * gap);
if (gap >= n) break;
for (int i = gap; i < n; ++i) {
int temp = arr[i];
int j = i - gap;
while (j >= 0 && arr[j] > temp) {
arr[j + gap] = arr[j];
j -= gap;
}
arr[j + gap] = temp;
}
}
}
}
五、使用建议
| 场景 | 推荐方案 |
|---|---|
| 学习/教学/面试 | shell_sort_knuth(原理清晰) |
| 竞赛/工程/高性能 | shell_sort_ciura(速度最快) |
| 通用实践 | 两者都加 小数组优化(N ≤ 32) |
| 替代方案 | 大多数情况直接用 std::sort |
六、常见误区
❌ 误区1 :用 gap = n/2, n/4, ...
→ 这是原始 Shell 序列,已被淘汰 ,可能退化为 O(n2)O(n^2)O(n2)
❌ 误区2 :希尔排序总是比插入排序快
→ 在已排序数组 或极小数组上,它反而更慢(负优化)
✅ 正确做法:
- 小数组(N ≤ 32)→ 插入排序
- 中等数组 → Knuth 或 Ciura
- 大数组 → 优先考虑
std::sort
七、总结
| 特性 | Knuth | Ciura |
|---|---|---|
| 理论依据 | ✅ 强 | ❌(实验得出) |
| 实现难度 | ⭐⭐ | ⭐⭐ |
| 实际速度 | 快 | 最快 |
| 适用场景 | 教学、通用 | 竞赛、工程 |
🎯 记住两个序列就够了:
- Knuth :
1, 4, 13, 40...(乘3加1)- Ciura :
1, 4, 10, 23, 57, 132, 301, 701(背下来!)
希尔排序虽非现代主流,但在无 STL、内存受限、或需要确定性性能的场景中,仍是利器!
测试代码
cpp
#include <iostream>
#include <vector>
#include <chrono>
#include <random>
#include <algorithm>
#include <string>
#include <iomanip>
// ========== 1. 你的原始版本(stride 控制) ==========
void ShellSort_Stride(int arr[], int size, int stride = 2) {
if (size <= 1) return;
for (int gap = size / stride; gap > 0; gap /= stride) {
for (int i = gap; i < size; i++) {
int temp = arr[i];
int j = i - gap;
for (; j >= 0; j -= gap) {
if (arr[j] > temp) {
arr[j + gap] = arr[j];
}
else {
break;
}
}
arr[j + gap] = temp;
}
}
}
// ========== 2. 插入排序(用于小数组) ==========
void InsertSort(int arr[], int n) {
for (int i = 1; i < n; ++i) {
int key = arr[i];
int j = i - 1;
while (j >= 0 && arr[j] > key) {
arr[j + 1] = arr[j];
--j;
}
arr[j + 1] = key;
}
}
// ========== 3. Knuth 序列版本 ==========
void ShellSort_Knuth(int arr[], int n) {
if (n <= 32) { InsertSort(arr, n); return; }
int gap = 1;
while (gap < n / 3) gap = gap * 3 + 1;
while (gap >= 1) {
for (int i = gap; i < n; ++i) {
int temp = arr[i];
int j = i - gap;
while (j >= 0 && arr[j] > temp) {
arr[j + gap] = arr[j];
j -= gap;
}
arr[j + gap] = temp;
}
gap /= 3;
}
}
// ========== 4. Ciura 序列版本 ==========
void ShellSort_Ciura(int arr[], int n) {
if (n <= 32) { InsertSort(arr, n); return; }
static const int gaps[] = { 701, 301, 132, 57, 23, 10, 4, 1 };
const int num_gaps = sizeof(gaps) / sizeof(gaps[0]);
for (int k = 0; k < num_gaps; ++k) {
int gap = gaps[k];
if (gap >= n) continue;
for (int i = gap; i < n; ++i) {
int temp = arr[i];
int j = i - gap;
while (j >= 0 && arr[j] > temp) {
arr[j + gap] = arr[j];
j -= gap;
}
arr[j + gap] = temp;
}
}
// 扩展 Ciura 序列(n 很大时)
if (n > 1577) {
int gap = 701;
while (true) {
gap = static_cast<int>(2.25 * gap);
if (gap >= n) break;
for (int i = gap; i < n; ++i) {
int temp = arr[i];
int j = i - gap;
while (j >= 0 && arr[j] > temp) {
arr[j + gap] = arr[j];
j -= gap;
}
arr[j + gap] = temp;
}
}
}
}
// ========== 辅助函数 ==========
bool is_sorted(const int arr[], int n) {
for (int i = 1; i < n; ++i)
if (arr[i] < arr[i - 1]) return false;
return true;
}
template<typename Func>
long long time_sort(Func sort_func, const std::vector<int>& original) {
auto data = original;
auto start = std::chrono::high_resolution_clock::now();
sort_func(data.data(), static_cast<int>(data.size()));
auto end = std::chrono::high_resolution_clock::now();
auto us = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
if (!is_sorted(data.data(), data.size())) {
std::cerr << "ERROR: Sorting failed!\n";
exit(1);
}
return us;
}
// ========== 主函数 ==========
int main() {
std::vector<int> sizes = { 10000, 50000, 100000, 200000 };
std::mt19937 rng(2026); // 固定种子,结果可复现
std::cout << std::fixed << std::setprecision(1);
std::cout << "=== 希尔排序性能对比 ===\n";
std::cout << "随机整数数组 | 正确性已验证\n\n";
for (int N : sizes) {
// 生成随机数据
std::vector<int> original(N);
std::uniform_int_distribution<int> dist(0, N * 10);
for (int& x : original) x = dist(rng);
// 测试三种实现
long long t_stride = time_sort([](int* a, int n) { ShellSort_Stride(a, n, 2); }, original);
long long t_knuth = time_sort(ShellSort_Knuth, original);
long long t_ciura = time_sort(ShellSort_Ciura, original);
// 计算加速比(以 stride 为基准)
double speedup_knuth = static_cast<double>(t_stride) / t_knuth;
double speedup_ciura = static_cast<double>(t_stride) / t_ciura;
std::cout << "N = " << std::setw(6) << N << ":\n";
std::cout << " Stride=2 (原始): " << std::setw(7) << t_stride << " μs\n";
std::cout << " Knuth 序列 : " << std::setw(7) << t_knuth << " μs (" << speedup_knuth << "x faster)\n";
std::cout << " Ciura 序列 : " << std::setw(7) << t_ciura << " μs (" << speedup_ciura << "x faster)\n";
std::cout << "\n";
}
// 额外:测试有序数组(展示负优化)
std::cout << "=== 有序数组测试(N=10000)===\n";
std::vector<int> sorted_arr(10000);
for (int i = 0; i < 10000; ++i) sorted_arr[i] = i;
long long t_ins = time_sort(InsertSort, sorted_arr);
long long t_str = time_sort([](int* a, int n) { ShellSort_Stride(a, n, 2); }, sorted_arr);
long long t_knu = time_sort(ShellSort_Knuth, sorted_arr);
long long t_ciu = time_sort(ShellSort_Ciura, sorted_arr);
std::cout << "InsertSort : " << std::setw(7) << t_ins << " μs\n";
std::cout << "Stride=2 : " << std::setw(7) << t_str << " μs (" << static_cast<double>(t_str) / t_ins << "x slower!)\n";
std::cout << "Knuth : " << std::setw(7) << t_knu << " μs\n";
std::cout << "Ciura : " << std::setw(7) << t_ciu << " μs\n";
return 0;
}
python
python
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import random
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
class ShellSortTracer:
def __init__(self, name, arr):
self.name = name
self.original = arr[:]
self.steps = [] # 每一步的数组快照
self.highlights = [] # 当前高亮的索引列表(如 [i, j])
self.sorted_mask = [] # 每一步中哪些位置"已稳定"(可选,这里暂不严格定义)
self.total_comparisons = 0
self.total_moves = 0 # 插入式移动次数(非交换)
self.finished = False
def _record(self, arr, highlight_indices):
self.steps.append(arr[:])
self.highlights.append(highlight_indices[:])
# 简化:不严格追踪"已排序区",因为 Shell Sort 是跳跃插入
self.sorted_mask.append([False] * len(arr))
def sort_stride(self, stride=2):
arr = self.original[:]
n = len(arr)
self.steps.clear(); self.highlights.clear(); self.sorted_mask.clear()
self.total_comparisons = 0; self.total_moves = 0
gap = n // stride
while gap > 0:
for i in range(gap, n):
temp = arr[i]
j = i - gap
# 记录初始状态(开始插入 arr[i])
self._record(arr, [i])
while j >= 0:
self.total_comparisons += 1
self._record(arr, [j, j + gap]) # 比较 arr[j] 和 temp
if arr[j] > temp:
arr[j + gap] = arr[j]
self.total_moves += 1
j -= gap
else:
break
arr[j + gap] = temp
self._record(arr, [j + gap]) # 插入完成
gap //= stride
# 最终状态
self.steps.append(arr[:])
self.highlights.append([])
self.sorted_mask.append([True] * n)
def sort_knuth(self):
arr = self.original[:]
n = len(arr)
if n <= 32:
self._insert_sort_for_small(arr)
return
self.steps.clear(); self.highlights.clear(); self.sorted_mask.clear()
self.total_comparisons = 0; self.total_moves = 0
gap = 1
while gap < n // 3:
gap = gap * 3 + 1
while gap >= 1:
for i in range(gap, n):
temp = arr[i]
j = i - gap
self._record(arr, [i])
while j >= 0:
self.total_comparisons += 1
self._record(arr, [j, j + gap])
if arr[j] > temp:
arr[j + gap] = arr[j]
self.total_moves += 1
j -= gap
else:
break
arr[j + gap] = temp
self._record(arr, [j + gap])
gap //= 3
self.steps.append(arr[:])
self.highlights.append([])
self.sorted_mask.append([True] * n)
def sort_ciura(self):
arr = self.original[:]
n = len(arr)
if n <= 32:
self._insert_sort_for_small(arr)
return
self.steps.clear(); self.highlights.clear(); self.sorted_mask.clear()
self.total_comparisons = 0; self.total_moves = 0
gaps = [701, 301, 132, 57, 23, 10, 4, 1]
for gap in gaps:
if gap >= n:
continue
for i in range(gap, n):
temp = arr[i]
j = i - gap
self._record(arr, [i])
while j >= 0:
self.total_comparisons += 1
self._record(arr, [j, j + gap])
if arr[j] > temp:
arr[j + gap] = arr[j]
self.total_moves += 1
j -= gap
else:
break
arr[j + gap] = temp
self._record(arr, [j + gap])
# 扩展 Ciura(大数组)
if n > 1577:
gap = 701
while True:
gap = int(2.25 * gap)
if gap >= n:
break
for i in range(gap, n):
temp = arr[i]
j = i - gap
self._record(arr, [i])
while j >= 0:
self.total_comparisons += 1
self._record(arr, [j, j + gap])
if arr[j] > temp:
arr[j + gap] = arr[j]
self.total_moves += 1
j -= gap
else:
break
arr[j + gap] = temp
self._record(arr, [j + gap])
self.steps.append(arr[:])
self.highlights.append([])
self.sorted_mask.append([True] * n)
def _insert_sort_for_small(self, arr):
n = len(arr)
for i in range(1, n):
key = arr[i]
j = i - 1
self._record(arr, [i])
while j >= 0:
self.total_comparisons += 1
self._record(arr, [j, j+1])
if arr[j] > key:
arr[j + 1] = arr[j]
self.total_moves += 1
j -= 1
else:
break
arr[j + 1] = key
self._record(arr, [j + 1])
self.steps.append(arr[:])
self.highlights.append([])
self.sorted_mask.append([True] * n)
def compare_shell_sorts(data):
tracers = [
ShellSortTracer("Shell (Stride=2)", data),
ShellSortTracer("Shell (Knuth)", data),
ShellSortTracer("Shell (Ciura)", data),
]
# 执行排序并记录过程
tracers[0].sort_stride(stride=2)
tracers[1].sort_knuth()
tracers[2].sort_ciura()
max_steps = max(len(t.steps) for t in tracers)
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
bars = [
axes[i].bar(range(len(data)), t.steps[0], color='skyblue')
for i, t in enumerate(tracers)
]
for ax, t in zip(axes, tracers):
ax.set_ylim(0, max(data) * 1.15)
ax.set_xlabel('索引')
ax.set_ylabel('值')
info_text = fig.text(0.5, 0.02, '', ha='center', fontsize=11, family='monospace')
def update(frame):
stats_lines = []
all_bars = []
for idx, (tracer, bar_set, ax) in enumerate(zip(tracers, bars, axes)):
if frame < len(tracer.steps):
arr = tracer.steps[frame]
highlights = set(tracer.highlights[frame])
is_final = (frame == len(tracer.steps) - 1)
else:
# 已完成:冻结在最终状态
arr = tracer.steps[-1]
highlights = set()
is_final = True
# 更新柱子
for i, bar in enumerate(bar_set):
bar.set_height(arr[i])
if i in highlights:
bar.set_color('red')
elif is_final:
bar.set_color('lightgreen')
else:
bar.set_color('skyblue')
all_bars.append(bar)
# 标题
status = "✓ 完成" if frame >= len(tracer.steps) - 1 else f"第 {frame + 1} 步"
ax.set_title(f"{tracer.name}\n{status}", fontsize=11)
# 统计(始终显示最终值)
stats_lines.append(
f"{tracer.name}: "
f"比较={tracer.total_comparisons}, "
f"移动={tracer.total_moves}, "
f"步数={len(tracer.steps)-1}"
)
info_text.set_text(" | ".join(stats_lines))
return all_bars + [info_text]
ani = animation.FuncAnimation(
fig, update, frames=max_steps,
repeat=False, interval=200, blit=False # 间隔 200ms 更清晰
)
plt.tight_layout(rect=[0, 0.05, 1, 0.95])
plt.show()
# 打印最终统计
print("\n📊 最终统计:")
for t in tracers:
comp = t.total_comparisons
move = t.total_moves
steps = len(t.steps) - 1
print(f"{t.name:20} | 比较: {comp:5d} | 移动: {move:5d} | 步数: {steps:5d}")
# ========================
# 测试
# ========================
if __name__ == "__main__":
random.seed(42)
data = list(range(10, 50, 3)) # [10, 13, 16, ..., 49]
# 制造轻微无序
for _ in range(4):
i, j = random.sample(range(len(data)), 2)
data[i], data[j] = data[j], data[i]
print("原始数组:", data)
compare_shell_sorts(data)
python
import time
import random
import matplotlib.pyplot as plt
from typing import List, Callable
# ========== 1. 插入排序 ==========
def insert_sort(arr: List[int]) -> None:
n = len(arr)
for i in range(1, n):
key = arr[i]
j = i - 1
while j >= 0 and arr[j] > key:
arr[j + 1] = arr[j]
j -= 1
arr[j + 1] = key
# ========== 2. 原始 Stride 版本 ==========
def shell_sort_stride(arr: List[int], stride: int = 2) -> None:
n = len(arr)
if n <= 1:
return
gap = n // stride
while gap > 0:
for i in range(gap, n):
temp = arr[i]
j = i - gap
while j >= 0 and arr[j] > temp:
arr[j + gap] = arr[j]
j -= gap
arr[j + gap] = temp
gap //= stride
# ========== 3. Knuth 序列版本 ==========
def shell_sort_knuth(arr: List[int]) -> None:
n = len(arr)
if n <= 32:
insert_sort(arr)
return
# 生成 Knuth 序列:1, 4, 13, 40, 121, ...
gap = 1
while gap < n // 3:
gap = gap * 3 + 1
while gap >= 1:
for i in range(gap, n):
temp = arr[i]
j = i - gap
while j >= 0 and arr[j] > temp:
arr[j + gap] = arr[j]
j -= gap
arr[j + gap] = temp
gap //= 3
# ========== 4. Ciura 序列版本 ==========
def shell_sort_ciura(arr: List[int]) -> None:
n = len(arr)
if n <= 32:
insert_sort(arr)
return
gaps = [701, 301, 132, 57, 23, 10, 4, 1]
for gap in gaps:
if gap >= n:
continue
for i in range(gap, n):
temp = arr[i]
j = i - gap
while j >= 0 and arr[j] > temp:
arr[j + gap] = arr[j]
j -= gap
arr[j + gap] = temp
# 扩展 Ciura 序列(适用于非常大的数组)
if n > 1577:
gap = 701
while True:
gap = int(2.25 * gap)
if gap >= n:
break
for i in range(gap, n):
temp = arr[i]
j = i - gap
while j >= 0 and arr[j] > temp:
arr[j + gap] = arr[j]
j -= gap
arr[j + gap] = temp
# ========== 辅助函数 ==========
def is_sorted(arr: List[int]) -> bool:
return all(arr[i] <= arr[i+1] for i in range(len(arr)-1))
def time_sort(sort_func: Callable, data: List[int]) -> float:
arr = data.copy()
start = time.perf_counter_ns()
sort_func(arr)
end = time.perf_counter_ns()
duration_ms = (end - start) / 1_000_000 # 转为毫秒
if not is_sorted(arr):
raise RuntimeError("Sorting failed!")
return duration_ms
# ========== 主测试与绘图 ==========
def main():
sizes = list(range(500, 100001, 500)) # 从 500 到 10000,步长 500
algorithms = {
"Shell (stride=2)": lambda a: shell_sort_stride(a, stride=2),
"Shell (Knuth)": shell_sort_knuth,
"Shell (Ciura)": shell_sort_ciura,
}
results = {name: [] for name in algorithms}
for size in sizes:
print(f"Testing size: {size}")
# 生成随机数据
data = [random.randint(0, size) for _ in range(size)]
for name, func in algorithms.items():
t = time_sort(func, data)
results[name].append(t)
# 绘图
plt.figure(figsize=(10, 6))
for name, times in results.items():
plt.plot(sizes, times, label=name, marker='o')
plt.title("Shell Sort Variants Performance Comparison")
plt.xlabel("Array Size")
plt.ylabel("Time (ms)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("shell_sort_comparison.png", dpi=150)
plt.show()
if __name__ == "__main__":
main()