Contents

Making A Binary Search Tree in C++

This article is about implementing a Binary Search Tree (BST) in C++. I’ll skip the part about defining what a BST is since that’s a horse that’s been beaten many times. I am new to C++, so my implementation may have flaws. I welcome and encourage critique from other programmers :)

Draft 1

We start by implementing a TreeNode struct.

struct TreeNode
{
    // member vars
    int data;
    TreeNode* left;
    TreeNode* right;
    
    // constructor
    TreeNode(int data): data(data), left(nullptr), right(nullptr) {}
};

Notes:

  • Each TreeNode has three member variables:
    • data, an int storing the node’s value. In the future we can use template programming so that data can be any comparable type.
    • left a pointer to the left child node which is also a TreeNode
    • right a pointer to the left child node which is also a TreeNode
  • Our constructor lets us build a new TreeNode by providing a single int value, and it sets left and right to nullptrs.

Let’s play with it. We’ll start by making a single node with the value 5.

#include <iostream>

struct TreeNode
{
    // member vars
    int data;
    TreeNode* left;
    TreeNode* right;
    
    // constructor
    TreeNode(int data): data(data), left(nullptr), right(nullptr) {}
};


int main() {
    // Make a new TreeNode
    TreeNode foo(5);
    
    // Print info about foo
    std::cout <<
    "data: " << foo.data <<
    ", left: " << foo.left <<
    ", right: " << foo.right <<
    std::endl;
    
    return 0;
}

Challenge

Build a binary tree with 5 as the root node connected to 4 (left child) and 6 (right child).

Solution

#include <iostream>

// struct TreeNode{...}

int main() {
    
    // Make the tree
    //    5
    //   / \
    //  4   6
    
    // Make the nodes
    TreeNode root(5);
    TreeNode leftChild(4);
    TreeNode rightChild(6);
    
    // Connect nodes
    root.left = &leftChild;
    root.right = &rightChild;
    
    // Print info about the root
    std::cout <<
    "data: " << root.data <<
    ", left: " << root.left->data <<
    ", right: " << root.right->data <<
    std::endl;
    
    return 0;
}

data: 5, left: 4, right: 6

Draft 2

Right now, we can’t initialize an empty BST from our model. In order to allow for empty trees, we’ll make a BSTree class that stores a pointer to the root TreeNode (which might be null). This class has the added benefit that it gives us a distinction of the tree as a whole and it’s nodes or subtrees.

#include <iostream>

// struct TreeNode{...}

class BSTree
{
public:
    // constructors
    BSTree(): root(nullptr) {}
    BSTree(TreeNode* rootNode): root(rootNode) {}
    
    // member functions
    void Print();
    
private:
    TreeNode* root;
};

Challenge

Implement the member function Print().

Solution

This method won’t print the prettiest trees, but it’ll be good enough to visualize the structure of small trees.

#include <iostream>
#include <string>

// struct TreeNode{...}

class BSTree
{
public:
    // constructors
    BSTree(): root(nullptr) {}
    BSTree(TreeNode* rootNode): root(rootNode) {}
    
    // member functions
    void Print();
    
private:
    TreeNode* root;
    std::string SubTreeAsString(TreeNode* node);  // Helper method for Print()
};

/// Print the tree
void BSTree::Print(){
    if(this->root == nullptr){
        std::cout << "{}" << std::endl;
    } else{
        std::cout << this->SubTreeAsString(this->root) << std::endl;
    }
}

/// Print the subtree starting at '*node'
std::string BSTree::SubTreeAsString(TreeNode* node){
    std::string leftStr = (node->left == nullptr) ? "{}" : SubTreeAsString(node->left);
    std::string rightStr = (node->right == nullptr) ? "{}" : SubTreeAsString(node->right);
    std::string result = "{" + std::to_string(node->data) + ", " + leftStr + ", " + rightStr + "}";
    return result;
}

And some tests…

int main() {
    
    // -----------------------------
    // Make and print an empty tree
    
    BSTree emptyTree {};
    emptyTree.Print();
    
    // -----------------------------
    // Make and print the tree
    //    5
    //   / \
    //  4   6
    
    // Make the nodes
    TreeNode root(5);
    TreeNode leftChild(4);
    TreeNode rightChild(6);
    
    // Connect nodes
    root.left = &leftChild;
    root.right = &rightChild;
    
    // Make and print the tree
    BSTree myTree {&root};
    myTree.Print();
    
    return 0;
}

