알고리즘 문제 풀이 일지

백준 25549: 트리의 MEX

여름하인 2024. 12. 27. 08:17

https://www.acmicpc.net/problem/25549

 

이 문제를 접하게 된 건 약간 특이하다. 이전에 틀렸던 문제들을 둘러보다가, 해당 문제의 알고리즘 태그들 중에 오랜만에 보는 태그가 있었고, 기억이 별로 나지 않아서, 복습 겸 이 태그가 달린 문제 중 하나를 골라서 풀었다.

 

분리 집합을 공부하면 각 분리집합의 크기를 구하는 방법을 배우게 된다.

int find(int a)
{
	if (a == P[a]) return a;
	return P[a] = find(P[a]);
}
void merge(int a, int b)
{
	a = find(a); b = find(b);
	if (a == b) return;
	P[b] = a;
	cnt[a] += cnt[b];
}

이 코드에서의 cnt가 바로 분리집합의 크기이다.

 

그런데, 필요한 것이 분리집합에 속한 유일한 원소들이라면 어떻게 해야할까?

예로, S1 = {1, 3, 6}, S2 = {2, 3, 8} 이라면, 필요한 것은 {1, 2, 3, 6, 8} 이다.

이러면 크기처럼 단순 숫자 덧셈으로는 표현할 수 없다.

따라서 중복을 허용하지 않는 자료구조를 활용해서 집합을 구현한 뒤, 합치는 작업을 해야한다.

 

이 때, 사용하는 기법이 Small to Large Trick이다.

Small To Large Trick은 분리 집합을 합칠 때 사용하는 트릭으로 큰 집합을 작은 집합으로 옮기는 것보다. 작은 집합을 큰 집합으로 옮기는게 시간이 덜 걸린다는 당연한 사실을 유창하게 말한 것이 Small To Large Trick이다.

 

중복을 허용하지 않는 자료구조는 C++의 Set을 사용한다고 할 때, 합치는 작업을 다음과 같이 구현을 할 수가 있다.

for (int i = 1; i <= N; i++)
{
	cin >> C[i];
	S[i].insert(C[i]);
}
...

int p1 = find(P[num]);
int p2 = find(num);
if (S[p2].size() > S[p1].size()) swap(p1, p2);
merge(p1, p2)
for (auto it = S[p2].begin(); it != S[p2].end(); it++)
	S[p].insert(*it);

이렇게하면 더욱 빠르게 집합을 합칠 수 있다.

 

그렇다면 이것을 사용해서 이 문제를 어떻게 풀 수 있을까?

필자의 사전엔 분리집합에 끊는다란 개념이 없다. 그렇다면, 리프 노드부터 루트 노드까지 올라가며 노드를 합치는 작업을 하고, 해당 집합 내에 없는 수 중 음수가 아닌, 가장 작은 수를 찾으면 된다.

다행히 set은 정렬된 채로 값을 저장한다. 따라서, 가장 첫번째 원소를 확인해서, 그것이 0인지 확인한다.

0이 아니라면, 해당 노드의 MEX는 0이다.

0이라면, 그 다음 원소를 확인한다. 만약 그게 (이전 원소 + 1) 보다 크다면, 해당 노드의 MEX는 (다음 노드 - 1)이다.

아니라면, 다음 원소를 차례차례 확인한다.

만약 모든 원소가 전부 (이전 원소 + 1)이라면, (끝 원소 + 1)이 MEX이다.

 

		if (*(S[s].begin()) != 0) ans[now] = 0;
		else
		{
			for (auto iter = S[s].begin(); iter != S[s].end(); iter++)
			{
				auto iter2 = iter;
				iter2++;
				if (iter2 == S[s].end() || (*iter) + 1 != *iter2)
				{
					ans[now] = (*iter) + 1;
                    			break;
				}
			}
		}

이 작업을 리프 노드부터 실행한다. 리프노드는 어떻게 구할 수 있을까?

필자는 여기서 위상정렬이 떠올랐다.

특정 노드를 가르키는 갯수를 저장한 뒤에, 0부터 차례차례 노드를 확인하면 된다. 확인한 뒤에 부모 노드를 가르키는 갯수를 빼고, 만약 그것이 0이면 확인할 노드에 추가하는 방식이다.

		queue<int> q;
		for (int i = 1; i <= N; i++)
			if (indegree[i] == 0)
				q.push(i);
        	....
    		if (P[now] >= 0)
		{
			indegree[P[now]]--;
			int p1 = find(P[now]);
			int p2 = find(now);
			if (S[p1].size() > S[p2].size()) swap(p1, p2);
			parent[p1] = p2;
			for (auto iter = S[p1].begin(); iter != S[p1].end(); iter++)
				S[p2].insert(*iter);
			if (indegree[P[now]] == 0)
				q.push(P[now]);
		}

그렇게 나온 전체코드는 다음과 같다.

#include <iostream>
#include <set>
#include <queue>

using namespace std;

const int MAX = 200001;
set<int> S[MAX];
int N, P[MAX], V, indegree[MAX], parent[MAX], ans[MAX];

int find(int a)
{
	if (parent[a] == a) return a;
	return parent[a] = find(parent[a]);
}

int main()
{
	ios_base::sync_with_stdio(false);
	cin.tie(NULL); cout.tie(NULL);
	cin >> N;
	for (int i = 1; i <= N; i++)
	{
		parent[i] = i;
		cin >> P[i];
		if (P[i] < 0) continue;
		indegree[P[i]]++;
	}
	for (int i = 1; i <= N; i++)
	{
		cin >> V;
		S[i].insert(V);
	}
	queue<int> q;
	for (int i = 1; i <= N; i++)
		if (indegree[i] == 0)
			q.push(i);

	while (!q.empty())
	{
		int now = q.front();
		q.pop(); int s = find(now);
		if (*(S[s].begin()) != 0) ans[now] = 0;
		else
		{
			for (auto iter = S[s].begin(); iter != S[s].end(); iter++)
			{
				auto iter2 = iter;
				iter2++;
				if (iter2 == S[s].end() || (*iter) + 1 != *iter2)
				{
					ans[now] = (*iter) + 1;
                    break;
				}
			}
		}
		if (P[now] >= 0)
		{
			indegree[P[now]]--;
			int p1 = find(P[now]);
			int p2 = find(now);
			if (S[p1].size() > S[p2].size()) swap(p1, p2);
			parent[p1] = p2;
			for (auto iter = S[p1].begin(); iter != S[p1].end(); iter++)
				S[p2].insert(*iter);
			if (indegree[P[now]] == 0)
				q.push(P[now]);
		}
	}
	for (int i = 1; i <= N; i++)
		cout << ans[i] << '\n';
}

알고리즘 복습하는 데 좋은 문제였다.