Skip to main content

Treap

namespace Treap {
std::mt19937 gen(std::chrono::steady_clock::now().time_since_epoch().count());
std::uniform_int_distribution<uint32_t> dis;

class Node {
public:
int32_t k; uint32_t p;
Node *l = nullptr, *r = nullptr;
~Node() { delete l; delete r; }
};

std::tuple<Node*, Node*> split(Node* t, int32_t k) {
if (!t) return {nullptr, nullptr};
if (k < t->k) { auto [tl, tr] = split(t->l, k); t->l = tr; return {tl, t}; }
else { auto [tl, tr] = split(t->r, k); t->r = tl; return {t, tr}; }
}

Node* merge(Node* tl, Node* tr) {
if (!tl) return tr;
if (!tr) return tl;
if (tl->p > tr->p) { tl->r = merge(tl->r, tr); return tl; }
else { tr->l = merge(tl, tr->l); return tr; }
}

void insert(Node* &t, int32_t k) {
uint32_t p = dis(gen);
if (!t) { t = new Node {k, p}; return; }
if (p > t->p) { auto [tl, tr] = split(t, k); t = new Node {k, p, tl, tr}; }
else { insert(k < t->k ? t->l : t->r, k); }
}

void erase(Node* &t, int32_t k) {
if (t->k == k) { t = merge(t->l, t->r); return; }
else { erase(k < t->k ? t->l : t->r, k); }
}
};

Treap with Implicit Key

ImplicitTreap
std::mt19937 gen(3);
std::uniform_int_distribution<> dis(0, INT_MAX);

template<typename E>
class ImplicitTreap {
private:
class Node {
public:
E value;
Node* left = nullptr;
Node* right = nullptr;
int32_t size = 0;
int32_t priority = dis(gen);

void setLeft(Node* node) {
left = node;
applySizeInvariant();
}

void setRight(Node* node) {
right = node;
applySizeInvariant();
}

void applySizeInvariant() {
size = 1 + getSize(left) + getSize(right);
}

explicit Node(E element): value(element) {
applySizeInvariant();
}
};

Node* root = nullptr;

static int32_t getSize(Node* node) {
if (node == nullptr) return 0;
return node->size;
}

pair<Node*, Node*> split(Node* node, int32_t key, int32_t parentImplicitKey = 0) {
if (node == nullptr) return {nullptr, nullptr};
auto implicitKey = parentImplicitKey + getSize(node->left);

if (key <= implicitKey) {
auto [left, right] = split(node->left, key, parentImplicitKey);
node->setLeft(right);
return {left, node};
}
else {
auto [left, right] = split(node->right, key, 1 + implicitKey);
node->setRight(left);
return {node, right};
}
}

pair<Node*, E> erase(Node* node, int32_t key, int32_t parentImplicitKey = 0) {
auto implicitKey = parentImplicitKey + getSize(node->left);

if (key == implicitKey) {
return {merge(node->left, node->right), node->value};
}
else if (key < implicitKey) {
auto [newLeft, removedElement] = erase(node->left, key, parentImplicitKey);
node->setLeft(newLeft);
return {node, removedElement};
}
else {
auto [newRight, removedElement] = erase(node->right, key, 1 + implicitKey);
node->setRight(newRight);
return {node, removedElement};
}
}

Node* merge(Node* leftNode, Node* rightNode) {
if (leftNode == nullptr) return rightNode;
if (rightNode == nullptr) return leftNode;

if (leftNode->priority > rightNode->priority) {
leftNode->setRight(merge(leftNode->right, rightNode));
return leftNode;
}
else {
rightNode->setLeft(merge(leftNode, rightNode->left));
return rightNode;
}
}

public:
int32_t getSize() {
return getSize(root);
}

void add(E element) {
root = merge(root, new Node(element));
}

void add(int32_t index, E element) {
auto [left, right] = split(root, index);
root = merge(merge(left, new Node(element)), right);
}

E get(int32_t index) {
auto [leftWithIndexNode, right] = split(root, index + 1);
auto [left, indexNode] = split(leftWithIndexNode, index);
auto value = indexNode->value;
root = merge(merge(left, indexNode), right);
return value;
}

void set(int32_t index, E element) {
auto [leftWithIndexNode, right] = split(root, index + 1);
auto [left, indexNode] = split(leftWithIndexNode, index);
indexNode->value = element;
root = merge(merge(left, indexNode), right);
}

E removeAt(int32_t index) {
auto [newRoot, removedElement] = erase(root, index);
root = newRoot;
return removedElement;
}
};