逆元を知らない人でも確率dp mod 998244353を通せるようにするための「おまじない」(茶~緑向け)

Page content

せっかくブログを立てたので競プロの記事を書きます。

dp などの問題で確率 mod P (P は大きい素数)を求めることがしばしばあります。
初見だと「は?確率の mod?何いってんだこいつ」となりがちなので、この記事では

単純なナップサック dp のような D 問題相当の dp なら解けるけど、確率 mod P とか言われると「は?何いってんだこいつ」となる人向けに

  1. Python ユーザーが ACL の modint に相当するライブラリ等を使わずに
  2. 答えが小数になることを許容してそのまま実装し
  3. 実装した確率 dp を機械的に mod P に対応させる

というステップで解く方法について整理しました。

本記事の目的はあくまでも例題に対する実装例を示すことによる(問題を通したり解説を理解するための)「取っ掛かりをつくる」ことであり、逆元それ自体の詳細については触れません(あまりきちんとした説明もできません)。後述するけんちょんさんの記事などを読みましょう。

例題としてabc275_e: Sugoroku 4を使います。

問題文

1. 普通に実装する

mod のことは気にせず、普通に dp の遷移を考えてみてください。

遷移

dp[k][i] を、ルーレットを k 回したときにマス i にいる"確率"として考えます。

この"確率"は、小数を扱うという特徴はあるものの、初期値をdp[0][0] = 1として(スタート地点にいる確率は 1)

N = 2, M = 2のときdp[1][1] = 0.5, dp[1][2] = 0.5
(dp[0][0]からマス 1、マス 2 にそれぞれ 1/2 の確率で遷移)

といった具合に、ナップサック問題の"価値"などと同じく、ある状態からある状態への遷移に伴って変化するものとして扱うことができます。 折り返しを考慮した 遷移先を j とすると、 ルーレットの候補は等確率で M 個あることから、具体的には

dp[k+1][j] += dp[k][i] / M として計算すればよいです。

実装例

N,M,K = list(map(int,input().split()))
dp = [[0]*(N+1) for _ in range(K+1)]
dp[0][0] = 1

for k in range(K): # 回した回数
    for i in range(N): # 移動元(マスNに到達したら終了するため、Nは移動元に含まない)
        for m in range(1, M+1): # 移動するマスの数
            j = i+m if i+m <= N else N - (i+m-N) # 折り返しを考慮した移動先
            dp[k+1][j] += dp[k][i] / M

# 各移動回数におけるマスNの確率を合計する
ans = sum(dp[i][N] for i in range(K+1))

print(ans)

入力例 1

2 2 1
# => 0.5

この時点ではまだ答えは小数です。

とはいえ、ABC の多くの問題ではサンプル 1 に分数での説明を入れてくれているため、
人間に理解できるかたちでサンプルの確認ができるほか、手計算での実験やデバッグがしやすいなどのメリットもあります。

2. 普通に実装した dp を mod P に対応させる

先程のソースコードに以下の変更を加えます。

  1. 変数 mod に問題文で指定された素数(998244353)を入れる
  2. その下に inv_M = pow(M, mod-2, mod) と書く
  3. / M の部分を * inv_M に置換する
  4. 計算過程に都度 %= mod をつける

差分

N,M,K = list(map(int,input().split()))
+ mod = 998244353
+ inv_M = pow(M, mod-2, mod)
dp = [[0]*(N+1) for _ in range(K+1)]
dp[0][0] = 1

for k in range(K): # 回した回数
    for i in range(N): # 移動元(マスNに到達したら終了するため、Nは移動元に含まない)
        for m in range(1, M+1): # 移動するマスの数
            j = i+m if i+m <= N else N - (i+m-N) # 折り返しを考慮した移動先
-           dp[k+1][j] += dp[k][i] / M
+           dp[k+1][j] += dp[k][i] * inv_M
+           dp[k+1][j] %= mod

# 各移動回数におけるマスNの確率を合計する
ans = sum(dp[i][N] for i in range(K+1))
+ ans %= mod

print(ans)

上記の変更を行うことで、mod 998244353 で確率が求められるようになり、AC できました。やったね!

https://atcoder.jp/contests/abc275/submissions/41220093

3. 何をやったのか

結局のところ M で割る操作を mod 998244353 で行うことができれば、それは確率を mod 998244353 の世界で計算できることにほかなりません。 しかし、普通に M で割ろうとすると明らかに割り切れずに小数になってしまうので、とても困ります。

そこで、M で割る操作を 1/M を掛ける操作に置き換えた上で、 「1/M mod 998244353」自体をイイ感じに整数で表現することを考えます。
1/M の部分さえ整数で表現できれば、dp の遷移は全て整数同士の掛け算と足し算で完結することになるので、都度余りをとりながら計算すれば全て丸く収まる、という算段です。

そして、それを計算してくれているのが「pow(M, mod-2, mod)」のおまじないです。
あとは都度 %= mod を加えれば、この問題を解くことができました。めでたしめでたし。
(問題文の定義では既約分数がどうのこうのとよくわからんことを書いていますが、上記の計算さえイイ感じにできていれば別に dp の要素を[分子, 分母]などで管理したりする必要もありません。)

実は、ここでの「1/M mod 998244353」には名前がついていて、mod 998244353 における M の「逆元」と呼ばれています。

本記事ではこの「逆元」をフェルマーの小定理を使って求めています。
このあたりをめちゃくちゃわかりやすく解説してくれているけんちょんさんの Qiita 記事があるので、詳細はこちらに丸投げします。

また、本記事では逆元の存在条件にも触れていませんし、「イイ感じに整数で表現する」といったふんわり表現を多用しています。mod についてきちんと理解するためにも、mod についてふんわりとしか理解していない競プロ er のみなさん - えびちゃんの日記を(自分も含めて)読みましょう。