AOJ 3042 - Aizu Competitive Programming Camp 2018 Day 2 D Gridgedge
解法
縦と横に分けて考えます。
座標aからbに移動するとき、考えられるパターンは3つあり、
そのままaからbに直接いくパターン
座標0に移動してからbにいくパターン
座標r-1(およびc-1)に移動してからbにいくパターン
になります。
また、座標をワープしてから移動する2、3番目のパターンは、必ず最初にワープを行わなければなりません。
ので、ワープする順番は特に考える必要がありません(この部分はコンテスト中に気づいていないため、ワープするパターンで場合分けをしてしまったのでコード量が増えています)。
あとは、3つのパターンのうち最も少ないコストで移動できるパターンを縦と横について求め、その縦と横のパターンの組み合わせについて、個数制限のある重複順列のパターン数を求めれば答えになります。
コンテスト後に書き直したコードはこちらになります。
#include <bits/stdc++.h> #define N (long long)(1e9 + 7) #define MAX 250005 using namespace std; long long factorial[MAX] = {0}, finverse[MAX] = {0}, inverse[MAX] = {0}; void smodfact() { factorial[0] = factorial[1] = 1; finverse[0] = finverse[1] = 1; inverse[1] = 1; for(int i = 2; i < MAX; ++i) { factorial[i] = factorial[i - 1] * i % N; inverse[i] = N - (inverse[N % i] * (N / i)) % N; finverse[i] = finverse[i - 1] * inverse[i] % N; } } long long calccomb(long long n, long long k) { if(n == k && n == 0) return 1; if(n < 0 || k < 0 || n < k) return 0; return factorial[n] * finverse[k] % N * finverse[n - k] % N; } long long rc[2], a[2], b[2], ch[2], data[2][3] = {0}, cost = 0, ans = 0; long long solve(); long long calc(int x, int y); long long setd(int x); int main() { smodfact(); cin >> rc[0] >> rc[1] >> a[0] >> a[1] >> b[0] >> b[1]; ans = solve(); cout << cost << " " << ans << endl; return 0; } long long solve() { long long xy[2], ans = 0, nowx, nowy; for(int i = 0; i < 2; ++i) xy[i] = setd(i); cost = xy[0] + xy[1]; for(int i = 0; i < 3; ++i) for(int j = 0; j < 3; ++j) { nowx = data[0][i]; nowy = data[1][j]; if(xy[0] == nowx && xy[1] == nowy) { ans += calc(i, j); ans %= N; } } return ans; } long long setd(int x) { long long minn; data[x][0] = max(a[x] - b[x], b[x] - a[x]); minn = data[x][0]; data[x][1] = b[x] + 1; minn = min(minn, data[x][1]); data[x][2] = rc[x] - b[x]; minn = min(minn, data[x][2]); return minn; } long long calc(int x, int y) { return factorial[data[0][x] + data[1][y]] * finverse[data[0][x]] % N * finverse[data[1][y]] % N; }
コンテスト中に提出コードはこちらになります。
#include <bits/stdc++.h> #define N (long long)(1e9 + 7) #define MAX 250005 using namespace std; long long factorial[MAX] = {0}, finverse[MAX] = {0}, inverse[MAX] = {0}; void smodfact() { factorial[0] = factorial[1] = 1; finverse[0] = finverse[1] = 1; inverse[1] = 1; for(int i = 2; i < MAX; ++i) { factorial[i] = factorial[i - 1] * i % N; inverse[i] = N - (inverse[N % i] * (N / i)) % N; finverse[i] = finverse[i - 1] * inverse[i] % N; } } long long calccomb(long long n, long long k) { if(n == k && n == 0) return 1; if(n < 0 || k < 0 || n < k) return 0; return factorial[n] * finverse[k] % N * finverse[n - k] % N; } long long rc[2], a[2], b[2], ch[2], data[2][3][2] = {0}, cost = 0, ans = 0; long long solve(); long long calc(int x, int y); long long setd(int x); int main() { smodfact(); cin >> rc[0] >> rc[1] >> a[0] >> a[1] >> b[0] >> b[1]; ans = solve(); cout << cost << " " << ans << endl; return 0; } long long solve() { long long xy[2], ans = 0, nowx, nowy; for(int i = 0; i < 2; ++i) xy[i] = setd(i); cost = xy[0] + xy[1]; for(int i = 0; i < 3; ++i) for(int j = 0; j < 3; ++j) { nowx = data[0][i][0]; nowy = data[1][j][0]; if(xy[0] == nowx && xy[1] == nowy) { ans += calc(i, j); ans %= N; } } return ans; } long long setd(int x) { long long minn; data[x][0][0] = max(a[x] - b[x], b[x] - a[x]); minn = data[x][0][0]; data[x][1][0] = b[x] + 1; data[x][1][1] = 1; minn = min(minn, data[x][1][0]); data[x][2][0] = rc[x] - b[x]; data[x][2][1] = 1; minn = min(minn, data[x][2][0]); return minn; } long long calc(int x, int y) { long long ans = 0; if(x == 0 && x == y) return factorial[data[0][0][0] + data[1][0][0]] * finverse[data[0][0][0]] % N * finverse[data[1][0][0]] % N; if(x != 0) { for(int i = 0; i <= data[1][y][0]; ++i) { ans += factorial[data[0][x][0] + data[1][y][0] - i - 1] * finverse[data[0][x][0] - 1] % N * finverse[data[1][y][0] - i] % N; ans %= N; } } else if(y != 0) { for(int i = 0; i <= data[0][x][0]; ++i) { ans += factorial[data[0][x][0] + data[1][y][0] - i - 1] * finverse[data[0][x][0] - i] % N * finverse[data[1][y][0] - 1] % N; ans %= N; } } while(ans < 0) ans += N; return ans; }