

最近在做一个在线PPT编辑器,其中状态管理用到了Zustand,撤销重做功能用的是zundo,好奇是如何实现的。于是看看源码,写篇文章总结一下。
"撤销/重做"(Undo/Redo)功能是提升用户体验不可或缺的一环。对于使用Zustand这款轻量级状态管理库的开发者来说,zundo 提供了一个极其优雅和简单的解决方案。
zundo 是一个专门为 Zustand 设计的中间件,它几乎以"零配置"的方式为Zustand的状态管理添加了强大的历史记录功能。
如何在项目中使用 Zundo
1. 安装
Bash
npm install zundo
# 或者
yarn add zundo
2. 封装Store
只需要将 create 函数(Zustand store 的定义)用 zundo 中间件包裹起来即可
假设有一个简单的计数器 store:
TypeScript
import create from 'zustand';
// 你的 store 定义
const createStore = (set) => ({
count: 0,
increment: () => set((state) => ({ count: state.count + 1 })),
decrement: () => set((state) => ({ count: state.count - 1 })),
});
要为其添加 undo/redo 功能,只需这样做:
TypeScript
import { create } from 'zustand';
import { temporal } from 'zundo';
// 你的 store 定义
const createStore = (set, get) => ({
count: 0,
increment: () => set((state) => ({ count: state.count + 1 })),
decrement: () => set((state) => ({ count: state.count - 1 })),
// ... 其他 actions
});
// 使用temporal封装
export const useStore = create(temporal(createStore));
3. 在组件中使用
zundo 会自动向你的 store 中注入 undo、redo、clear 等 action,以及 pastStates 和 futureStates 两个状态数组
JavaScript
function App() {
// 从 store 中获取状态和 actions
const { count, increment, decrement } = useStore();
// 从 zundo 获取 undo/redo
const { undo, redo, clear, pastStates, futureStates } = useStore.temporal.getState();
return (
<div>
<h1>Count: {count}</h1>
<button onClick={increment}>+</button>
<button onClick={decrement}>-</button>
<hr />
<button onClick={undo} disabled={pastStates.length === 0}>
Undo (撤销)
</button>
<button onClick={redo} disabled={futureStates.length === 0}>
Redo (重做)
</button>
<button onClick={clear}>Clear History (清除历史)</button>
</div>
);
}
实现原理
- 在
create时通过一个 mutator 把独立的时间旅行子 store 注入到主 store 的store.temporal上。 - 拦截所有状态更新(外部
store.setState和config内部的set),在更新完成后决定是否把这次变更记录到历史。 - 历史以两个栈维护:
pastStates(用于 undo)与futureStates(用于 redo),并提供undo/redo/clear/pause/resume等 API
index.ts
ts
import { createStore } from 'zustand';
import { temporalStateCreator } from './temporal';
import type {
StateCreator,
StoreMutatorIdentifier,
Mutate,
StoreApi,
} from 'zustand';
import type {
TemporalState,
_TemporalState,
Write,
ZundoOptions,
} from './types';
type Zundo = <
TState,
Mps extends [StoreMutatorIdentifier, unknown][] = [],
Mcs extends [StoreMutatorIdentifier, unknown][] = [],
UState = TState,
>(
config: StateCreator<TState, [...Mps, ['temporal', unknown]], Mcs>,
options?: ZundoOptions<TState, UState>,
) => StateCreator<
TState,
Mps,
[['temporal', StoreApi<TemporalState<UState>>], ...Mcs]
>;
declare module 'zustand/vanilla' {
interface StoreMutators<S, A> {
temporal: Write<S, { temporal: A }>;
}
}
export const temporal = (<TState>(
config: StateCreator<TState, [], []>,
options?: ZundoOptions<TState>,
): StateCreator<TState, [], []> => {
// 增强用户传入的 config:通过注入 temporal 子 store,为外层 store 增加"时间旅行"能力
const configWithTemporal = (
set: StoreApi<TState>['setState'],
get: StoreApi<TState>['getState'],
store: Mutate<
StoreApi<TState>,
[['temporal', StoreApi<TemporalState<TState>>]]
>,
) => {
// 创建一个独立的 temporal 子 store,用来维护历史/未来队列;支持 wrapTemporal 进行自定义包装
store.temporal = createStore(
options?.wrapTemporal?.(temporalStateCreator(set, get, options)) ||
temporalStateCreator(set, get, options),
);
// 取出内部的 _handleSet;若用户提供 handleSet,则对其进行柯里化以在外层 set 完成后统一写入时间线
const curriedHandleSet =
options?.handleSet?.(
(store.temporal.getState() as _TemporalState<TState>)
._handleSet as StoreApi<TState>['setState'],
) || (store.temporal.getState() as _TemporalState<TState>)._handleSet;
// 在每次状态变更后调用:基于 partialize 获取最小化状态,使用 diff/equality 判断是否需要记录
const temporalHandleSet = (pastState: TState) => {
if (!store.temporal.getState().isTracking) return;
const currentState = options?.partialize?.(get()) || get();
const deltaState = options?.diff?.(pastState, currentState);
if (
// 当 diff 返回 null 或 equality 判断两者相等时,认为无实际变更,跳过记录
!(
(
deltaState === null ||
options?.equality?.(pastState, currentState)
)
)
) {
// 将一次变更写入时间线:pastState -> currentState;deltaState 用于缩减历史体积
curriedHandleSet(
pastState,
undefined as unknown as Parameters<typeof set>[1],
currentState,
deltaState,
);
}
};
const setState = store.setState;
// 代理原始 setState:先读取变更前的 pastState,再执行真实 setState,最后记录到时间线
store.setState = (...args) => {
// 先 get 再 set:保证拿到变更前的最新快照(callback 可能返回部分状态)
const pastState = options?.partialize?.(get()) || get();
setState(...(args as Parameters<typeof setState>));
temporalHandleSet(pastState);
};
return config(
// 同样代理用户传入的 set:确保每次变更都进入时间线(先取快照,再 set,最后记录)
(...args) => {
// 先 get 再 set:保证拿到变更前的最新快照
const pastState = options?.partialize?.(get()) || get();
set(...(args as Parameters<typeof set>));
temporalHandleSet(pastState);
},
get,
store,
);
};
// 返回增强后的配置函数,使该 store 具备 temporal 能力(类型层面暴露 mutator)
return configWithTemporal as StateCreator<TState, [], []>;
}) as unknown as Zundo;
export type { ZundoOptions, Zundo, TemporalState };
temporal.ts
ts
import type { StateCreator, StoreApi } from 'zustand';
import type { _TemporalState, ZundoOptions } from './types';
export const temporalStateCreator = <TState>(
userSet: StoreApi<TState>['setState'],
userGet: StoreApi<TState>['getState'],
options?: ZundoOptions<TState>,
) => {
// 构造 temporal 子 store 的状态机:维护 pastStates/futureStates,提供 undo/redo 等能力
const stateCreator: StateCreator<_TemporalState<TState>, [], []> = (
set,
get,
) => {
return {
// 初始化历史与未来队列(可通过 options 预设)
pastStates: options?.pastStates || [],
futureStates: options?.futureStates || [],
// 撤销 steps 步:从过去队列尾部取出状态并应用到用户 store
undo: (steps = 1) => {
if (get().pastStates.length) {
// 必须先拿到变更前的快照:userGet 在 userSet 之前调用
const currentState = options?.partialize?.(userGet()) || userGet();
const statesToApply = get().pastStates.splice(-steps, steps);
// 应用最新的一条历史状态,并将当前状态与剩余历史(反转以维持时间顺序)推入未来队列
const nextState = statesToApply.shift()!;
userSet(nextState);
set({
pastStates: get().pastStates,
futureStates: get().futureStates.concat(
// 若提供 diff,则使用 current->next 的 delta 作为存储单位;否则存完整 currentState
options?.diff?.(currentState, nextState) || currentState,
statesToApply.reverse(),
),
});
}
},
// 重做 steps 步:从未来队列尾部取出状态并应用到用户 store
redo: (steps = 1) => {
if (get().futureStates.length) {
// 必须先拿到变更前的快照:userGet 在 userSet 之前调用
const currentState = options?.partialize?.(userGet()) || userGet();
const statesToApply = get().futureStates.splice(-steps, steps);
// 应用最新的一条未来状态,并将当前状态与剩余未来(反转以维持时间顺序)推入过去队列
const nextState = statesToApply.shift()!;
userSet(nextState);
set({
pastStates: get().pastStates.concat(
// 若提供 diff,则使用 current->next 的 delta;否则存完整 currentState
options?.diff?.(currentState, nextState) || currentState,
statesToApply.reverse(),
),
futureStates: get().futureStates,
});
}
},
// 清空时间线
clear: () => set({ pastStates: [], futureStates: [] }),
// 记录开关:暂停/恢复时间线写入
isTracking: true,
pause: () => set({ isTracking: false }),
resume: () => set({ isTracking: true }),
// 动态设置保存回调(如持久化)
setOnSave: (_onSave) => set({ _onSave }),
// 内部属性与写入逻辑
_onSave: options?.onSave,
_handleSet: (pastState, replace, currentState, deltaState) => {
// 容量限制:若超过 limit,则丢弃最早的过去状态
if (options?.limit && get().pastStates.length >= options?.limit) {
get().pastStates.shift();
}
// 触发保存回调(可用于外部持久化或统计)
get()._onSave?.(pastState, currentState);
set({
// 将 delta(若存在)或完整 pastState 记录到过去队列,同时清空未来队列
pastStates: get().pastStates.concat(deltaState || pastState),
futureStates: [],
});
},
};
};
// 对外以普通 StateCreator 形式暴露,隐藏内部 temporal 扩展的具体细节
return stateCreator as StateCreator<_TemporalState<TState>, [], []>;
};
历史记录是快照还是操作?
- 默认:存的是"快照",即在没有提供
diff时,记录项为"过去的跟踪状态子集 - 可选:如果提供了
diff,则存的是"最小补丁"