Kth Smallest Element in a BST

Given a binary search tree, write a functionkthSmallestto find the kth smallest element in it.

Note: You may assume k is always valid, 1 ≤ k ≤ BST's total elements.

Example 1:

Input:
 root = [3,1,4,null,2], k = 1
   3
  / \
 1   4
  \
   2

Output:
 1

Example 2:

Input:
 root = [5,3,6,2,4,null,null,1], k = 3
       5
      / \
     3   6
    / \
   2   4
  /
 1

Output:
 3

Follow up: What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine?

Analysis

In order traversal of BST actually returns the element in ascending order, thus intuitively, traverse the BST with in-order, and return the kth element in the result, would be the kth smallest element in a BST.

https://leetcode.com/problems/kth-smallest-element-in-a-bst/discuss/63660/3-ways-implemented-in-JAVA-(Python):-Binary-Search-in-order-iterative-and-recursive

https://leetcode.com/problems/kth-smallest-element-in-a-bst/discuss/63783/Two-Easiest-In-Order-Traverse-(Java\)

Solution

DFS in order traverse

/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    public int kthSmallest(TreeNode root, int k) {
        if (root == null) return 0; 
        List<Integer> topK = new ArrayList<Integer>();
        helper(root, topK, k);
        return topK.get(k - 1);
    }
    private void helper(TreeNode root, List<Integer> topK, int k) {
        if (root == null) return;
        helper(root.left, topK, k);
        if (topK.size() < k) {
            topK.add(root.val);
        } else {
            return;
        }
        helper(root.right, topK, k);
    }
}
class Solution {

    int count = 0;
    int result = 0;
    public int kthSmallest(TreeNode root, int k) {
        count = 0;
        result = 0;
        dfs(root, k);
        return result;
    }

    boolean dfs(TreeNode x, int k) {
        if (x == null) return false;

        if (dfs(x.left, k)) {
            return true;
        }

        count++;
        if (count == k) {
            result = x.val;
            return true;
        }

        return dfs(x.right, k);
    }
}

Last updated