diff --git a/BinarySearchTree.h b/BinarySearchTree.h index e4a7940..11f76fc 100644 --- a/BinarySearchTree.h +++ b/BinarySearchTree.h @@ -10,12 +10,12 @@ private: BinaryTreeNode* _search(BinaryTreeNode* treePtr, const T& target) const; - //BinaryNode* _remove(BinaryNode* nodePtr, const ItemType target, bool& success); + BinaryTreeNode* _remove(BinaryTreeNode* nodePtr, const T& target); public: void insert(const T& item); - void remove(const T &item) { throw std::logic_error("Not implemented: BinarySearchTree.remove()"); }; + bool remove(const T& item); bool search(const T& target, T& returnedItem) const; }; @@ -54,6 +54,60 @@ BinaryTreeNode* BinarySearchTree::_insert(BinaryTreeNode* nodePtr, Bina return nodePtr; } +template +bool BinarySearchTree::remove(const T& item) { + BinaryTreeNode* removed = _remove(this->root, item); + if (removed) { + this->size--; + return true; + } + return false; +} + +template +BinaryTreeNode* BinarySearchTree::_remove(BinaryTreeNode* nodePtr, const T& target) { + BinaryTreeNode* parNode = nullptr; + BinaryTreeNode* currNode = nodePtr; + + while (currNode) { + if (currNode->getItem() == target) { + if (!currNode->getLeftPtr() && !currNode->getRightPtr()) { // if target is leaf node + if (!parNode) this->root = nullptr; + else if (parNode->getLeftPtr() == currNode) parNode->setLeftPtr(nullptr); + else parNode->setRightPtr(nullptr); + } + else if (!currNode->getRightPtr()) { // if target has only left child + if (!parNode) this->root = currNode->getLeftPtr(); + else if (parNode->getLeftPtr() == currNode) parNode->setLeftPtr(currNode->getLeftPtr()); + else parNode->setRightPtr(currNode->getLeftPtr()); + } + else if (!currNode->getLeftPtr()) { // if target has only right child + if (!parNode) this->root = currNode->getRightPtr(); + else if (parNode->getLeftPtr() == currNode) parNode->setLeftPtr(currNode->getRightPtr()); + else parNode->setRightPtr(currNode->getRightPtr()); + } + else { // if target has both left and right child + BinaryTreeNode* sucNode = currNode->getRightPtr(); + while (sucNode->getLeftPtr()) + sucNode = sucNode->getLeftPtr(); + T sucData = sucNode->getItem(); + _remove(this->root, sucData); + currNode->setItem(sucData); + } + return currNode; + } + else if (currNode->getItem() < target) { + parNode = currNode; + currNode = currNode->getRightPtr(); + } + else { + parNode = currNode; + currNode = currNode->getLeftPtr(); + } + } + return nullptr; +} + template bool BinarySearchTree::search(const T& target, T& returnedItem) const { diff --git a/BinaryTree.h b/BinaryTree.h index 2ef44b9..59842aa 100644 --- a/BinaryTree.h +++ b/BinaryTree.h @@ -32,7 +32,7 @@ public: // abstract functions to be implemented by derived class virtual void insert(const T& newData) = 0; - //virtual bool remove(const T &data) = 0; + virtual bool remove(const T &data) = 0; virtual bool search(const T& target, T& returnedItem) const = 0; private: