1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
| #include <cstdio> using namespace std;
const int N = 1e6 + 5, MOD = 998244353; int n, k, sum, fac[N], ifac[N];
int fpow(int a, int b) { int res = 1; for (; b; b >>= 1, a = a * 1ll * a % MOD) if (b & 1) res = res * 1ll * a % MOD; return res; } int C(int n, int k) { if (k > n) return 0; return fac[n] * 1ll * ifac[k] % MOD * 1ll * ifac[n - k] % MOD; } int S(int n, int k) { int res = 0; for (int i = 0; i <= k; i++) { int g = (i & 1) ? -1 : 1; g = g * 1ll * fpow(k - i, n) % MOD * C(k, i) % MOD; g = (g + MOD) % MOD; res = (res + g) % MOD; } res = res * 1ll * ifac[k] % MOD; return res; }
int main() { freopen("ichigo.in", "r", stdin); freopen("ichigo.out", "w", stdout); scanf("%d %d", &n, &k); for (int i = 1, x; i <= n; i++) scanf("%d", &x), sum = (sum + x) % MOD; fac[0] = 1; for (int i = 1; i <= n; i++) fac[i] = fac[i - 1] * 1ll * i % MOD; ifac[n] = fpow(fac[n], MOD - 2); for (int i = n - 1; i >= 0; i--) ifac[i] = ifac[i + 1] * 1ll * (i + 1) % MOD; printf("%lld\n", sum * 1ll * (S(n, k) * 1ll + (n - 1) * 1ll * S(n - 1, k) % MOD) % MOD); return 0; }
|