1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
| const int N = 2000005; int n, m, k; int a[N], vis[N], d[N], f[N], sz[N]; vector<int> vec[N]; const ll mod = 998244353;
ll fpow(ll x, ll r) { ll result = 1; while (r) { if (r & 1)result = result * x % mod; r >>= 1; x = x * x % mod; } return result; }
namespace binom { ll fac[N], ifac[N]; int __ = [] { fac[0] = 1; for (int i = 1; i <= N - 5; i++) fac[i] = fac[i - 1] * i % mod; ifac[N - 5] = fpow(fac[N - 5], mod - 2); for (int i = N - 5; i; i--) ifac[i - 1] = ifac[i] * i % mod; return 0; }();
inline ll C(int n, int m) { if (n < m || m < 0)return 0; return fac[n] * ifac[m] % mod * ifac[n - m] % mod; }
inline ll A(int n, int m) { if (n < m || m < 0)return 0; return fac[n] * ifac[n - m] % mod; } } using namespace binom; ll ans;
void dfs1(int u, int fa) { f[u] = fa; d[u] = d[fa] + 1; sz[u] = 1; for (auto v:vec[u]) { if (v == fa)continue; dfs1(v, u); sz[u] += sz[v]; } }
ll inv;
void dfs2(int u, int ancient) { if (u != 1) ans = (ans - (C(n - d[u] + d[ancient], k) - C(n - 1 - d[u] + d[2], k)) * inv % mod * 2 * sz[u] % mod + mod) % mod; for (auto v:vec[u]) { if (v == f[u])continue; if (vis[u]) dfs2(v, u); else dfs2(v, ancient); } }
int main() { int p, q, u, v, w, x, y, z, T; int s; cin >> n >> k >> s; inv = fpow(C(n - 1, k), mod - 2); for (int i = 1; i < n; i++) scanf("%d%d", &u, &v), vec[u].emplace_back(v), vec[v].emplace_back(u); dfs1(1, 0); vis[1] = 1; for (int u = s; u != 1; u = f[u]) vis[u] = 1, ans += 2 * sz[u] - 1; ans %= mod; dfs2(1, 0); cout << ans; return 0; }
|