cubelover의 블로그

Splay Tree

자료구조2016. 9. 4. 00:39

Splay Tree는 Binary Search Tree의 한 종류이다.


삽입, 삭제, 검색 등의 쿼리를 amortized O(log n)에 처리 가능하며 Splay 연산을 이용해서 구간에 대한 쿼리가 자유롭고 AVL Tree나 Red-Black Tree와 같은 다른 Binary Search Tree보다 구현이 단순한 편이기 때문에 알아두면 좋다.


~


Splay Tree는 쿼리로 들어온 노드에 대해 splay 연산을 행해서 amortized O(log n) 시간에 동작하는 자료구조이다.


Splay 연산은 임의의 노드 x를 루트로 만드는 연산으로, 아래와 같은 과정을 통해 이루어진다. 여기서 Rotate(x)는 다음처럼 x를 x의 부모 위치로 올리는 연산이다.



1. x가 루트이면, 루트를 만드는 데 성공했으므로 종료한다.

2. x의 부모 p가 루트이면, Rotate(x)를 행하고 종료한다. (Zig Step)

3. x의 조부모를 g라고 하면, 다음 두 가지 경우가 있다.

3-1. g→p의 방향과 p→x의 방향이 같은 경우, Rotate(p) 이후 Rotate(x)를 행한다. (Zig-Zig Step)

3-2. g→p의 방향과 p→x의 방향이 다른 경우, Rotate(x)를 두 번 행한다. (Zig-Zag Step)

4. 1로 돌아가서 루트가 될 때까지 반복한다.


Zig-Zig Step과 Zig-Zag Step이 어떻게 동작하는지는 다음 그림을 보면 알 수 있다.


Zig-Zig Step



Zig-Zag Step



어떻게 이런 단순한(?) 과정을 통해 amortized O(log n)의 시간복잡도가 나오는지는 여기에 설명되어 있다.


~


노드 구조체


먼저 가장 기본이 되는 노드 구조체를 만든다.


struct node {
	node *l, *r, *p;
} *tree;


Splay Tree를 만들기 위해서는 부모를 가리키는 포인터, 자식들을 가리키는 포인터 총 세 개만 있으면 된다.


삽입, 삭제, 검색 쿼리를 위해서는 key가 필요하지만 그 부분은 아래에서 다시 다룬다.


~


회전


다음으로 회전을 만든다.


void Rotate(node *x) {
	node *p = x->p;
    node *b;
	if (x == p->l) {
		p->l = b = x->r;
		x->r = p;
	} else {
		p->r = b = x->l;
		x->l = p;
	}
	x->p = p->p;
	p->p = x;
    if (b) b->p = p;
	(x->p ? p == x->p->l ? x->p->l : x->p->r : tree) = x;
}


~


Splay


이제 Splay 연산을 만든다.


void Splay(node *x) {
    while (x->p) {
        node *p = x->p;
        node *g = p->p;
        if (g) Rotate((x == p->l) == (p == g->l) ? p : x);
        Rotate(x);
    }
}


~


삽입, 삭제, 검색


이로써 Splay Tree를 완성했다(?) 임의의 노드에 대해 Splay 연산을 행하면 해당 노드가 루트가 되고, 이 때 inorder 순서가 유지된다. 이제 남은 것은 key를 추가하는 것이다.


먼저 노드 구조체에 key라는 변수를 추가한다. 다음처럼 될 것이다.


struct node {
	node *l, *r, *p;
	int key;
} *tree;


이제 가장 단순한 Insert 함수를 만들자. 다음처럼 될 것이다.


void Insert(int key) {
	node *p = tree, **pp;
	if (!p) {
		node *x = new node;
		tree = x;
		x->l = x->r = x->p = NULL;
		x->key = key;
		return;
	}
	while (1) {
		if (key == p->key) return;
		if (key < p->key) {
			if (!p->l) {
				pp = &p->l;
				break;
			}
			p = p->l;
		} else {
			if (!p->r) {
				pp = &p->r;
				break;
			}
			p = p->r;
		}
	}
	node *x = new node;
	*pp = x;
	x->l = x->r = NULL;
	x->p = p;
	x->key = key;
	Splay(x);
}


마지막에 Splay 연산이 들어갔음에 유의하라. Splay Tree는 이처럼 삽입, 삭제, 검색한 노드에 대해 Splay 연산을 행함으로써 amortized O(log n) 시간복잡도로 만든다.


다음으로 Find 함수를 만들자. Find 함수를 호출한 뒤 루트가 해당 key를 가진 노드가 되므로 다른 부가적인 연산을 행하기 쉽다.


