计算机算法设计与分析 3-23 K中值问题

解题方法

显然是个dp题, 不过是dp的方程不太容易想到罢了
明天再写吧好累了,先贴代码

/*
 * Copyright (c) 2019 Ng Kimbing, HNU, All rights reserved. May not be used, modified, or copied without permission.
 * @Author: Ng Kimbing, HNU.
 * @LastModified:2019-05-14 T 21:15:57.514 +08:00
 */
package ACMProblems.DynamicProgramming;

import MyUtil.Matrix;

import java.io.FileInputStream;

import static ACMProblems.ACMIO.*;

public class ServicePointInLine {
    private static int n;
    private static int maxStationNum;
    private static int[] x;
    private static int[] w;
    private static int[] c;
    private static int[] sumW;
    private static int[] sumOfWiMultiD1i;
    private static int[][] dp1;
    private static int[][] dp2;
//    /**
//     * Finds the first position in which @key could be inserted without changing the ordering.
//     *
//     * @param array the array
//     * @param size  the size of the array.
//     * @param key   The search item
//     * @return returns the index pointing to the first element <em>not less than</em> {@code key},
//     * or {@code size} if every element is less than {@code key}.
//     */
//    public static int lowerBound(int[] array, int size, int key) {
//        int first = 0, middle;
//        int half, len;
//        len = size;
//        //binary search
//        while (len > 0) {
//            half = len >> 1;
//            middle = first + half;
//            if (array[middle] < key) {
//                first = middle + 1;
//                len = len - half - 1;       //search in the right sub-sequence
//            } else
//                len = half;            //search in the left sub-sequence, middle included.
//        }
//        return first;
//    }

    private static void inputData() throws Exception {
        setStream(new FileInputStream("serviceData.txt"));
        n = nextInt();
        maxStationNum = nextInt();
        x = new int[n + 1];
        w = new int[n + 1];
        c = new int[n + 1];
        sumW = new int[n + 1];
        dp1 = new int[maxStationNum + 1][n + 1];
        dp2 = new int[maxStationNum + 1][n + 1];
        sumOfWiMultiD1i = new int[n + 1];
        for (int i = 1; i <= n; ++i) {
            x[i] = nextInt();
            w[i] = nextInt();
            c[i] = nextInt();
            sumW[i] = sumW[i - 1] + w[i];
            sumOfWiMultiD1i[i] = sumOfWiMultiD1i[i - 1] + w[i] * getD(1, i);
        }
    }

    /**
     * sum w[ i : j ]
     *
     * @param i lower bound, inclusive
     * @param j upper bound, inclusive
     * @return sum
     */
    private static int getWSum(int i, int j) {
        return sumW[j] - sumW[i - 1];
    }

    private static int getD(int i, int j) {
        int t = x[i] - x[j];
        return t > 0 ? t : -t;
    }

    /**
     * sum w(i)*d(1,i)  {from i to j}
     * id est the last station is located at 1.
     * calculate the total service cost for people who live in the interval[i, j]
     *
     * @param left  lower bound, inclusive
     * @param right upper bound, inclusive
     * @return sum
     */
    private static int getGoToOne(int left, int right) {
        if (right < left)
            return 0;
        return sumOfWiMultiD1i[right] - sumOfWiMultiD1i[left - 1];
    }

    /**
     * suppose that the last service station is located at left-1.
     * calculate the total service cost for people who live in the interval[left, right]
     *
     * @param left  left bound, inclusive
     * @param right right bound, inclusive
     * @return the total money
     */
    private static int getStationLeft(int left, int right) {
//        if (right < left)
//            return 0;
        assert left <= right;
//        System.out.printf("%d - %d*%d\n", getGoToOne(left, right), getWSum(left, right), getD(1, left - 1));
        return getGoToOne(left, right) - getWSum(left, right) * getD(1, left - 1);
    }

    /**
     * everybody go to the station right+1
     */
    private static int getStationRight(int left, int right) {
        assert left <= right;
        return getWSum(left, right) * getD(1, right + 1) - getGoToOne(left, right);
    }
//    /**
//     * WRONG CODE !
//     *
//     * @param i the first index
//     * @param j the second index
//     * @return return the value
//     */
//    private static int WRONG_dp2FindMin(int i, int j) {
//        int currMin = 0x3f3f3f3f;
//        //goto k or goto j
//        //Wrong Code!   can not go to k!!!   5/15/2019
//        for (int k = i - 1; k < j; ++k) {
//            int mid = lowerBound(x, x.length, (int) Math.ceil(1.0 * (x[k] + x[j]) / 2));
//            int temp = dp1[i - 1][k] + getStationLeft(k + 1, mid - 1) + getStationRight(mid, j - 1);
//            if (temp < currMin)
//                currMin = temp;
//        }
//        return currMin + c[j];
//    }

    /**
     * get a value to update dp2[i][j]
     *
     * @param i the first index
     * @param j the second index
     * @return return the value
     */
    private static int dp2FindMin(int i, int j) {
        int currMin = 0x3f3f3f3f;
        //goto k or goto j
        for (int k = i - 1; k < j; ++k) {
            int temp = dp1[i - 1][k] + getStationRight(k + 1, j - 1);
            if (temp < currMin)
                currMin = temp;
        }
        return currMin + c[j];
    }

    /**
     * get a value to update dp1[i][j]
     *
     * @param i the first index
     * @param j the second index
     * @return return the value
     */
    private static int dp1FindMin(int i, int j) {
        int currMin = 0x3f3f3f3f;
        for (int k = i; k <= j; ++k) {
            //a station in k
            int temp = dp2[i][k] + getStationLeft(k + 1, j);
            if (temp < currMin)
                currMin = temp;
        }
        return currMin;
    }

    /**
     * solve the problem
     */
    private static void solveProblem() {
        for (int j = 1; j <= n; ++j) {
            //a station is located at j
            dp2[1][j] = c[j] + getStationRight(1, j - 1);
            dp1[1][j] = dp1FindMin(1, j);
        }
        for (int i = 2; i <= maxStationNum; ++i) {
            for (int j = 1; j <= n; ++j) {
                if (i > j) {
                    dp1[i][j] = 0x3f3f3f3f;
                    dp2[i][j] = 0x3f3f3f3f;
                    continue;
                }
//                if (i == 2 && j == 5)
//                    System.out.println("debug");
                dp2[i][j] = dp2FindMin(i, j);
                dp1[i][j] = dp1FindMin(i, j);
            }
        }
//        System.out.println(new Matrix(dp1));
//        System.out.println(new Matrix(dp2));
        int ans = 0x3f3f3f3f;
        for (int stationNum = 1; stationNum <= maxStationNum; ++stationNum) {
            if (dp1[stationNum][n] < ans)
                ans = dp1[stationNum][n];
        }
        System.out.println(ans);
    }

    /**
     * to test functions
     */
    private static void foo() {
        int a = 3, b = 5;
        System.out.println(getGoToOne(a, b));
        System.out.println(getStationLeft(a, b));
        System.out.println(getStationRight(a, a));
//        int[] arr = {0, 5, 6, 8};
//        System.out.println(lowerBound(arr, arr.length, 7));
    }

    public static void main(String[] args) throws Exception {
        inputData();
//        foo();
        solveProblem();
    }
}

相关文章: