https://www.acmicpc.net/problem/25402
DFS?
처음 내가 생각한 아이디어는 단순 DFS를 사용해서, S에 속한 정점들만, 탐색하여 사이즈를 구한 뒤에, Size * (Size - 1) / 2로 쌍의 갯수를 더하므로 답을 구하는 심플한 아이디어였다. 시간 초과가 의심되지만, 혹시나 하는 마음에 코드를 구현해보았다.
#include <iostream>
#include <vector>
using namespace std;
using ll = long long;
const int MAX = 250001;
vector<int> adj[MAX];
ll N, Q, K, A, B, S[MAX];
bool v[MAX], g[MAX];
ll dfs(int here)
{
ll ret = 1;
for (int there : adj[here])
{
if (v[there] || !g[there]) continue;
v[there] = true;
ret += dfs(there);
}
return ret;
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL); cout.tie(NULL);
cin >> N;
for(int i = 0 ;i < N - 1; i++)
{
cin >> A >> B;
adj[A].push_back(B);
adj[B].push_back(A);
}
cin >> Q;
while (Q--)
{
ll ans = 0;
cin >> K;
for (int i = 0; i < K; i++)
{
cin >> S[i];
g[S[i]] = true;
}
for (int i = 0; i < K; i++)
{
if (v[S[i]]) continue;
v[S[i]] = true;
ll tmp = dfs(S[i]);
ans += (tmp * (tmp - 1)) / 2;
}
for (int i = 0; i < K; i++)
{
v[S[i]] = false;
g[S[i]] = false;
}
cout << ans << '\n';
}
}
아니나 다를까, 부분적으로 정답이였다.(24점), 다른 문제였다면 그냥 "시간초과"다.
다른 아이디어로 집합 S에 포함된, 모든 간선을 구한 뒤에, 간선이 잇는 두 정점을 DisjoingSet으로 merge하는 방식을 떠올리긴 했으나, 간선의 계산방법을 전혀 떠오르지 못해서 이 아이디어를 포기했었다.
Disjoint Set!
결국 고민하느라 시간을 써버린 나는 구글링을 거쳤다.
아니나 다를까, 바로 위의 아이디어를 활용하는 것이 정답이였다.
내가 하나 놓치고 있던 사실이 있었는데, 입력으로 주어지는 그래프는 트리란 사실이였다. 즉, 루트 노드를 제외한, 하나의 정점은 반드시 부모 노드를 가진다. 이러한 원리로, 집합 S에 포함된 모든 간선을 구할 수 있다. 부모 노드와 자식 노드가 둘 다 집합 S에 포함 되어있으면, 그 간선은 집합 S에 포함된다!
#include <iostream>
#include <vector>
using namespace std;
using ll = long long;
const int MAX = 250001;
vector<int> adj[MAX];
ll N, Q, K, A, B, S[MAX], P[MAX];
ll cnt[MAX], parent[MAX];
bool v[MAX], v2[MAX];
void dfs(int here, int p)
{
parent[here] = p;
for (int there : adj[here])
{
if (p == there) continue;
dfs(there, here);
}
}
int find(int a)
{
if (a == P[a]) return a;
return P[a] = find(P[a]);
}
void merge(int a, int b)
{
a = find(a);
b = find(b);
if (a > b) swap(a, b);
P[b] = a;
cnt[a] += cnt[b];
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL); cout.tie(NULL);
cin >> N;
for(int i = 0 ;i < N - 1; i++)
{
cin >> A >> B;
adj[A].push_back(B);
adj[B].push_back(A);
}
dfs(1, 0);
cin >> Q;
while (Q--)
{
ll ans = 0;
cin >> K;
for (int i = 0; i < K; i++)
{
cin >> S[i];
v[S[i]] = true;
P[S[i]] = S[i];
cnt[S[i]] = 1;
}
for (int i = 0; i < K; i++)
if (v[parent[S[i]]])
merge(S[i], parent[S[i]]);
for (int i = 0; i < K; i++)
{
int idx = find(S[i]);
if (!v2[idx])
{
v2[idx] = true;
ans += (cnt[idx] * (cnt[idx] - 1)) / 2;
}
}
for (int i = 0; i < K; i++)
{
v[S[i]] = false;
P[S[i]] = S[i];
cnt[S[i]] = 0;
v2[S[i]] = false;
}
cout << ans << '\n';
}
}
중간에, int형과 long long 타입 변환 문제로 한 번 틀렸다.
트리의 특징을 생각해보게 되는 다소 독특한 문제였다.
'알고리즘 문제 풀이 일지' 카테고리의 다른 글
백준 2283: 구간 자르기 (0) | 2024.11.13 |
---|---|
백준 14427: 수열과 쿼리 15 (0) | 2024.11.09 |
백준 20952: 게임 개발자 승희 (0) | 2024.10.31 |
백준 10423: 전기가 부족해 (0) | 2024.10.27 |
백준 13334: 철도 (0) | 2024.10.22 |