bool Find(int key) {
	node *p = tree;
	if (!p) return false;
	while (p) {
		if (key == p->key) break;
		if (key < p->key) {
			if (!p->l) break;
			p = p->l;
		} else {
			if (!p->r) break;
			p = p->r;
		}
	}
	Splay(p);
	return key == p->key;
}


마지막으로 Delete 함수이다. 삭제를 위해서 해당하는 노드를 Splay한 뒤, 자식이 0개 또는 1개인 경우 그냥 삭제하고 2개인 경우 두 서브트리를 붙여준다.


void Delete(int key) {
	if (!Find(key)) return;
	node *p = tree;
	if (p->l) {
		if (p->r) {
			tree = p->l;
			tree->p = NULL;
			node *x = tree;
			while (x->r) x = x->r;
			x->r = p->r;
			p->r->p = x;
			Splay(x);
			delete p;
			return;
		}
		tree = p->l;
		tree->p = NULL;
		delete p;
		return;
	}
	if (p->r) {
		tree = p->r;
		tree->p = NULL;
		delete p;
		return;
	}
	delete p;
	tree = NULL;
}


~


K번째 원소 찾기


사실 삽입, 삭제, 검색 연산은 set과 map에서도 지원하는 기능이기 때문에 이것을 위해 Binary Search Tree를 구현할 필요는 없다. 다만 Binary Search Tree를 직접 구현하면 K번째 원소를 O(log n)에 찾는 것이 가능하다는 것이 큰 장점이다.


K번째 원소를 찾기 위해서 노드 구조체에 서브트리에 있는 노드 개수를 저장하는 변수를 만들 필요가 있다. key가 없어도 동작하므로 여기서는 제외하고 설명한다.


struct node {
	node *l, *r, *p;
	int cnt;
} *tree;


cnt 변수의 갱신을 위해 Update 함수를 만든다. 이렇게 따로 분리해두면 K번째 원소 찾기 외에 구간 합 구하기 등 다른 기능들을 추가하기 편하다.


void Update(node *x) {
	x->cnt = 1;
	if (x->l) x->cnt += x->l->cnt;
	if (x->r) x->cnt += x->r->cnt;
}


그리고 Rotate 함수의 맨 마지막 부분에 다음을 추가하자.


void Rotate(node *x) {
	// ...
	Update(p);
	Update(x);
}


이로써 K번째 원소를 찾을 준비가 완료되었다. 이제 K번째 원소를 찾는 함수를 만들자.


void Find_Kth(int k) {
	node *x = tree;
	while (1) {
		while (x->l && x->l->cnt > k) x = x->l;
		if (x->l) k -= x->l->cnt;
		if (!k--) break;
		x = x->r;
	}
	Splay(x);
}


이 함수를 호출하면 K번째 원소가 루트가 된다. 여기서 K는 0-based이다.


~


구간 합 구하기


이제 본격적으로 Splay Tree를 문제 풀이에 사용해보자. 먼저 가장 간단한 구간 합 구하기부터 한다.


먼저 노드 구조체와 Update 함수를 고쳐서 합을 구할 수 있도록 한다. 또한 K번째를 찾는 기능도 필요하므로 이 또한 구현되어 있어야 한다.


struct node {
	node *l, *r, *p;
	int cnt;
	int sum, value;
} *tree;

void Update(node *x) {
	x->cnt = 1;
	x->sum = x->value;
	if (x->l) {
		x->cnt += x->l->cnt;
		x->sum += x->l->sum;
	}
	if (x->r) {
		x->cnt += x->r->cnt;
		x->sum += x->r->sum;
	}
}


처음에 노드를 구간 길이만큼 만들고 적당히 연결해 주면 초기화가 끝난다. 다음처럼 구현할 수 있다.


void Initialize(int n) {
	node *x;
	int i;
	tree = x = new node;
	x->l = x->r = x->p = NULL;
	x->cnt = n;
	x->sum = x->value = 0;
	for (i = 1; i < n; i++) {
		x->r = new node;
		x->r->p = x;
		x = x->r;
		x->l = x->r = NULL;
		x->cnt = n - i;
		x->sum = x->value = 0;
	}
}


먼저 i번째 원소에 값 z를 더하는 것을 구현하자. i번째 원소를 찾고 sum과 value에 z를 더해주면 된다.


void Add(int i, int z) {
	Find_Kth(i);
	tree->sum += z;
	tree->value += z;
}


이제 구간에 대한 합을 구해보자. 먼저 L부터 R까지의 구간(inclusive)을 한 노드로 모아주기 위해 다음 과정을 거친다.



즉, L-1번째 원소에 Splay 연산을 행한 뒤 오른쪽 서브트리에서 R-L+1번째 원소(전체에서 R+1)번째 원소에 Splay 연산을 행하면 L부터 R까지의 구간이 한 노드에 모이게 된다.


