2 분 소요

문제

문제 바로가기

접근방법

N이 40일 때, 완전탐색을 하게 되면 시간초과가 발생한다. 따라서 수열을 반으로 나누어 접근해야한다. 반으로 나누게 되면 최대길이가 20이기 때문에 시간초과를 피할 수 있다.

다만 주의해야 할 점은 반으로 나눈 2개의 부분수열이 모두 공집합일 때를 고려해야하기 때문에 S가 0일 때는 1을 빼줘야 한다.

이분탐색과 투포인터 두 가지 방법으로 풀 수 있다.

풀이


이분탐색

import java.io.*;
import java.util.*;

public class Main {
    static int N, S;
    static int[] arr;
    static List<Integer> list1 = new ArrayList<>();
    static List<Integer> list2 = new ArrayList<>();
    static long cnt = 0;
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
        StringTokenizer st = new StringTokenizer(br.readLine());
        N = Integer.parseInt(st.nextToken());
        S = Integer.parseInt(st.nextToken());
        st = new StringTokenizer(br.readLine());

        arr = new int[N];
        for(int i = 0; i < N; i++) {
            arr[i] = Integer.parseInt(st.nextToken());
        }

        dfs(0, N/2, 0, list1);
        dfs(N/2, N, 0, list2);

        Collections.sort(list2);

        for(int i = 0; i < list1.size(); i++) {
            cnt += upper(list2, S - list1.get(i)) - lower(list2, S - list1.get(i));
        }

        if(S == 0) {
           cnt--;
        }

        bw.write(String.valueOf(cnt));
        bw.flush();
        bw.close();
        br.close();
    }

    public static int lower(List<Integer> list, int n) {
        int left = 0;
        int right = list.size();
        while(left < right) {
            int mid = (left + right) / 2;
            if(list.get(mid) >= n)
                right = mid;
            else
                left = mid + 1;
        }
        return right;
    }

    public static int upper(List<Integer> list, int n) {
        int left = 0;
        int right = list.size();
        while(left < right) {
            int mid = (left + right) / 2;
            if(list.get(mid) > n)
                right = mid;
            else
                left = mid + 1;
        }
        return right;
    }

    public static void dfs(int start, int end, int sum, List<Integer> list) {
        if(start == end) {
            list.add(sum);
            return;
        }
        dfs(start+1, end, sum + arr[start], list);
        dfs(start+1, end, sum, list);
    }
}



투포인터

import java.io.*;
import java.util.*;

public class Main {
    static int N, S;
    static int[] arr;
    static List<Integer> list1 = new ArrayList<>();
    static List<Integer> list2 = new ArrayList<>();
    static long cnt = 0;
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
        StringTokenizer st = new StringTokenizer(br.readLine());
        N = Integer.parseInt(st.nextToken());
        S = Integer.parseInt(st.nextToken());
        st = new StringTokenizer(br.readLine());

        arr = new int[N];
        for(int i = 0; i < N; i++) {
            arr[i] = Integer.parseInt(st.nextToken());
        }

        dfs(0, N/2, 0, list1);
        dfs(N/2, N, 0, list2);

        Collections.sort(list1);
        Collections.sort(list2);

        calc();

        if(S == 0) {
           cnt--;
        }

        bw.write(String.valueOf(cnt));
        bw.flush();
        bw.close();
        br.close();
    }

    public static void calc() {
        int leftPointer = 0;
        int rightPointer = list2.size() - 1;
        while(leftPointer < list1.size() && rightPointer >= 0) {
            int leftSum = list1.get(leftPointer);
            int rightSum = list2.get(rightPointer);
            long leftCnt = 0;
            long rightCnt = 0;
            if(leftSum + rightSum == S) {
                while(leftPointer < list1.size() && list1.get(leftPointer) == leftSum) {
                    leftCnt++;
                    leftPointer++;
                }
                while(rightPointer >= 0 && list2.get(rightPointer) == rightSum) {
                    rightCnt++;
                    rightPointer--;
                }
                cnt += leftCnt * rightCnt;
            }
            else if(leftSum + rightSum > S) {
                rightPointer--;
            }
            else {
                leftPointer++;
            }
        }
    }

    public static void dfs(int start, int end, int sum, List<Integer> list) {
        if(start == end) {
            list.add(sum);
            return;
        }
        dfs(start+1, end, sum + arr[start], list);
        dfs(start+1, end, sum, list);
    }
}

댓글남기기