{}
{5, {4, {}, {}}, {6, {}, {}}}

Notes:

  • Here we represent each node as a set of three elements, {data, leftChild, rightChild} with empty trees represented as {}.
  • We use a private helper method SubTreeAsString(TreeNode* node) that recursively pieces together a string representation of the current node’s data and string representations of it’s children’s data.

Draft 3

Here we add an Insert(int val) method for inserting new nodes into the tree. With this method in place, we’ll be able to construct trees naturually by initializing an empty tree and then doing a series of inserts as opposed to what we’ve been doing - awkwardly making node instances and manually stitching them together. Note that if a user tries to add a value that already exists in the tree, we’ll warn them about it and then do nothing.

#include <iostream>
#include <string>

// struct TreeNode{...}

class BSTree
{
public:
    // constructors
    BSTree(): root(nullptr) {}
    BSTree(TreeNode* rootNode): root(rootNode) {}
    
    // member functions
    void Print();
    void Insert(int val);
    
private:
    TreeNode* root;
    std::string SubTreeAsString(TreeNode* node);  // Helper method for Print()
    void Insert(int val, TreeNode* node); // Helper method for Insert(int val)
};

/// Insert a new value into the tree
void BSTree::Insert(int val) {
    if(root == nullptr){
        this->root = new TreeNode(val);
    } else{
        this->Insert(val, this->root);
    }
}

/// Insert a new value into the subtree starting at node
void BSTree::Insert(int val, TreeNode* node) {
    
    // Check if node's value equals val
    // If so, warn the user and then exit function
    if(val == node->data){
        std::cout << "Warning: Value already exists, so nothing will be done." << std::endl;
        return;
    }
    
    // Check if val is < or > this node's value
    if(val < node->data){
        if(node->left == nullptr){
            // Make a new node as the left child of this node
            node->left = new TreeNode(val);
        } else{
            // Recursively call Insert() on this node's left child
            this->Insert(val, node->left);
        }
    } else{
        if(node->right == nullptr){
            // Make a new node as the right child of this node
            node->right = new TreeNode(val);
        } else{
            // Recursively call Insert() on this node's right child
            this->Insert(val, node->right);
        }
    }
}

Notes:

  • Here we use a public Insert(int val) method for inserting a new node into the tree and a private Insert(int val, TreeNode* node) helper method for inserting a new node into the subtree starting at the given node.
  • Insert(int val) handles the special case of, when the tree is empty, making a new node that becomes the root.
  • Insert(int val, TreeNode* node) handles the recursive logic, If val is less than the current node’s value, insert val into the subtree starting at the current node’s left child, otherwise insert it into the subtree starting at it’s right child. This method uses arm’s-length recursion whereby each call to Insert(int val, TreeNode* node) checks if the left/right child is null before recurively calling Insert(int val, TreeNode* node) on a pointer to that child node. If the child is null, then instead of making the recurvise call, we stop and make the new node right there and point the current node at it.

Let’s try it.

int main() {
    
    BSTree myTree {};
    myTree.Print();
    
    myTree.Insert(5);
    myTree.Print();
    
    myTree.Insert(4);
    myTree.Print();
    
    myTree.Insert(6);
    myTree.Print();
    
    return 0;
}

{} {5, {}, {}} {5, {4, {}, {}}, {}} {5, {4, {}, {}}, {6, {}, {}}}

Indeed our Insert(int val) works! But there’s a nagging issue..

Draft 4

Suppose we want to add a new value, 3, to the tree we just created. Let’s visualize how this process will work..

When we call myTree.Insert(3) this generates three calls to Insert(int val, TreeNode* node). For each of these calls, we are generating copies of a node pointer (new, red arrows in the gif above). This begs the question, Why not traverse the existing node pointers instead of making copies? Indeed, such a design would be simpler and more efficient.

The trick to making this work is that we need to change Insert(int val, TreeNode* node) to Insert(int val, TreeNode*& node), passing each node pointer as a reference instead of a copy. With this this modification in place, it simplifies the logic for both Insert(int val) and Insert(int val, TreeNode*& node) because it lets us recursively traverse the tree until we reach a nullptr and then insert a new node as opposed to our previous arm’s length recursion technique.

