§ Description
Link.
求 i=1∑nj=1∑n(i+j)kf(gcd(i,j))gcd(i,j)。
§ Solution
ANS=i=1∑nj=1∑n(i+j)kμ2(gcd(i,j))gcd(i,j)=d=1∑ni=1∑nj=1∑n(i+j)kμ2(d)d[gcd(i,j)=d]=d=1∑ndk+1×μ2(d)i=1∑⌊dn⌋j=1∑⌊dn⌋(i+j)kh∣i,h∣j∑μ(h)=d=1∑ndk+1×μ2(d)h=1∑⌊dn⌋μ(h)×hk×i=1∑⌊dhn⌋j=1∑⌊dhn⌋(i+j)k=d=1∑ndk+1×μ2(d)h=1∑⌊dn⌋μ(h)×hk×i=1∑⌊dhn⌋j=1∑⌊dhn⌋(i+j)k
前面两个和式里面显然能算,考虑怎么对于 x 算 ∑i=1x∑j=1x(i+j)k。考虑对其差分:
(i=1∑x+1j=1∑x+1(i+j)k)−(i=1∑xj=1∑x(i+j)k)=i=1∑xj=1∑x+1(i+j)k+i=1∑x+1(x+1+i)k−i=1∑xj=1∑x(i+j)k=i=1∑x(j=1∑x+1(i+j)k−j=1∑x(i+j)k)+i=1∑x+1(x+1+i)k=i=1∑x(x+1+i)k+i=1∑x+1(x+1+i)k
然后滚个前缀和就可以算了。
#include<bits/stdc++.h>
typedef long long LL;
const int MOD=998244353;
int norm( LL x ) {
if( x<0 ) {
x+=MOD;
}
if( x>=MOD ) {
x%=MOD;
}
return x;
}
int n,k,ans;
int qpow( int bas,int fur ) {
int res=1;
while( fur ) {
if( fur&1 ) {
res=norm( LL( res )*bas );
}
bas=norm( LL( bas )*bas );
fur>>=1;
}
return norm( res+MOD );
}
std::tuple<std::vector<int>,std::vector<int>> makePrime( int n ) {
std::vector<int> prime,tag( n+1 ),mu( n+1 ),pw( n+1 );
pw[0]=1;
mu[1]=pw[1]=1;
for( int i=2;i<=n;++i ) {
if( !tag[i] ) {
mu[i]=norm( -1 );
prime.emplace_back( i );
pw[i]=qpow( i,k );
}
for( int j=0;j<int( prime.size() ) && i*prime[j]<=n;++j ) {
tag[i*prime[j]]=1;
pw[i*prime[j]]=norm( LL( pw[i] )*pw[prime[j]] );
if( i%prime[j]==0 ) {
mu[i*prime[j]]=0;
break;
} else {
mu[i*prime[j]]=norm( -mu[i] );
}
}
}
return std::tie( mu,pw );
}
int main() {
LL tmp;
scanf( "%d %lld",&n,&tmp );
k=tmp%( MOD-1 );
std::vector<int> mu,pw,prt( n+1 ),exprt( n+1 ),preSum( n+1 );
// prt: i^(k+1)*mu^2(i)
// exprt: mu(i)*i^k
// preSum sum sum (i+j)^k
std::tie( mu,pw )=makePrime( n<<1|1 );
for( int i=1;i<=n;++i ) {
prt[i]=norm( prt[i-1]+norm( LL( norm( LL( norm( LL( mu[i] )*mu[i] ) )*pw[i] ) )*i ) );
exprt[i]=norm( exprt[i-1]+norm( LL( mu[i] )*pw[i] ) );
}
for( int i=1;i<=( n<<1 );++i ) {
pw[i]=norm( pw[i]+pw[i-1] );
}
for( int i=1;i<=n;++i ) {
preSum[i]=norm( norm( preSum[i-1]+norm( pw[i<<1]-pw[i] ) )+norm( pw[(i<<1)-1]-pw[i] ) );
}
for( int l=1,r;l<=n;l=r+1 ) {
r=n/( n/l );
int tmp=0;
for( int exl=1,exr,m=n/l;exl<=m;exl=exr+1 ) {
exr=m/( m/exl );
tmp=norm( tmp+norm( LL( norm( exprt[exr]-exprt[exl-1] ) )*preSum[m/exl] ) );
}
ans=norm( ans+LL( norm( prt[r]-prt[l-1] ) )*tmp );
}
printf("%d\n",ans);
return 0;
}