가장 먼저 떠오른 건 우선순위 큐였지만 당연하게 시간 초과가 났습니다. 사실 제출할 때도 그렇게 큰 기대 없이 냈기 때문에 다른 로직을 생각해야 했고... 그러다가 합병 정렬 진행 과정이 불현듯 스쳐갔습니다. 시간 복잡도를 줄이려면 반드시 트리 구조여야 할 텐데, 합병정렬에서 분할 정복을 통해 모든 엘리먼트를 최소 단위까지 쪼갠 기억이 난 것이죠.
합병 정렬을 그림으로 그리고, 위 과정에서 정렬 대신 최솟값만 취하면 훨씬 빠르지 않을까? 하는 생각이었습니다.
쪼갠 후, 정렬 과정을 생략하고 최소 값만 모아봤습니다. 이렇게 보니 시간 초과를 극복할 수 있다는 예감이 강하게 듭니다. 이제 트리를 만들고, 범위 값을 찾는 과정을 생각해 보겠습니다.
배열이 있고, 트리의 노드는 포함한 범위 정보와 그 범위 안의 최솟값을 가지고 있습니다. 트리의 리프노드는 시작과 끝이 같은 1의 범위를 가지고 있고, 그 1의 범위를 인덱스로 사용하는 배열의 값을 가지고 있습니다. index 2-6까지의 범위 중 최솟값을 탐색해 보겠습니다.
왼쪽은 0-3의 범위, 오른쪽은 4-7의 범위를 가리키고 있으므로 양쪽 모두 탐색을 진행합니다.
0-3 위치의 왼쪽은 최대 1까지의 범위만 포함하고 있습니다. 탐색은 오른쪽으로 진행됩니다. 여기서 2-3 범위는 전체가 2-6에 포함되기 때문에 더 이상 진행되지 않습니다. 오른쪽은 이전 스텝과 마찬가지로 4-5, 6-7 양쪽으로 진행됩니다.
4-5 또한 2-6 범위에 완전히 포함되기 때문에 더이상 탐색을 진행하지 않습니다. 6-7에선 7이 범위 밖이므로, 6으로 한번 더 탐색을 진행합니다. 6은 리프노드이므로 여기서 모든 탐색이 종료됩니다.
이제 탐색이 종료됐으니, 노드가 표시하는 인덱스 정보를 해당하는 배열의 값으로 변환해 보겠습니다. 탐색이 진행된 시점의 값은 3, 1, 7이므로 세 값 중 가장 작은 값이 2-6 범위에서 가장 작은 값이 됩니다.
배열의 값과 비교했을 때 최솟값은 1로 일치합니다.
아래는 구현한 소스코드 입니다.
먼저 Java입니다.
package solution.hanghae.boj;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;
public class test10868 {
static int N, M;
static int[] numbers;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
N = Integer.parseInt(st.nextToken());
M = Integer.parseInt(st.nextToken());
numbers = new int[N];
for (int i = 0; i < N; i++) {
numbers[i] = Integer.parseInt(br.readLine());
}
Node node = new Node(0, N - 1);
StringBuilder sb = new StringBuilder();
for (int i = 0; i < M; i++) {
st = new StringTokenizer(br.readLine());
int a = Integer.parseInt(st.nextToken()) - 1;
int b = Integer.parseInt(st.nextToken()) - 1;
sb.append(node.search(a, b)).append('\n');
}
System.out.println(sb);
}
static class Node {
int start, end, min;
Node parent, left, right;
// 루트 노드용 생성자
Node(int start, int end) {
this.start = start;
this.end = end;
init();
}
// 배열 범위를 생성자 파라미터로 입력 받습니다.
Node(int start, int end, Node parent) {
this.start = start;
this.end = end;
this.parent = parent;
// 세팅 시작
init();
}
// 범위를 기준으로 좌/우로 나눠 노드를 생성해 left, right에 넣어줍니다.
// 범위가 최소 단위라면, 배열 값을 min으로 지정해줍니다.
private void init() {
if (start != end) {
int mid = (start + end) / 2;
left = new Node(start, mid, this);
right = new Node(mid + 1, end, this);
min = Math.min(left.min, right.min);
} else {
this.min = numbers[start];
}
}
// 범위를 기준으로 탐색합니다.
public int search(int s, int e) {
// 범위가 완전히 일치하면 최소값 return
if (s == this.start && e == this.end) {
return this.min;
}
int ret = 1_000_000_001;
// 시작점이 왼쪽 노드의 끝점보다 작거나 같다면 왼쪽 노드의 끝점/현재 탐색지점의 끝 점중 작은 값을 범위로 넘겨줍니다.
if (s <= left.end) {
ret = Math.min(left.search(s, Math.min(left.end, e) ), ret);
}
// 끝 점이 오른쪽 노드의 시작점보다 크거나 같다면 현재 시작점과 오른쪽 노드의 시작점 중 큰 값을 시작 지점으로 넘겨줍니다.
if (e >= right.start) {
ret = Math.min(right.search(Math.max(right.start, s), e), ret);
}
return ret;
}
}
}
위 코드를 바탕으로 python으로 옮겨 적은 코드도 첨부합니다. 사실 python은 이제 문법을 겨우 알아가는 단계여서 옮기는데 애를 좀 먹었습니다.
import sys
input = sys.stdin.readline
class SegmentTree:
def __init__(self, numbers):
self.numbers = numbers
self.root = self.Node(0, len(numbers) - 1, self.numbers)
class Node:
def __init__(self, start, end, numbers):
self.start = start
self.end = end
self.min = float("inf")
self.left = None
self.right = None
self.numbers = numbers
self.init()
def init(self):
if self.start != self.end:
mid = (self.start + self.end) // 2
self.left = SegmentTree.Node(self.start, mid, self.numbers)
self.right = SegmentTree.Node(mid + 1, self.end, self.numbers)
self.min = min(self.left.min, self.right.min)
else:
self.min = self.numbers[self.start]
def search(self, s, e):
if s == self.start and e == self.end:
return self.min
ret = 1_000_000_000
if s <= self.left.end:
ret = min(ret, self.left.search(s, min(e, self.left.end)))
if e >= self.right.start:
ret = min(ret, self.right.search(max(s, self.right.start), e))
return ret
n, m = list(map(int, input().strip().split()))
numbers = [int(input()) for _ in range(n)]
answer = []
segmentTree = SegmentTree(numbers)
for i in range(m):
a, b = list(map(int, input().strip().split()))
answer.append(segmentTree.root.search(a - 1, b - 1))
print("\n".join(str(x) for x in answer))
'Algorithm > 문제 풀이' 카테고리의 다른 글
BOJ 11726 - 2 x n 타일링 (0) | 2024.06.17 |
---|---|
BOJ 11834 - 홀짝 (0) | 2024.06.09 |
BOJ 3080 - 아름다운 이름 [Java / Python] (0) | 2024.06.02 |
BOJ 2568 - 전깃줄 2 (0) | 2024.05.29 |
BOJ 12015 - 가장 긴 증가하는 부분 수열 2 (0) | 2024.05.27 |