#include <iostream>
#include <string>

// struct TreeNode {...}

class BSTree
{
public:
    // constructors
    BSTree(): root(nullptr) {}
    BSTree(TreeNode* rootNode): root(rootNode) {}
    
    // member functions
    void Print();
    void Insert(int val);
    
private:
    TreeNode* root;
    std::string SubTreeAsString(TreeNode* node);  // Helper method for Print()
    void Insert(int val, TreeNode*& node); // Helper method for Insert(int val)
};

/// Insert a new value into the tree
void BSTree::Insert(int val) {
    this->Insert(val, this->root);
}

/// Insert a new value into the subtree starting at node
void BSTree::Insert(int val, TreeNode*& node) {
    
    if(node == nullptr){
        // Case: node is a nullptr
        // Make a new TreeNode for it to point to
        node = new TreeNode(val);
    } else{
        if(val < node->data){
            // Case: val is < node's data
            this->Insert(val, node->left);
        } else if(val > node->data){
            // Case: val is > node's data
            this->Insert(val, node->right);
        } else{
            // Case: val is equal to node's data
            std::cout << "Warning: Value already exists, so nothing will be done." << std::endl;
        }
    }
}

Draft 5

One use case for a BST is a dynamic set. For example, we could use a BST to create a dictionary of the unique words in a book. At the end of this process, we might want to check if a certain word is in the dictionary (and therefore in the book). Along these lines, let’s implement a Contains(int val) method that checks if a value exists in our BSTree.

#include <iostream>
#include <string>

// struct TreeNode {...}

class BSTree
{
public:
    // constructors
    BSTree(): root(nullptr) {}
    BSTree(TreeNode* rootNode): root(rootNode) {}
    
    // member functions
    void Print();
    void Insert(int val);
    bool Contains(int val);
    
private:
    TreeNode* root;
    std::string SubTreeAsString(TreeNode* node);  // Helper method for Print()
    void Insert(int val, TreeNode*& node); // Helper method for Insert(int val)
    bool Contains(int val, TreeNode*& node); // Helper method for Contains(int val)
};

/// Check if the given value exists in the BSTree
bool BSTree::Contains(int val) {
    return Contains(val, this->root);
}

/// Check if the given value exists in the subtree
/// starting at node
bool BSTree::Contains(int val, TreeNode*& node) {
    if(node == nullptr){
        return false;
    } else if(val == node->data){
        return true;
    } else if(val < node->data){
        return this->Contains(val, node->left);
    } else{
        return this->Contains(val, node->right);
    }
}

Now let’s test it.

int main() {
    
    BSTree myTree {};
    myTree.Insert(5);
    myTree.Insert(4);
    myTree.Insert(6);
    
    std::cout << std::boolalpha << myTree.Contains(4) << std::endl;
    std::cout << std::boolalpha <<myTree.Contains(2) << std::endl;
    
    return 0;
}

true false

Looks good.

Draft 6

Now lets implement Remove(int val) for removing a single node from a tree. In determinig the logic for removing a node, we need to consider five cases.

  1. val doesn’t exist
    We notify the user and then do nothing.

  2. val exists at a leaf node
    We delete the node.

  3. val exists at a node with a left child but not a right child
    We make the node’s parent point at the node’s left child and then delete the node.

  4. val exists at a node with a right child but not a left child
    We make the node’s parent point at the node’s right child and then delete the node.

  5. val exists at a node with left and right children
    This is the tricky case, but the solution is elegantly simple. We replace the node’s value with the minimum value in its right subtree. Then we delete that node (i.e the min-value node from the right subtree we just found). Convince yourself that the resulting tree is still a valid Binary Search Tree. (Note that there are other solutions to this problem.)

#include <iostream>
#include <string>

// struct TreeNode{...}

class BSTree
{
public:
    // constructors
    BSTree(): root(nullptr) {}
    BSTree(TreeNode* rootNode): root(rootNode) {}
    
