给定一个二叉搜索树 root (BST),请将它的每个节点的值替换成树中大于或者等于该节点值的所有节点值之和。
提醒一下, 二叉搜索树 满足下列约束条件:
节点的左子树仅包含键 小于 节点键的节点。
节点的右子树仅包含键 大于 节点键的节点。
左右子树也必须是二叉搜索树。
示例 1:
输入:[4,1,6,0,2,5,7,null,null,null,3,null,null,null,8]
输出:[30,36,21,36,35,26,15,null,null,null,33,null,null,null,8]
示例 2:
输入:root = [0,null,1]
输出:[1,null,1]
提示:
树中的节点数在 [1, 100] 范围内。
0 <= Node.val <= 100
树中的所有值均 不重复 。
法一:反向中序遍历即可,用一个变量记录累加和,以下是递归遍历:
cpp
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode() : val(0), left(nullptr), right(nullptr) {}
* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
* TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
* };
*/
class Solution {
public:
TreeNode* bstToGst(TreeNode* root) {
int sum = 0;
reverseInorderTraversal(root, sum);
return root;
}
private:
void reverseInorderTraversal(TreeNode *node, int &sum)
{
if (node == nullptr)
{
return;
}
reverseInorderTraversal(node->right, sum);
sum += node->val;
node->val = sum;
reverseInorderTraversal(node->left, sum);
}
};
如果树中有n个节点,此算法时间复杂度为O(n),空间复杂度为树的高度O(logn)(平均情况,最差情况为O(n))。
法二:法一的迭代版:
cpp
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode() : val(0), left(nullptr), right(nullptr) {}
* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
* TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
* };
*/
class Solution {
public:
TreeNode* bstToGst(TreeNode* root) {
stack<TreeNode *> s;
int sum = 0;
TreeNode *cur = root;
while (cur || !s.empty())
{
while (cur)
{
s.push(cur);
cur = cur->right;
}
cur = s.top();
s.pop();
sum += cur->val;
cur->val = sum;
cur = cur->left;
}
return root;
}
};
如果树中有n个节点,此算法时间复杂度为O(n),空间复杂度为树的高度O(logn)(平均情况,最差情况为O(n))。
法三:Morris遍历,要点在于找到当前节点的前驱节点,对于反向中序遍历,某节点的前驱节点是右子树的最左节点:
cpp
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode() : val(0), left(nullptr), right(nullptr) {}
* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
* TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
* };
*/
class Solution {
public:
TreeNode* bstToGst(TreeNode* root) {
int sum = 0;
TreeNode *cur = root;
while (cur)
{
if (!cur->right)
{
sum += cur->val;
cur->val = sum;
cur = cur->left;
continue;
}
TreeNode *leftestOfRight = getLeftestOfRight(cur);
if (leftestOfRight->left)
{
sum += cur->val;
cur->val = sum;
leftestOfRight->left = nullptr;
cur = cur->left;
}
else
{
leftestOfRight->left = cur;
cur = cur->right;
}
}
return root;
}
private:
TreeNode *getLeftestOfRight(TreeNode *node)
{
TreeNode *cur = node->right;
while (cur->left && cur->left != node)
{
cur = cur->left;
}
return cur;
}
};
如果树中有n个节点,此算法时间复杂度为O(n),空间复杂度为O(1)。
法四:由于输入的是二叉搜索树,因此我们可以每次找到比当前值小的前一个节点,依次更新:
cpp
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode() : val(0), left(nullptr), right(nullptr) {}
* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
* TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
* };
*/
class Solution {
public:
TreeNode* bstToGst(TreeNode* root) {
TreeNode *curTarget = getMax(root);
int sum = 0;
int lastNum = curTarget->val;
while (curTarget)
{
lastNum = curTarget->val;
sum += curTarget->val;
curTarget->val = sum;
curTarget = getLess(root, lastNum);
}
return root;
}
private:
TreeNode *getMax(TreeNode *node)
{
while (node->right)
{
node = node->right;
}
return node;
}
TreeNode *getLess(TreeNode *root, int target)
{
TreeNode *lessTarget = nullptr;
while (root)
{
if (root->val > target)
{
root = root->left;
}
else if (root->val < target)
{
if (!lessTarget || root->val > lessTarget->val)
{
lessTarget = root;
}
root = root->right;
}
else
{
break;
}
}
return lessTarget;
}
};
如果树中有n个节点,此算法时间复杂度为O(nlgn),空间复杂度为O(1)。