题解
2025-08-10 09:48:55
发布于:广东
17阅读
0回复
0点赞
#include <bits/stdc++.h>//万能头
using namespace std;
typedef long long ll;
const int MOD = 998244353;
const int N = 105,M = 2005;
int n, m;
ll a[N][M], s[N];
ll dp[2][2 * N + 5];
int main(){//主函数
ios::sync_with_stdio(false);
cin.tie(nullptr);//对cin和cout加速
cin >> n >> m;
for(int i = 1;i <= n;i++){
s[i] = 0;
for(int j = 1;j <= m;j++){
cin >> a[i][j];
s[i] = (s[i] + a[i][j]) % MOD;
}
}
ll total = 1;
for(int i = 1;i <= n;i++){
total = total * (s[i] + 1) % MOD;
}
total = (total - 1 + MOD) % MOD;
ll invalid = 0;
const int offset = N;
for(int j = 1;j <= m;j++){
memset(dp, 0, sizeof(dp));
int cur = 0;
dp[cur][offset] = 1;
for(int i = 1;i <= n;i++){
int prev = cur;
cur ^= 1;
memset(dp[cur],0,sizeof(dp[cur]));
ll c1 = a[i][j] % MOD;
ll c2 = (s[i] - c1 + MOD) % MOD;
for(int d_prev = 0;d_prev <= 2 * N;d_prev++){
if (dp[prev][d_prev] == 0) continue;
dp[cur][d_prev] = (dp[cur][d_prev] + dp[prev][d_prev]) % MOD;
int d_new = d_prev + 1;
if(d_new <= 2 * N){
dp[cur][d_new] = (dp[cur][d_new] + dp[prev][d_prev] * c1) % MOD;
}
d_new = d_prev - 1;
if(d_new >= 0 && d_new <= 2 * N){
dp[cur][d_new] = (dp[cur][d_new] + dp[prev][d_prev] * c2) % MOD;
}
}
}
for(int d = offset + 1;d <= 2 * N;d++){
invalid = (invalid + dp[cur][d]) % MOD;
}
}
ll ans = (total - invalid + MOD) % MOD;
cout << ans;
return 0;
}
这里空空如也
有帮助,赞一个