# Kth Smallest Element in a BST

Given a binary search tree, write a function`kthSmallest`to find the **k**th smallest element in it.

**Note:**\
You may assume k is always valid, 1 ≤ k ≤ BST's total elements.

**Example 1:**

```
Input:
 root = [3,1,4,null,2], k = 1
   3
  / \
 1   4
  \
   2

Output:
 1
```

**Example 2:**

```
Input:
 root = [5,3,6,2,4,null,null,1], k = 3
       5
      / \
     3   6
    / \
   2   4
  /
 1

Output:
 3
```

**Follow up:**\
What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine?

## Analysis

In order traversal of BST actually returns the element in ascending order, thus intuitively, traverse the BST with in-order, and return the kth element in the result, would be the kth smallest element in a BST.

<https://leetcode.com/problems/kth-smallest-element-in-a-bst/discuss/63660/3-ways-implemented-in-JAVA-(Python):-Binary-Search-in-order-iterative-and-recursive>

[https://leetcode.com/problems/kth-smallest-element-in-a-bst/discuss/63783/Two-Easiest-In-Order-Traverse-(Java\\](https://leetcode.com/problems/kth-smallest-element-in-a-bst/discuss/63783/Two-Easiest-In-Order-Traverse-\(Java/))

## Solution

DFS in order traverse

```java
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    public int kthSmallest(TreeNode root, int k) {
        if (root == null) return 0; 
        List<Integer> topK = new ArrayList<Integer>();
        helper(root, topK, k);
        return topK.get(k - 1);
    }
    private void helper(TreeNode root, List<Integer> topK, int k) {
        if (root == null) return;
        helper(root.left, topK, k);
        if (topK.size() < k) {
            topK.add(root.val);
        } else {
            return;
        }
        helper(root.right, topK, k);
    }
}
```

```java
class Solution {

    int count = 0;
    int result = 0;
    public int kthSmallest(TreeNode root, int k) {
        count = 0;
        result = 0;
        dfs(root, k);
        return result;
    }

    boolean dfs(TreeNode x, int k) {
        if (x == null) return false;

        if (dfs(x.left, k)) {
            return true;
        }

        count++;
        if (count == k) {
            result = x.val;
            return true;
        }

        return dfs(x.right, k);
    }
}
```
