题解
2025-08-14 09:55:15
发布于:浙江
思路分析
如果我们希望直接解得这个情况数,似乎并不容易,因此考虑使用间接的求法。观察一下题目条件,要求合法的解满足这样三个条件
-
菜数
-
每种烹饪方式至多选择一次
-
每种食材至多选择次
显然,其中最复杂的是3,那么我们可以考虑先求出所有仅满足1,2的情况,然后再去把不满足3的排除。
第一次dp,求满足1,2的情况
我们还是正常的思考,先想一个表示选取前个烹饪法的有效菜品数,因为此时似乎不只是受的影响,还需要涉及之前的决策情况。所以这个时候,我们无法从直接推出答案,需要引入一个新的信息,帮助我们求解。我们考虑引入一个表示目前已选取的烹饪法数。因此我们定义为。
现在我们考虑转移的过程,对于每一步。考虑两种决策
-
如果当前一步选择使用,则产生dp_1(i-1,j-1)\cdot \mbox{当前烹饪法可能产生选择数}种方案,其中当前烹饪法可能产生选择数我们可以用一个数组额外存储这一列的数字和,记作.
-
如果当前一步不选择,那么还能继承上一步的情况数,即
综上得到这样一个转移方程 $$dp_1(i,j)=dp(i-1,j-1)+dp(i-1,j)\cdot s(i)$$
于是全部满足1,2的解可以在复杂度内被找到。
第二次dp,如何求不合法解数
要求不满足条件3的情况,我们可以发现,当某种食材用了超过次,考虑极端情况,其使用了次,则其他食材的使用总次数为对于这个式子,考虑两种情况
-
如果是奇数,设,则原式
-
如果是偶数,设,原式
因为
所以原式
综上,我们可以得知此时其他食材的使用总次数是
那么显然,至多有一个食材会被过度使用。我们考虑枚举这个食材
我们定义一个,继续先考虑与之前一样。然而同样,我们也无法直接统计出之前所有过度情况数,于是引入另一个参数。我们引入一个过度使用的食材的参数,然后似乎还需要记录其使用次数和其他食材的种类和使用次数等信息。
这会造成不可接受的时间和空间复杂度,我们必须要思考如何消掉不必要的信息。
首先,关于当前删去哪个节点的信息是不必要的,
因为它不参与到状态转移的过程中,而仅仅只是被枚举考虑的量,并且我们的执行累加,并不会因为状态重叠这类的问题造成程序错误。因此我们可以正确的这样操作。同理,其余未被选择的食材数也只需要记录其数量即可,因为其他的信息不参与和影响转移。于是我们终于得到了较为简单的
我们简单推导一下转移方程
-
不选择第种烹饪方式并让被枚举到的过度使用,情况数
-
选择第种烹饪方式,并让过度使用(进一步的过度),情况数
-
选择第种烹饪方式,但不选择,产生情况数
综上,可得的转移方程为
几乎完美,除了时间。这个算法的复杂度是,仍超我们可以接受的范围
优化
不难发现,不合法的充分必要条件是
::: proof
Proof. 当时
类似于之前证明"仅有一个食材过度使用"的方法,可得
故,充分性得证
现在假设已有
则,,
此时,情况不合法,必要性得证 ◻
:::
就这样,我们可以只考虑,但是为了避免负索引造成的问题,我们再加上一个总食材数,于是我们得到了降维的状态
这样,我们的复杂度就只剩下了,完美AC
代码实现
#include <stdio.h>
#include <string.h>
const int mod=998244353;
const int maxn=105;
const int maxm=2005;
int n, m;
long long a[maxn][maxm];
long long s[maxn];
long long dp1[maxn][maxn];
long long dp2[maxn][maxn<<1];
int main(){
scanf("%d%d",&n,&m);
for (int i=1;i<= n;i++) {
for (int j=1;j<=m;j++) {
scanf("%lld",&a[i][j]);
s[i]=(s[i]+a[i][j])%mod;
}
}
// 计算总方案数
dp1[0][0]=1;
for (int i=1;i<=n;i++) {
for (int j=0;j<=i;j++) {
dp1[i][j]=dp1[i-1][j];
if (j>0){
dp1[i][j]=
(dp1[i][j]+dp1[i-1][j-1]*s[i])
%mod;
}
}
}
long long total=0;
for (int j=1;j<=n;j++) {
total=
(total+dp1[n][j])
%mod;
}
// 计算不合法方案数
long long invalid = 0;
for (int c=1;c<=m;c++) {
memset(dp2,0,sizeof(dp2));
dp2[0][maxn]=1;
for (int i=1;i<=n;i++) {
for (int d=-i;d<=i;d++) {
int idx=d+maxn;
dp2[i][idx]=dp2[i-1][idx];
if (d>-i){
dp2[i][idx]=(dp2[i][idx]+dp2[i-1][idx-1]
*a[i][c])
%mod;
}
if (d<i){
long long other=(s[i]-a[i][c]+mod)
%mod;
dp2[i][idx]=
(dp2[i][idx]+dp2[i-1][idx+1]*other)
%mod;
}
}
}
for (int d=1;d<=n;d++) {
invalid=(invalid+dp2[n][d+maxn])
%mod;
}
}
// 最终结果
long long ans=(total-invalid+mod)%mod;
printf("%lld\n",ans);
return 0;
}
这里空空如也
有帮助,赞一个