0222. Count Complete Tree Nodes

Given the root of a complete binary tree, return the number of the nodes in the tree.

According to Wikipedia, every level, except possibly the last, is completely filled in a complete binary tree, and all nodes in the last level are as far left as possible. It can have between 1 and 2h nodes inclusive at the last level h.

Example 1:


Input: root = [1,2,3,4,5,6]
Output: 6

Example 2:

Input: root = []
Output: 0

Example 3:

Input: root = [1]
Output: 1


  • The number of nodes in the tree is in the range [0, 5 * 104].
  • 0 <= Node.val <= 5 * 104
  • The tree is guaranteed to be complete.

Follow up: Traversing the tree to count the number of nodes in the tree is an easy solution but with O(n) complexity. Could you find a faster algorithm?


Complete binary tree has the special property: at the last level of the binary tree, the nodes will be populated from left to right. There could be empty/null from any points from the last level, but from the first empty/null node, all the rest to the right has to be empty/null as well.

If we are given a perfect binary tree, we can use 2^{height} - 1 to calculate the total number of nodes from the tree. The problem now breaks down to find starting at which position, the rest nodes of the leaf nodes become empty/null. We can use binary search to find the pivot node.

  1. start from left and right, and traverse \log(n) levels to see the heights of left and right.
  2. if left and right heights are the same, it means we already in perfect binary tree -> we can use the formula to calculate the number of nodes in the current subtree.
  3. if left = right + 1, then the pivot should be somewhere between left and right, we should recursively call both.

  4. Time: \log^2 (n)

  5. Space: \log(n) -> it takes \log(n) times to find the pivot thus it will call the recursion \log(n) times.


 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode() : val(0), left(nullptr), right(nullptr) {}
 *     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
 *     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
 * };
class Solution {
    int countNodes(TreeNode* root) {
        if (!root) return 0;
        int l = 0, r = 0;
        TreeNode *a = root, *b = root;
        while (a) {
            a = a -> left;
        while (b) {
            b = b -> right;
        if (l == r) return pow(2, l) - 1;
        return countNodes(root -> left) + countNodes(root -> right) + 1;

