On Github lionell / data_structures
Created by Ruslan Sakevych / @lionell
int a[] = {1, 2, 3, 4, 5, 6};
build();
sum(1, 4); // 2 + 3 + 4 + 5 = 16
sum(4, 5); // 5 + 6 = 11
update(2, 10); // a[2] = 10
sum(2, 3); // 2 + 10 = 12
int t[4 * MAX_N]; // MEMORY LIMIT
void build(int v = 1, int tl = 0,
int tr = n - 1) {
if (tl == tr) {
t[v] = a[tl];
} else {
int tm = (tl + tr) / 2;
build(2 * v, tl, tm);
build(2 * v + 1, tm + 1, tr);
t[v] = t[2 * v] + t[2 * v + 1];
}
}
int sum(int l, int r, int v = 1,
int tl = 0, int tr = n - 1) {
if (l > r) {
return 0;
}
if (tl == l && tr == r) {
return t[v];
}
int tm = (tl + tr) / 2;
return sum(l, min(r, tm), 2 * v, tl, tm)
+ sum(max(tm + 1, l), r, a, 2 * v + 1, tm + 1, tr);
}
void update(int i, int val, int v = 1,
int tl = 0, int tr = n - 1) {
if (tl == tr) {
t[v] = val;
} else {
int tm = (tl + tr) / 2;
if (i <= tm) {
update(i, val, 2 * v, tl, tm);
} else {
update(i, val, 2 * v + 1, tm + 1, tr);
}
t[v] = t[2 * v] + t[2 * v + 1];
}
}
... int tm = tl + (tr - tl) / 2; // instead of int tm = (tl + tr) / 2; ...
x >> 1 // instead of x / 2 x << 1 // instead of x * 2
Recursive call of update is tail-call.
So we can easily convert it to loop.
Now, we need some universal method for queries.
Let's define some function that can combine useful information from child nodes.
T t[MAX_N];
T combine(T l, T r) {
...
}
T make(...) {
...
}
void build(...) {
if (tl == tr) {
// t[v] = a[tl];
t[v] = make(a[tl]);
} else {
int tm = tl + (tr - tl) / 2;
build(2 * v, tl, tm);
build(2 * v + 1, tm + 1, tr);
// t[v] = t[2 * v] + t[2 * v + 1];
t[v] = combine(t[2 * v], t[2 * v + 1]);
}
}
void update(int i, T val, ...) {
if (tl == tr) {
// t[v] = val;
t[v] = make(val);
} else {
int tm = tl + (tr - tl) / 2;
if (i <= tm) {
update(i, val, 2 * v, tl, tm);
} else {
update(i, val, 2 * v + 1, tm + 1, tr);
}
// t[v] = t[2 * v] + t[2 * v + 1];
t[v] = combine(t[2 * v], t[2 * v + 1]);
}
}
API example
int a[] = {0, 0, 0, 0, 0};
inc(1, 3, 1); // a = {0, 1, 1, 1, 0}
inc(0, 2, -2); // a = {-2, -1, -1, 1, 0}
get(1); // a[1] == -1
void inc(int l, int r, int x, ...) {
if (l > r) {
return;
}
if (tl == l && tr == r) {
t[v] += x;
} else {
int tm = tl + (tr - tl) / 2;
inc(l, min(r, tm), x, 2 * v, tl, tm);
inc(max(tm + 1, l), x, r, 2 * v + 1, tm + 1, tr);
}
}
int get(int i, ...) {
if (tl == tr) {
return a[i];
}
int tm = tl + (tr - tl) / 2;
if (i <= tm) {
return t[v] + get(i, 2 * v, tl, tm);
}
return t[v] + get(i, 2 * v + 1, tm + 1, tr);
}
API example
int a[] = {0, 0, 0, 0, 0};
let(3, 4, 1); // a = {0, 0, 0, 1, 1}
let(2, 3, 7); // a = {0, 0, 7, 7, 1}
get(3); // a[1] == 7
void push(int v) {
if (t[v] == -1) {
return;
}
t[v * 2] = t[2 * v + 1] = t[v];
t[v] = -1;
}
void let(int l, int r, int x, ...) {
if (l > r) {
return;
}
if (tl == l && tr == r) {
t[v] = x;
} else {
push(v);
int tm = tl + (tr - tl) / 2;
let(l, min(r, tm), x, 2 * v, tl, tm);
let(max(l, tm + 1), x, r, 2 * v + 1, tm + 1, tr);
}
}
int get(int i, ...) {
if (tl == tr) {
return t[v];
}
push(v);
int tm = tl + (tr - tl) / 2;
if (i <= tm) {
return get(i, 2 * v, tl, tm);
}
return get(i, 2 * v + 1, tm + 1, tr);
}
API example
int a[] = {1, 2, 0, 4, 0};
build();
let(3, 4, 1); // a = {1, 2, 0, 1, 1}
let(2, 3, 7); // a = {1, 2, 7, 7, 1}
sum(1, 3); // 2 + 7 + 7 = 16
pii get(int v, int tl, int tr) {
return t[v].second == INF
? t[v].first
: t[v].second * (tr - tl + 1);
}
void push(int v) {
if (t[v].second == INF) {
return;
}
t[2 * v].second = t[2 * v + 1].second = t[v].second;
t[v].second = INF;
}
void let(int l, int r, int x, ...) {
if (l > r) {
return;
}
if (tl == l && tr == r) {
t[v].second = x;
} else {
push(v);
int tm = tl + (tr - tl) / 2;
let(l, min(r, tm), x, 2 * v, tl, tm);
let(max(l, tm + 1), r, x, 2 * v + 1, tm + 1, tr);
t[v] = {get(2 * v, tl, tm)
+ get(2 * v + 1, tm + 1, tr), INF};
}
}
int sum(int l, int r, ...) {
if (l > r) {
return 0;
}
if (tl == l && tr == r) {
return get(v, tl, tr);
}
push(v);
int tm = tl + (tr - tl) / 2;
return sum(l, min(r, tm), 2 * v, tl, tm)
+ sum(max(l, tm + 1), r, 2 * v + 1, tm + 1, tr);
}
API example
t->insert(1); // t = {4}
...
t->insert(6); // t = {1, 2, 3, 4, 5, 6}
t->remove(3); // t = {1, 2, 4, 5, 6}
t->split(2, l, r); // l = {1}, r = {2, 4, 5, 6}
r->sum(); // 2 + 4 + 5 + 6 = 17
t = merge(l, r); // t = {1, 2, 4, 5, 6}
struct Treap {
int x;
int y;
Treap *left;
Treap *right;
Treap(int x, int y, Treap *left, Treap *right);
static Treap *merge(Treap *l, Treap* r);
void split(int key, Treap *l, Treap *r);
Treap *insert(int x);
Treap *remove(int x);
};
static Treap *merge(Treap *l, Treap *r) {
if (l == nullptr) {
return r;
}
if (r == nullptr) {
return l;
}
if (l->y > r->y) {
Treap *newRight = merge(l->right, r);
return new Treap(l->x, l->y, l->left, newRight);
} else {
Treap *newLeft = merge(l, r->left);
return new Treap(r->x, r->y, newLeft, r->right);
}
}
void split(int key, Treap *l, Treap *r) {
Treap *newTree = nullptr;
if (x <= key) {
if (right == nullptr)
r = nullptr;
else
right->split(key, newTree, r);
l = new Treap(x, y, left, newTree);
} else {
if (left == nullptr)
l = nullptr;
else
left->split(key, l, newTree);
r = new Treap(x, y, newTree, right);
}
}
Treap *insert(int x) {
Treap *l, *r;
split(x, l, r);
Treap *m = new Treap(x, rand());
return merge(merge(l, m), r);
}
Treap *remove(int x) {
Treap *l, *m, *r, *t;
split(x, m, r);
m->split(x - 1, l, t);
return merge(l, r);
}