백준 25549: 트리의 MEX
https://www.acmicpc.net/problem/25549
이 문제를 접하게 된 건 약간 특이하다. 이전에 틀렸던 문제들을 둘러보다가, 해당 문제의 알고리즘 태그들 중에 오랜만에 보는 태그가 있었고, 기억이 별로 나지 않아서, 복습 겸 이 태그가 달린 문제 중 하나를 골라서 풀었다.
분리 집합을 공부하면 각 분리집합의 크기를 구하는 방법을 배우게 된다.
int find(int a)
{
if (a == P[a]) return a;
return P[a] = find(P[a]);
}
void merge(int a, int b)
{
a = find(a); b = find(b);
if (a == b) return;
P[b] = a;
cnt[a] += cnt[b];
}
이 코드에서의 cnt가 바로 분리집합의 크기이다.
그런데, 필요한 것이 분리집합에 속한 유일한 원소들이라면 어떻게 해야할까?
예로, S1 = {1, 3, 6}, S2 = {2, 3, 8} 이라면, 필요한 것은 {1, 2, 3, 6, 8} 이다.
이러면 크기처럼 단순 숫자 덧셈으로는 표현할 수 없다.
따라서 중복을 허용하지 않는 자료구조를 활용해서 집합을 구현한 뒤, 합치는 작업을 해야한다.
이 때, 사용하는 기법이 Small to Large Trick이다.
Small To Large Trick은 분리 집합을 합칠 때 사용하는 트릭으로 큰 집합을 작은 집합으로 옮기는 것보다. 작은 집합을 큰 집합으로 옮기는게 시간이 덜 걸린다는 당연한 사실을 유창하게 말한 것이 Small To Large Trick이다.
중복을 허용하지 않는 자료구조는 C++의 Set을 사용한다고 할 때, 합치는 작업을 다음과 같이 구현을 할 수가 있다.
for (int i = 1; i <= N; i++)
{
cin >> C[i];
S[i].insert(C[i]);
}
...
int p1 = find(P[num]);
int p2 = find(num);
if (S[p2].size() > S[p1].size()) swap(p1, p2);
merge(p1, p2)
for (auto it = S[p2].begin(); it != S[p2].end(); it++)
S[p].insert(*it);
이렇게하면 더욱 빠르게 집합을 합칠 수 있다.
그렇다면 이것을 사용해서 이 문제를 어떻게 풀 수 있을까?
필자의 사전엔 분리집합에 끊는다란 개념이 없다. 그렇다면, 리프 노드부터 루트 노드까지 올라가며 노드를 합치는 작업을 하고, 해당 집합 내에 없는 수 중 음수가 아닌, 가장 작은 수를 찾으면 된다.
다행히 set은 정렬된 채로 값을 저장한다. 따라서, 가장 첫번째 원소를 확인해서, 그것이 0인지 확인한다.
0이 아니라면, 해당 노드의 MEX는 0이다.
0이라면, 그 다음 원소를 확인한다. 만약 그게 (이전 원소 + 1) 보다 크다면, 해당 노드의 MEX는 (다음 노드 - 1)이다.
아니라면, 다음 원소를 차례차례 확인한다.
만약 모든 원소가 전부 (이전 원소 + 1)이라면, (끝 원소 + 1)이 MEX이다.
if (*(S[s].begin()) != 0) ans[now] = 0;
else
{
for (auto iter = S[s].begin(); iter != S[s].end(); iter++)
{
auto iter2 = iter;
iter2++;
if (iter2 == S[s].end() || (*iter) + 1 != *iter2)
{
ans[now] = (*iter) + 1;
break;
}
}
}
이 작업을 리프 노드부터 실행한다. 리프노드는 어떻게 구할 수 있을까?
필자는 여기서 위상정렬이 떠올랐다.
특정 노드를 가르키는 갯수를 저장한 뒤에, 0부터 차례차례 노드를 확인하면 된다. 확인한 뒤에 부모 노드를 가르키는 갯수를 빼고, 만약 그것이 0이면 확인할 노드에 추가하는 방식이다.
queue<int> q;
for (int i = 1; i <= N; i++)
if (indegree[i] == 0)
q.push(i);
....
if (P[now] >= 0)
{
indegree[P[now]]--;
int p1 = find(P[now]);
int p2 = find(now);
if (S[p1].size() > S[p2].size()) swap(p1, p2);
parent[p1] = p2;
for (auto iter = S[p1].begin(); iter != S[p1].end(); iter++)
S[p2].insert(*iter);
if (indegree[P[now]] == 0)
q.push(P[now]);
}
그렇게 나온 전체코드는 다음과 같다.
#include <iostream>
#include <set>
#include <queue>
using namespace std;
const int MAX = 200001;
set<int> S[MAX];
int N, P[MAX], V, indegree[MAX], parent[MAX], ans[MAX];
int find(int a)
{
if (parent[a] == a) return a;
return parent[a] = find(parent[a]);
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL); cout.tie(NULL);
cin >> N;
for (int i = 1; i <= N; i++)
{
parent[i] = i;
cin >> P[i];
if (P[i] < 0) continue;
indegree[P[i]]++;
}
for (int i = 1; i <= N; i++)
{
cin >> V;
S[i].insert(V);
}
queue<int> q;
for (int i = 1; i <= N; i++)
if (indegree[i] == 0)
q.push(i);
while (!q.empty())
{
int now = q.front();
q.pop(); int s = find(now);
if (*(S[s].begin()) != 0) ans[now] = 0;
else
{
for (auto iter = S[s].begin(); iter != S[s].end(); iter++)
{
auto iter2 = iter;
iter2++;
if (iter2 == S[s].end() || (*iter) + 1 != *iter2)
{
ans[now] = (*iter) + 1;
break;
}
}
}
if (P[now] >= 0)
{
indegree[P[now]]--;
int p1 = find(P[now]);
int p2 = find(now);
if (S[p1].size() > S[p2].size()) swap(p1, p2);
parent[p1] = p2;
for (auto iter = S[p1].begin(); iter != S[p1].end(); iter++)
S[p2].insert(*iter);
if (indegree[P[now]] == 0)
q.push(P[now]);
}
}
for (int i = 1; i <= N; i++)
cout << ans[i] << '\n';
}
알고리즘 복습하는 데 좋은 문제였다.