求赞,题解
2025-08-09 19:18:14
发布于:广东
10阅读
0回复
0点赞
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#include<cstdio>
#include<cstring>
#define LL long long
#define md 998244353
#define STC static_cast<LL>
#define reg register
struct matrix{
	int a[167][167];
	int r,c;
	inline matrix(){memset(a,0,sizeof a);}
	matrix operator*(const matrix&b)const{
		matrix c;
		c.r=r,c.c=b.c;
		for(int i=0;i<r;++i)
		for(reg int j=0;j<b.c;++j){
			reg __int128 tmp=0;
			for(reg int k=0;k<b.r;++k)
			tmp+=STC(a[i][k])*b.a[k][j];
			c.a[i][j]=tmp%md;
		}
		return c;
	}
}p[62],a;
int T,m,k,num[9][9][9],inv[10];
void init3(){
	int cnt=0;
	for(int i=0;i<=k;++i)
	for(int j=k-i;~j;--j)
	for(int s=k-i-j;~s;--s)
	num[i][j][s]=cnt++;
	for(int i=0;i<=k;++i)
	for(int j=k-i;~j;--j)
	for(int s=k-i-j;~s;--s){
		int&id=num[i][j][s];
		const int ni=inv[i+j+s+1];
		if(i)p->a[id][num[i-1][j][s]]=STC(i)*ni%md;
		if(j){
			if(i+j+s<k)p->a[id][num[i+1][j-1][s+1]]=STC(j)*ni%md;else
			p->a[id][num[i+1][j-1][s]]=STC(j)*ni%md;
		}
		if(s){
			if(i+j+s<k)p->a[id][num[i][j+1][s]]=STC(s)*ni%md;else
			p->a[id][num[i][j+1][s-1]]=STC(s)*ni%md;
		}
		p->a[id][cnt]=p->a[id][id]=ni;
	}
	p->a[cnt][cnt]=1;
	p->r=p->c=cnt+1;
	a.r=1,a.c=cnt+1;
}
void init2(){
	int cnt=0;
	for(int i=0;i<=k;++i)
	for(int j=k-i;~j;--j)
	num[i][j][0]=cnt++;
	for(int i=0;i<=k;++i)
	for(int j=k-i;~j;--j){
		int&id=num[i][j][0];
		const int ni=inv[i+j+1];
		if(i)p->a[id][num[i-1][j][0]]=STC(i)*ni%md;
		if(j){
			if(i+j<k)p->a[id][num[i+1][j][0]]=STC(j)*ni%md;else
			p->a[id][num[i+1][j-1][0]]=STC(j)*ni%md;
		}
		p->a[id][cnt]=p->a[id][id]=ni;
	}
	p->a[cnt][cnt]=1;
	p->r=p->c=cnt+1;
	a.r=1,a.c=cnt+1;
}
int main(){
	inv[1]=1;
	for(int i=2;i<10;++i)
	inv[i]=STC(md-md/i)*inv[md%i]%md;
	scanf("%d%d%d",&T,&m,&k);
	if(m==3)init3();else
	init2();
	for(int i=1;i<62;++i)
	p[i]=p[i-1]*p[i-1];
	while(T--){
		LL n;
		scanf("%lld",&n);
		memset(*a.a,0,sizeof*a.a);
		if(m==1)
		a.a[0][num[1][0][0]]=1;else
		if(m==2)
		a.a[0][num[0][1][0]]=1;else
		a.a[0][num[0][0][1]]=1;
		for(int i=0;i<62;++i)
		if(n>>i&1)a=a*p[i];
		printf("%d\n",a.a[0][a.c-1]);
	}
	return 0;
}
这里空空如也






有帮助,赞一个