题目描述
这是 LeetCode 上的 96. 不同的二叉搜索树 ,难度为 中等。
Tag : 「树」、「二叉搜索树」、「动态规划」、「区间 DP」、「数学」、「卡特兰数」
给你一个整数 n
,求恰由 n
个节点组成且节点值从 1
到 n
互不相同的 二叉搜索树 有多少种?
返回满足题意的二叉搜索树的种数。
示例 1:
ini
输入:n = 3
输出:5
示例 2:
ini
输入:n = 1
输出:1
提示:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 < = n < = 19 1 <= n <= 19 </math>1<=n<=19
区间 DP
沿用 95. 不同的二叉搜索树 II 的基本思路,只不过本题不是求具体方案,而是求个数。
除了能用 95. 不同的二叉搜索树 II 提到的「卡特兰数」直接求解 <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n 个节点的以外,本题还能通过常规「区间 DP」的方式进行求解。
求数量使用 DP,求所有具体方案使用爆搜,是极其常见的一题多问搭配。
定义 <math xmlns="http://www.w3.org/1998/Math/MathML"> f [ l ] [ r ] f[l][r] </math>f[l][r] 为使用数值范围在 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ l , r ] [l, r] </math>[l,r] 之间的节点,所能构建的 BST
个数。
不失一般性考虑 <math xmlns="http://www.w3.org/1998/Math/MathML"> f [ l ] [ r ] f[l][r] </math>f[l][r] 该如何求解,仍用 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ l , r ] [l, r] </math>[l,r] 中的根节点 i
为何值,作为切入点进行思考。
根据「BST
定义」及「乘法原理」可知: <math xmlns="http://www.w3.org/1998/Math/MathML"> [ l , i − 1 ] [l, i - 1] </math>[l,i−1] 相关节点构成的 BST
子树只能在 i
的左边,而 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ i + 1 , r ] [i + 1, r] </math>[i+1,r] 相关节点构成的 BST
子树只能在 i
的右边。所有的左右子树相互独立,因此以 i
为根节点的 BST
数量为 <math xmlns="http://www.w3.org/1998/Math/MathML"> f [ l ] [ i − 1 ] × f [ i + 1 ] [ r ] f[l][i - 1] \times f[i + 1][r] </math>f[l][i−1]×f[i+1][r],而 i
共有 <math xmlns="http://www.w3.org/1998/Math/MathML"> r − l + 1 r - l + 1 </math>r−l+1 个取值( <math xmlns="http://www.w3.org/1998/Math/MathML"> i ∈ [ l , r ] i \in [l, r] </math>i∈[l,r])。
即有:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> f [ l ] [ r ] = ∑ i = l r ( f [ l ] [ i − 1 ] × f [ i + 1 ] [ r ] ) f[l][r] = \sum_{i = l}^{r} (f[l][i - 1] \times f[i + 1][r]) </math>f[l][r]=i=l∑r(f[l][i−1]×f[i+1][r])
不难发现,求解区间 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ l , r ] [l, r] </math>[l,r] 的 BST
数量 <math xmlns="http://www.w3.org/1998/Math/MathML"> f [ l ] [ r ] f[l][r] </math>f[l][r] 依赖于比其小的区间 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ l , i − 1 ] [l, i - 1] </math>[l,i−1] 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ i + 1 , r ] [i + 1, r] </math>[i+1,r],这引导我们使用「区间 DP」的方式进行递推。
不了解区间 DP 的同学,可看 前置 🧀
一些细节:由于我们 i
的取值可能会取到区间中的最值 l
和 r
,为了能够该情况下, <math xmlns="http://www.w3.org/1998/Math/MathML"> f [ l ] [ i − 1 ] f[l][i - 1] </math>f[l][i−1] 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> f [ i + 1 ] [ r ] f[i + 1][r] </math>f[i+1][r] 能够顺利参与转移,起始我们需要先对所有满足 <math xmlns="http://www.w3.org/1998/Math/MathML"> i > = j i >= j </math>i>=j 的 <math xmlns="http://www.w3.org/1998/Math/MathML"> f [ i ] [ j ] f[i][j] </math>f[i][j] 初始化为 1
。
Java 代码:
Java
class Solution {
public int numTrees(int n) {
int[][] f = new int[n + 10][n + 10];
for (int i = 0; i <= n + 1; i++) {
for (int j = 0; j <= n + 1; j++) {
if (i >= j) f[i][j] = 1;
}
}
for (int len = 2; len <= n; len++) {
for (int l = 1; l + len - 1 <= n; l++) {
int r = l + len - 1;
for (int i = l; i <= r; i++) {
f[l][r] += f[l][i - 1] * f[i + 1][r];
}
}
}
return f[1][n];
}
}
C++ 代码:
C++
class Solution {
public:
int numTrees(int n) {
vector<vector<int>> f(n + 2, vector<int>(n + 2, 0));
for (int i = 0; i <= n + 1; i++) {
for (int j = 0; j <= n + 1; j++) {
if (i >= j) f[i][j] = 1;
}
}
for (int len = 2; len <= n; len++) {
for (int l = 1; l + len - 1 <= n; l++) {
int r = l + len - 1;
for (int i = l; i <= r; i++) {
f[l][r] += f[l][i - 1] * f[i + 1][r];
}
}
}
return f[1][n];
}
};
Python 代码:
Python
class Solution(object):
def numTrees(self, n):
f = [[0] * (n + 2) for _ in range(n + 2)]
for i in range(n + 2):
for j in range(n + 2):
if i >= j:
f[i][j] = 1
for length in range(2, n + 1):
for l in range(1, n - length + 2):
r = l + length - 1
for i in range(l, r + 1):
f[l][r] += f[l][i - 1] * f[i + 1][r]
return f[1][n]
TypeScript 代码:
TypeScript
function numTrees(n: number): number {
const f = new Array(n + 2).fill(0).map(() => new Array(n + 2).fill(0));
for (let i = 0; i <= n + 1; i++) {
for (let j = 0; j <= n + 1; j++) {
if (i >= j) f[i][j] = 1;
}
}
for (let len = 2; len <= n; len++) {
for (let l = 1; l + len - 1 <= n; l++) {
const r = l + len - 1;
for (let i = l; i <= r; i++) {
f[l][r] += f[l][i - 1] * f[i + 1][r];
}
}
}
return f[1][n];
};
- 时间复杂度: <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( n 3 ) O(n^3) </math>O(n3)
- 空间复杂度: <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( n 2 ) O(n^2) </math>O(n2)
区间 DP(优化)
求解完使用 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 1 , n ] [1, n] </math>[1,n] 共 <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n 个连续数所能构成的 BST
个数后,再来思考一个问题:使用 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ L , R ] [L, R] </math>[L,R] 共 <math xmlns="http://www.w3.org/1998/Math/MathML"> n = R − L + 1 n = R - L + 1 </math>n=R−L+1 个连续数,所能构成的 BST
个数又是多少。
答案是一样的。
由 <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n 个连续数构成的 BST
个数仅与数值个数有关系,与数值大小本身并无关系。
由于可知,我们上述的「区间 DP」必然进行了大量重复计算,例如 <math xmlns="http://www.w3.org/1998/Math/MathML"> f [ 1 ] [ 3 ] f[1][3] </math>f[1][3] 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> f [ 2 ] [ 4 ] f[2][4] </math>f[2][4] 同为大小为 <math xmlns="http://www.w3.org/1998/Math/MathML"> 3 3 </math>3 的区间,却被计算了两次。
调整我们的状态定义:定义 <math xmlns="http://www.w3.org/1998/Math/MathML"> f [ k ] f[k] </math>f[k] 为考虑连续数个数为 <math xmlns="http://www.w3.org/1998/Math/MathML"> k k </math>k 时,所能构成的 BST
的个数。
不失一般性考虑 <math xmlns="http://www.w3.org/1998/Math/MathML"> f [ i ] f[i] </math>f[i] 如何计算,仍用 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 1 , i ] [1, i] </math>[1,i] 中哪个数值作为根节点进行考虑。假设使用数值 <math xmlns="http://www.w3.org/1998/Math/MathML"> j j </math>j 作为根节点,则有 <math xmlns="http://www.w3.org/1998/Math/MathML"> f [ j − 1 ] × f [ i − j ] f[j - 1] \times f[i - j] </math>f[j−1]×f[i−j] 个 BST
可贡献到 <math xmlns="http://www.w3.org/1998/Math/MathML"> f [ i ] f[i] </math>f[i],而 <math xmlns="http://www.w3.org/1998/Math/MathML"> j j </math>j 共有 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 个取值( <math xmlns="http://www.w3.org/1998/Math/MathML"> j ∈ [ 1 , i ] j \in [1, i] </math>j∈[1,i])。
即有:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> f [ i ] = ∑ j = 1 i ( f [ j − 1 ] × f [ i − j ] ) f[i] = \sum_{j = 1}^{i}(f[j - 1] \times f[i - j]) </math>f[i]=j=1∑i(f[j−1]×f[i−j])
同时有初始化 <math xmlns="http://www.w3.org/1998/Math/MathML"> f [ 0 ] = 1 f[0] = 1 </math>f[0]=1,含义为没有任何连续数时,只有"空树"一种合法方案。
Java 代码:
Java
class Solution {
public int numTrees(int n) {
int[] f = new int[n + 10];
f[0] = 1;
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= i; j++) {
f[i] += f[j - 1] * f[i - j];
}
}
return f[n];
}
}
C++ 代码:
C++
class Solution {
public:
int numTrees(int n) {
vector<int> f(n + 10, 0);
f[0] = 1;
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= i; j++) {
f[i] += f[j - 1] * f[i - j];
}
}
return f[n];
}
};
Python 代码:
Python
class Solution:
def numTrees(self, n: int) -> int:
f = [0] * (n + 10)
f[0] = 1
for i in range(1, n + 1):
for j in range(1, i + 1):
f[i] += f[j - 1] * f[i - j]
return f[n]
TypeScript 代码:
TypeScript
function numTrees(n: number): number {
const f = new Array(n + 10).fill(0);
f[0] = 1;
for (let i = 1; i <= n; i++) {
for (let j = 1; j <= i; j++) {
f[i] += f[j - 1] * f[i - j];
}
}
return f[n];
};
- 时间复杂度: <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( n 2 ) O(n^2) </math>O(n2)
- 空间复杂度: <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( n ) O(n) </math>O(n)
数学 - 卡特兰数
在「区间 DP(优化)」中的递推过程,正是"卡特兰数"的 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( n 2 ) O(n^2) </math>O(n2) 递推过程。通过对常规「区间 DP」的优化,我们得证 95. 不同的二叉搜索树 II 中「给定 <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n 个节点所能构成的 BST
的个数为卡特兰数」这一结论。
对于精确求卡特兰数,存在时间复杂度为 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( n ) O(n) </math>O(n) 的通项公式做法,公式为 <math xmlns="http://www.w3.org/1998/Math/MathML"> C n + 1 = C n ⋅ ( 4 n + 2 ) n + 2 C_{n+1} = \frac{C_n \cdot (4n + 2)}{n + 2} </math>Cn+1=n+2Cn⋅(4n+2)。
Java 代码:
Java
class Solution {
public int numTrees(int n) {
if (n <= 1) return 1;
long ans = 1;
for (int i = 0; i < n; i++) ans = ans * (4 * i + 2) / (i + 2);
return (int)ans;
}
}
C++ 代码:
C++
class Solution {
public:
int numTrees(int n) {
if (n <= 1) return 1;
long long ans = 1;
for (int i = 0; i < n; i++) ans = ans * (4 * i + 2) / (i + 2);
return (int)ans;
}
};
Python 代码:
Python
class Solution:
def numTrees(self, n: int) -> int:
if n <= 1:
return 1
ans = 1
for i in range(n):
ans = ans * (4 * i + 2) // (i + 2)
return ans
TypeScript 代码:
TypeScript
function numTrees(n: number): number {
if (n <= 1) return 1;
let ans = 1;
for (let i = 0; i < n; i++) ans = ans * (4 * i + 2) / (i + 2);
return ans;
};
- 时间复杂度: <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( n ) O(n) </math>O(n)
- 空间复杂度: <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( 1 ) O(1) </math>O(1)
最后
这是我们「刷穿 LeetCode」系列文章的第 No.96
篇,系列开始于 2021/01/01,截止于起始日 LeetCode 上共有 1916 道题目,部分是有锁题,我们将先把所有不带锁的题目刷完。
在这个系列文章里面,除了讲解解题思路以外,还会尽可能给出最为简洁的代码。如果涉及通解还会相应的代码模板。
为了方便各位同学能够电脑上进行调试和提交代码,我建立了相关的仓库:github.com/SharingSour... 。
在仓库地址里,你可以看到系列文章的题解链接、系列文章的相应代码、LeetCode 原题链接和其他优选题解。
更多更全更热门的「笔试/面试」相关资料可访问排版精美的 合集新基地 🎉🎉