    // member functions
    void Print();
    void Insert(int val);
    bool Contains(int val);
    void Remove(int val);
    
private:
    TreeNode* root;
    std::string SubTreeAsString(TreeNode* node);  // Helper method for Print()
    void Insert(int val, TreeNode*& node);  // Helper method for Insert(int val)
    bool Contains(int val, TreeNode*& node);  // Helper method for Contains(int val)
    void Remove(int val, TreeNode*& node);  // Helper method for Remove(int val)
    TreeNode*& FindMin(TreeNode*& node); // Helper method for Remove(int val)
};

/// Remove given value from the tree
void BSTree::Remove(int val) {
    this->Remove(val, this->root);
}

/// Remove given value from the subtree starting at node
void BSTree::Remove(int val, TreeNode*& node) {
    if(node == nullptr){
        // Case: nullptr
        
        std::cout << "val not found in tree" << std::endl;
        
    } else if(val == node->data){
        // Found value
        
        TreeNode* trash = nullptr;
        if(node->left == nullptr && node->right == nullptr){
            // Case: node is a leaf
            
            trash = node;
            node = nullptr;
            
        } else if(node->left != nullptr && node->right == nullptr){
            // Case: node has a left subtree (but not right)
            // Point node's parent at node's left subtree
            
            trash = node;
            node = node->left;
            
        } else if(node->left == nullptr && node->right != nullptr){
            // Case: node has a right subtree (but not left)
            
            trash = node;
            node = node->right;
            
        } else{
            // Case: node has left and right subtrees
            
            TreeNode*& minNode = this->FindMin(node->right); // returns a reference to the pointer in the tree
            node->data = minNode->data;
            this->Remove(minNode->data, minNode);
        }
        
        // Free memory
        if(trash != nullptr) delete trash;
        
    } else if(val < node->data){
        // Case: remove val from this node's left subtree
        this->Remove(val, node->left);
    } else{
        // Case: remove val from this node's right subtree
        this->Remove(val, node->right);
    }
}

/// Search the subtree starting at node and return a pointer to the minimum-value node
/// The returned pointer will be a reference of an actual pointer in the tree, not a copy
TreeNode*& BSTree::FindMin(TreeNode*& node) {
    if(node == nullptr){
        throw "Min value not found";
    } else if(node->left == nullptr){
        return node;
    } else{
        return this->FindMin(node->left);
    }
}

Notes:

  • void BSTree::Remove(int val, TreeNode*& node) does the heavy lifting
  • TreeNode*& BSTree::FindMin(TreeNode*& node) is a helper method that finds and returns a reference to the tree’s pointer that points at the smallest node in the subtree starting at the given node.
  • Whenever we delete a node, we make sure to delete the TreeNode object, freeing up memory on the heap

Draft 7

Finally, we’ll identify and implement a number of improvements, giving our rough implementation a more polished feel.

  1. Currently we’re using raw pointers, but it’d be better to replace those raw pointer with smart pointers so that we don’t need to worry about memory leaks and we don’t have to manage the deletion of objects ourselves when we move nodes around.
  2. There’s really no reason to expose TreeNode to the user. It’d be better to declare TreeNode as a private member of our BSTree class.
  3. Right now our BSTree can only be comprised of ints. With template programming, we can let our users build a BSTree with any type that is comparable.
  4. Many of our mthods use but don’t modify their input/out. We should declare such variables as const.

Challenge

Implement those improvements.

Solution

#include <iostream>
#include <string>
#include <memory>  // unique_ptr

template <typename T>
class BSTree
{
public:
    // constructors
    BSTree(): root(nullptr) {}
    
    // member functions
    void Print() const;
    void Insert(T val);
    bool Contains(T val) const;
    void Remove(T val);
    
private:
    
    struct TreeNode
    {
        // member vars
        T data;
        std::unique_ptr<TreeNode> left;
        std::unique_ptr<TreeNode> right;
        
        // constructor
        TreeNode(T data): data(data), left(nullptr), right(nullptr) {}
    };
    
    std::unique_ptr<TreeNode> root;
    std::string SubTreeAsString(const std::unique_ptr<TreeNode>& node) const;  // Helper method for Print()
    void Insert(T val, std::unique_ptr<TreeNode>& node);  // Helper method for Insert(int val)
    bool Contains(T val, std::unique_ptr<TreeNode>& node) const;  // Helper method for Contains(int val)
    void Remove(T val, std::unique_ptr<TreeNode>& node);  // Helper method for Remove(int val)
    std::unique_ptr<TreeNode>& FindMin(std::unique_ptr<TreeNode>& node); // Helper method for Remove(int val)
};

