Solution -「JRKSJ R1」吊打

cirnovsky /

考虑线段树。

当然不能直接维护序列,我们来维护序列元素的指数。

具体来说,若 ai=(c2)bia_i=(c^{2})^{b_i},我们维护的就是 bib_i,当然 c2c^2 也要维护。

这里不能取一次方,否则 bib_i 为奇数时就不能开方了。

开方就是花神的做法,能开就直接开,否则就对指数打减标记。

平方简单,直接对指数打加标记。

最后统计答案就遍历线段树,然后快速幂计算即可。

#include<bits/stdc++.h>
using namespace std;
#define int long long
#define ls (x<<1)
#define rs (x<<1|1)
const int mod=998244353;
int pwr[8000010],val[800010],tag[800010],n,m,a[200010];//,b[200010];
int qkpow(int bas,int fur,int mod){
	int res=1;
	while(fur){
		if(fur&1)res=res*bas%mod;
		bas=bas*bas%mod,fur>>=1;
	}
	return res%mod;
}
void psup(int x){pwr[x]=min(pwr[ls],pwr[rs]),val[x]=max(val[ls],val[rs]);}
void psdn(int x){
	if(tag[x]){
//		int m=(l+r)/2;
		pwr[ls]+=tag[x],pwr[rs]+=tag[x];
		tag[ls]+=tag[x],tag[rs]+=tag[x];
		tag[x]=0;
	}
}
void build(int l,int r,int x){
	if(l==r)return(void)((pwr[x]=0,val[x]=a[l]));
	int m=(l+r)/2;
	build(l,m,ls),build(m+1,r,rs);
	psup(x);
}
void open(int l,int r,int x,int ql,int qr){
	if(val[x]<=1)return;
	if(l>=ql&&r<=qr&&pwr[x]>=1)return void((tag[x]--,pwr[x]--));
	if(l==r) {
		if(pwr[x]) return void(pwr[x] --);
		return void(val[x]=sqrt(val[x]));
	}
	int m=(l+r)/2;psdn(x);
	if(m>=ql)open(l,m,ls,ql,qr);
	if(m<qr)open(m+1,r,rs,ql,qr);
	psup(x);
}
void pinf(int l,int r,int x,int ql,int qr){
	if(val[x]<=1)return;
	if(l>qr||r<ql)return;
	if(l>=ql&&r<=qr)return(void)((tag[x]++,pwr[x]++));
	int m=(l+r)/2;psdn(x);
	if(m>=ql)pinf(l,m,ls,ql,qr);
	if(m<qr)pinf(m+1,r,rs,ql,qr);
	psup(x);
}
//int get(int t){
//	int res=0;
//	if(t==1)return 0;
//	while(1){
//		int tmp=pow(t,0.5);
//		if(tmp*tmp==t)t=tmp,res++;
//		else break;
//	}
//	return res;
//}
int solve(int l,int r,int x){
	if(l==r){
		if(pwr[x]>0)return qkpow(val[x],qkpow(2,pwr[x],mod-1),mod);
		else return val[x];
	}
	int m=(l+r)/2;psdn(x);
	return (solve(l,m,ls)+solve(m+1,r,rs))%mod;
}
signed main(){
	scanf("%lld%lld",&n,&m);
//	memset(pwr,0x7f,sizeof(pwr));
	for(int i=1;i<=n;++i){
		scanf("%d",&a[i]);
//		b[i]=get(a[i]);
//		if(b[i]>0)a[i]=pow(a[i],1.0/(b[i]*2));
	}
	build(1,n,1);
	while(m--){
		int t,l,r;scanf("%lld%lld%lld",&t,&l,&r);
		if(t==1)open(1,n,1,l,r);
		else pinf(1,n,1,l,r);
	}
	printf("%lld\n",solve(1,n,1));
	return 0;
}
//don't forget to mod