알고리즘 공부/문제 풀이

[JAVA] 백준 2213 트리의 독립집합

valid_ming 2021. 11. 9. 21:32

https://www.acmicpc.net/problem/2213

 

2213번: 트리의 독립집합

첫째 줄에 트리의 정점의 수 n이 주어진다. n은 10,000이하인 양의 정수이다. 1부터 n사이의 정수가 트리의 정점이라고 가정한다. 둘째 줄에는 n개의 정수 w1, w2, ..., wn이 주어지는데, wi는 정점 i의

www.acmicpc.net

- 어려웠다.. 결국 블로그의 도움을 받아 해결하였다

- 트리의 구조 -> 어느 정점이든 루트가 될 수 있다 -> 편의상 1번 정점을 루트로 생각하고 1번 노드부터 dfs 탐색을 진행한다

- 현재 접근한 노드를 포함할 때/포함하지 않을 때의 최대 독립 집합의 가중치를 구해본다 => memo[N][2]

- 현재 노드를 집합에 포함한다면, 인접한 다음 노드는 포함되지 않아야 한다.

- 현재 노드를 집합에 포함하지 않는다면, 인접한 다음 노드를 포함되도 되고, 포함하지 않아도 된다 => 둘 중 더 큰 값을 선택한다

- 이렇게 dfs를 통해 memo 배열을 모두 채우고, memo[1][0]과 memo[1][1] 중의 큰값을 선택한다

- 이제 다시 루트 노드에서 최대 독립 집합을 구성하는 노드들을 찾을 trace 메서드를 호출한다

- trace(int cur, int inc): 현재 접근한 노드 cur, 접근한 노드를 포함하는지(1), 포함하지 않는지(0)

- 포함한다면 정답 배열에 넣고, 인접한 노드들은 포함하지 않는 버전으로 호출 => trace(next, 0)

- 포함하지 않는다면, 이전에 구했던 memo를 이용하여 더 큰 값을 가진 경우로 호출 => trace(next, ?)

 

import java.util.*;
import java.io.*;

public class Main {
    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        N = Integer.parseInt(br.readLine());
        W = new int[N+1];
        visit = new boolean[N+1];
        memo = new int[N+1][2]; //i 번째 노드를 선택한 경우와 선택하지 않은 경우
        tree = new ArrayList[N+1];

        StringTokenizer st = new StringTokenizer(br.readLine());
        for(int i=1;i<=N;i++){
            tree[i] = new ArrayList<>();
            W[i] = Integer.parseInt(st.nextToken());
        }

        for(int i=1;i<N;i++){
            st = new StringTokenizer(br.readLine());
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken());
            tree[a].add(b);
            tree[b].add(a);
        }

        dfs(1);

        visit = new boolean[N+1];
        if(memo[1][0] > memo[1][1]){
            System.out.println(memo[1][0]);
            trace(1, 0);
        } else{
            System.out.println(memo[1][1]);
            trace(1, 1);
        }

        Collections.sort(res);
        for(int num : res ) {
            System.out.print(num+" ");
        }

    }

    static int N;
    static int[] W;
    static boolean[] visit;
    static int[][] memo;
    static ArrayList<Integer>[] tree;
    static ArrayList<Integer> res = new ArrayList<>();

    public static void dfs(int cur){
        int child_size = tree[cur].size();

        memo[cur][0] = 0;
        memo[cur][1] = W[cur];
        visit[cur] = true;

        for(int i=0;i<child_size;i++){
            int next = tree[cur].get(i);
            if(!visit[next]){
                dfs(next);

                memo[cur][0] += Math.max(memo[next][0], memo[next][1]);
                memo[cur][1] += memo[next][0];
            }
        }
    }

    public static void trace(int cur, int inc){
        visit[cur] = true;

        if(inc==1){
            res.add(cur);
            for(int next:tree[cur]){
                if(!visit[next]) trace(next, 0);
            }
        } else {
            for(int next:tree[cur]){
                if(!visit[next]){
                    if(memo[next][1] > memo[next][0]){
                        trace(next, 1);
                    } else trace(next, 0);
                }
            }
        }
    }
}