Description
Given the root of a binary tree, return the number of nodes where the value of the node is equal to the average of the values in its subtree.
Note:
The average of n elements is the sum of the n elements divided by n and rounded down to the nearest integer.
A subtree of root is a tree consisting of root and all of its descendants.
Example 1:
Input: root = [4,8,5,0,1,null,6]
Output: 5
Explanation:
For the node with value 4: The average of its subtree is (4 + 8 + 5 + 0 + 1 + 6) / 6 = 24 / 6 = 4.
For the node with value 5: The average of its subtree is (5 + 6) / 2 = 11 / 2 = 5.
For the node with value 0: The average of its subtree is 0 / 1 = 0.
For the node with value 1: The average of its subtree is 1 / 1 = 1.
For the node with value 6: The average of its subtree is 6 / 1 = 6.
Example 2:
Input: root = [1]
Output: 1
Explanation: For the node with value 1: The average of its subtree is 1 / 1 = 1.
Constraints:
The number of nodes in the tree is in the range [1, 1000].
0 <= Node.val <= 1000
Solution
Use post-order, keep track of the sum of the value of all the children, and the number of the children.
Time complexity: o ( n ) o(n) o(n)
Space complexity: o ( n ) o(n) o(n)
Code
python3
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def averageOfSubtree(self, root: Optional[TreeNode]) -> int:
stack = [(root, 0)]
res = 0
node_info = {}
while stack:
node, stat = stack.pop()
if stat == 0:
stack.append((node, 1))
if node.left:
stack.append((node.left, 0))
if node.right:
stack.append((node.right, 0))
else:
if not node.left and not node.right:
node_info[node] = (node.val, 1)
elif not node.left:
node_info[node] = (node.val + node_info[node.right][0], 1 + node_info[node.right][1])
elif not node.right:
node_info[node] = (node.val + node_info[node.left][0], 1 + node_info[node.left][1])
else:
left_sum, left_num = node_info[node.left]
right_sum, right_num = node_info[node.right]
node_info[node] = (node.val + left_sum + right_sum, 1 + left_num + right_num)
if node.val == node_info[node][0] // node_info[node][1]:
res += 1
return res