/// Print the tree
template <typename T>
void BSTree<T>::Print() const {
    if(this->root == nullptr){
        std::cout << "{}" << std::endl;
    } else{
        std::cout << this->SubTreeAsString(this->root) << std::endl;
    }
}

/// Print the subtree starting at node
template <typename T>
std::string BSTree<T>::SubTreeAsString(const std::unique_ptr<TreeNode>& node) const {
    std::string leftStr = (node->left == nullptr) ? "{}" : SubTreeAsString(node->left);
    std::string rightStr = (node->right == nullptr) ? "{}" : SubTreeAsString(node->right);
    std::string result = "{" + std::to_string(node->data) + ", " + leftStr + ", " + rightStr + "}";
    return result;
}

/// Insert a new value into the tree
template <typename T>
void BSTree<T>::Insert(T val) {
    this->Insert(val, this->root);
}

/// Insert a new value into the subtree starting at node
template <typename T>
void BSTree<T>::Insert(T val, std::unique_ptr<TreeNode>& node) {

    if(node == nullptr){
        // Case: node is a nullptr
        // Make a new TreeNode for it to point to
        node = std::make_unique<TreeNode>(val);
    } else{
        if(val < node->data){
            // Case: val is < node's data
            this->Insert(val, node->left);
        } else if(val > node->data){
            // Case: val is > node's data
            this->Insert(val, node->right);
        } else{
            // Case: val is equal to node's data
            std::cout << "Warning: Value already exists, so nothing will be done." << std::endl;
        }
    }
}

/// Check if the given value exists in the BSTree
template <typename T>
bool BSTree<T>::Contains(T val) const {
    return Contains(val, this->root);
}

/// Check if the given value exists in the subtree
/// starting at node
template <typename T>
bool BSTree<T>::Contains(T val, std::unique_ptr<TreeNode>& node) const {
    if(node == nullptr){
        return false;
    } else if(val == node->data){
        return true;
    } else if(val < node->data){
        return this->Contains(val, node->left);
    } else{
        return this-Contains(val, node->right);
    }
}

/// Remove given value from the tree
template <typename T>
void BSTree<T>::Remove(T val) {
    this->Remove(val, this->root);
}

/// Remove given value from the subtree starting at node
template <typename T>
void BSTree<T>::Remove(T val, std::unique_ptr<TreeNode>& node) {
    if(node == nullptr){
        // Case: nullptr

        std::cout << "val not found in tree" << std::endl;

    } else if(val == node->data){
        // Found value

        if(node->left == nullptr && node->right == nullptr){
            // Case: node is a leaf

            node = nullptr;

        } else if(node->left != nullptr && node->right == nullptr){
            // Case: node has a left subtree (but not right)
            // Point node's parent at node's left subtree

            node = std::move(node->left);

        } else if(node->left == nullptr && node->right != nullptr){
            // Case: node has a right subtree (but not left)

            node = std::move(node->right);

        } else{
            // Case: node has left and right subtrees

            std::unique_ptr<TreeNode>& minNode = this->FindMin(node->right); // returns a reference to the actual pointer in the tree
            node->data = minNode->data;
            this->Remove(minNode->data, minNode);
        }

    } else if(val < node->data){
        // Case: remove val from this node's left subtree
        this->Remove(val, node->left);
    } else{
        // Case: remove val from this node's right subtree
        this->Remove(val, node->right);
    }
}

/// Search the subtree starting at node and return a pointer to the minimum-value node
/// The returned pointer will be a reference of an actual pointer in the tree, not a copy
template <typename T>
std::unique_ptr<typename BSTree<T>::TreeNode>& BSTree<T>::FindMin(std::unique_ptr<TreeNode>& node) {
    if(node == nullptr){
        throw "Min value not found";
    } else if(node->left == nullptr){
        return node;
    } else{
        return this->FindMin(node->left);
    }
}

Special thanks to Marty Stepp and his fantastic Stanford lectures on implementing Binary Search Trees in C++.