백준 2957: 이진 탐색 트리
https://www.acmicpc.net/problem/2957
이 문제를 처음 보며 생각해낸건 "각 노드의 자식 노드의 갯수를 구해서 더하면 어떨까?" 란 아이디어였다. 하지만, N이 300000이라서 최악의 경우 O(N ^ 2)란 시간복잡도를 지니기에, 초기의 아이디어를 바탕으로 이것을 어떻게 더 효율적으로 문제를 해결할 지 고민을 했었다.
문득, 나는 이 문제를 어디서 본듯한 느낌이 들었다.
그러자 하나 생각난 개념이 있었는데, 이진 탐색 트리는 배열로 나타낼 수 있다. 그림으로 나타내면 다음과 같다.
일단 서로 다른 값을 넣을 것이고, 늘 정렬되어있고, 랜덤 값에 접근을 해야하니, 말이 배열이지, 실 구현은 Set 사용이 바람직할것이다.
여기서, 이것을 활용해서 값을 어떻게 구현할 것이냐? 이다. 여기서 생각의 전환이 필요했다. 앞서 말했듯 나는 값을 넣었을 때, 각 노드의 자식노드 수의 합이라 생각했었다. 하지만 이래선 시간초과가 되어버리기에, 값을 넣을 때마다 기존의 값과 합쳐서 정답을 구해야한다란 생각에 이르렀다.
그러면 값을 넣을 때마다 얼마나 더해야할까? 그것은 값을 추가할 때마다 거친 노드의 수이다. 그렇다면? 추가된 노드의 이진탐색트리의 높이가 정답일것이다. 높이는 어떻게 구해야 할까? 배열의 근처 노드의 높이를 통해 구할 수 있다. 즉, (근처 노드 중 높이가 가장 큰 높이) + 1을 기존의 값을 더하면 정답이다!
다행히 값은 [1, N]이 중복 없이 나와서, 배열을 사용하면 높이 값을 쉽게 구할 수 있을 것이다. 그리고, 근처 노드는 이진 탐색을 활용한 set의 lower_bound를 사용하면 쉽게 구할 수 있다. 그렇게 구현한 코드는 다음과 같다.
#include <cmath>
#include <iostream>
#include <set>
using namespace std;
int H[300001], N, A;
set<int> s;
long long ans = 0;
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL); cout.tie(NULL);
cin >> N;
for (int i = 0; i < N; i++)
{
cin >> A;
if (i == 0)
{
H[A] = 0;
s.insert(A);
cout << 0 << '\n';
}
else
{
auto l = s.lower_bound(A);
if (l == s.end()) H[A] = H[*(--l)] + 1;
else if (l == s.begin()) H[A] = H[*l] + 1;
else
{
H[A] = H[*l] + 1;
H[A] = max(H[A], H[*(--l)] + 1);
}
ans += H[A];
s.insert(A);
cout << ans << '\n';
}
}
}
처음엔 근처 노드를 구하는데, 다소 애먹었는데, 내가 lower_bound와 upper_bound의 기능을 착각해서, 둘 다 사용하다가 예제 3에서 막히고, 기능을 다시 검색해서, 제대로 구현하였다.
내가 사용할 함수는 좀 더 자세히 알아보고 사용하도록 하자.
추신)
이 문제를 해결한 이후로 내가 이 문제와 비슷한 문제를 구글링으로 해결했던 것 같아서 찾아본 결과
https://www.acmicpc.net/problem/1539
9개월 전에 해결했던 문제인 것을 발견했다. 이 문제는 위 문제보다 좀더 직설적으로 모든 노드의 높이의 합을 구하라고 한다. 역시 PS는 문제를 많이 해결하고 보는게 많는것 같다..
#include <iostream>
#include <algorithm>
#include <vector>
#include <set>
using namespace std;
set<int> T;
int P, N, H[250001];
long long ans = 0;
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL); cout.tie(NULL);
cin >> N;
for (int i = 0; i < N; i++)
{
cin >> P;
if (i == 0)
{
H[P] = 1;
T.insert(P);
}
else
{
auto iter = T.lower_bound(P);
if (iter == T.begin())
H[P] = H[*T.begin()] + 1;
else if (iter == T.end())
{
--iter;
H[P] = H[*iter] + 1;
}
else
{
H[P] = H[*iter] + 1; --iter;
H[P] = max(H[P], H[*iter] + 1);
}
T.insert(iter, P);
}
}
for (int i = 0; i < N; i++) ans += H[i];
cout << ans << '\n';
}