void Interval(int l, int r) {
	Find_Kth(l - 1);
	node *x = tree;
	tree = x->r;
	tree->p = NULL;
	Find_Kth(r - l + 1);
	x->r = tree;
	tree->p = x;
	tree = x;
}


Splay를 하는 과정에서 Rotate 함수를 호출하고, Rotate 함수에서 Update 함수를 호출하므로 해당하는 노드의 sum 값이 곧 구간의 합이 된다.


L이 구간 왼쪽 끝이거나 R이 구간 오른쪽 끝인 경우 예외처리를 해 줘야 하는데, 이는 더미노드를 2개 더 만듦으로써 쉽게 해결할 수 있다.


int Sum(int l, int r) {
	Interval(l, r);
	return tree->r->l->sum;
}


~


Lazy Propagation


Splay Tree에서 구간을 한 노드로 만들 수 있으므로 Lazy Propagation 또한 가능하다. 먼저 노드 구조체에 해당 구간에 더해진 값을 저장할 변수를 하나 추가하자.


struct node {
	node *l, *r, *p;
	int cnt;
	int sum, value, lazy;
} *tree;


그리고 lazy를 뿌려주는 함수를 작성한다.


void Lazy(node *x) {
	x->value += x->lazy;
	if (x->l) {
		x->l->lazy += x->lazy;
		x->l->sum += x->l->cnt * x->lazy;
	}
	if (x->r) {
		x->r->lazy += x->lazy;
		x->r->sum += x->r->cnt * x->lazy;
	}
	x->lazy = 0;
}


Lazy는 자식으로 내려갈 때마다 호출을 해 주어야 한다. 이 경우에는 Find_Kth 함수에만 추가해주면 된다.


void Find_Kth(int k) {
	node *x = tree;
	Lazy(x);
	while (1) {
		while (x->l && x->l->cnt > k) {
			x = x->l;
			Lazy(x);
		}
		if (x->l) k -= x->l->cnt;
		if (!k--) break;
		x = x->r;
		Lazy(x);
	}
	Splay(x);
}


이제 L부터 R까지의 구간에 값 z를 더해주는 함수를 만들자. L부터 R까지의 구간을 한 노드로 모아주고 값을 더해주면 된다.


void Add(int l, int r, int z) {
	Intrerval(l, r);
	node *x = tree->r->l;
	x->sum += x->cnt * z;
	x->lazy += z;
}


~


구간 뒤집기


Splay Tree를 이용하면 구간 뒤집기를 할 수 있다. 또한 구간 뒤집기를 하면서 구간 쿼리도 가능하다(!)


L부터 R까지의 구간을 뒤집는다는 것은 다음처럼 바뀐다는 뜻이다.


..., L-1, L, L+1, ..., R-1, R, R+1, ... → ..., L-1, R, R-1, ..., L+1, L, R+1, ...


L부터 R까지의 구간을 한 노드로 모으면 L과 R 사이에 있는 어떤 원소 K가 해당 구간을 나타낼 것이다.


..., L-1, L, L+1, ..., K-1, K, K+1, ..., R-1, R, R+1, ... → ..., L-1, R, R-1, ..., K+1, K, K-1, ..., L+1, L, R+1, ...


이를 주의깊게 보면, [L, K-1](K의 왼쪽 서브트리)과 [K+1, R](K의 오른쪽 서브트리)을 바꾼 뒤 각 구간을 다시 뒤집어 준 것과 같다는 것을 알 수 있다. 따라서 Lazy Propagation으로 이를 처리할 수 있다.


먼저 해당 구간이 뒤집혔는지 아닌지를 나타내는 변수를 만든다.


struct node {
	node *l, *r, *p;
	bool inv;
} *tree;


그리고 Lazy 함수를 다음처럼 만든다.


void Lazy(node *x) {
	if (!x->inv) return;
	node *t = x->l;
	x->l = x->r;
	x->r = t;
	x->inv = false;
	if (x->l) x->l->inv = !x->l->inv;
	if (x->r) x->r->inv = !x->r->inv;
}


마지막으로 구간을 뒤집어주는 함수를 만든다.


void Reverse(int l, int r) {
	Interval(l, r);
	node *x = tree->r->l;
	x->inv = !x->inv;
}


~


Splay Tree 연습문제


https://www.acmicpc.net/problem/3421

https://www.acmicpc.net/problem/3444

https://www.acmicpc.net/problem/13159

https://www.acmicpc.net/problem/13543

'자료구조' 카테고리의 다른 글

Persistent Segment Tree  (1) 2016.11.03