题目链接:Travel Plan
题目大意:\(n\) 个点的完全二叉树,每个点可以分配 \(1 \sim m\) 的点权,定义路径价值为路径中最大的点权,求所有路径的价值和。
对于任意长度(这里主要指包括几个节点)的路径 \(t\),最大点权不超过 \(k\) 的方案数有 \(k^t\) 个, 因此最大点权恰好为 \(k\) 的方案数有 \(k^t – (k-1)^t\)。所以,对于任意一条长度为 \(t\) 的路径,不考虑不在路径上其他点的影响时,其对于答案的贡献为:
\[\begin{aligned}\text{path contribution}_t &= \sum_{k=1}^m (k^t – (k-1)^t) \cdot k \\ &= \sum_{k=1}^m \left( k^{t+1} – (k-1)^{t+1} – (k-1)^t \right) \\ &= m^{t+1} – \sum_{k=1}^{m-1} k^t\end{aligned}\]
由于路径长度不会超过 \(2 \log n\),因此求出全部长度路径分别对于答案的贡献时间复杂度为 \(O(m \log \log n)\)。
事实上,对于上面式子的第二项,可以用 Lagrange 插值、伯努利数、多项式等方法可以优化到 \(O(\log^2 n)\)。
下一步,问题转化为求出路径长度为 \(t\) 的个数分别是多少,然后乘一下即可。
第一种方法是点分治,显然复杂度是不够的,因为有 \(O(n \log n)\)。
第二种方法是题解做法。
首先,在这个完全二叉树中,不同形状的子二叉树共有 \(O(\log n)\) 个,设叶子个数为 \(leaf_i\),那么其中包括两种类型:
- \(leaf_i = 2^{p-1}\) 时(\(p\) 是这个子二叉树的最大深度),那么以 \(i\) 为根的子树是一个完全二叉树,显然有 \(O(\log n)\) 个。
- \(leaf_i \not = 2^{p-1}\) 时,节点 \(i\) 的左右儿子必有一个满足其为 \(2\) 的幂次,而另一个不满足,以这样的点为根的子树中的根可以脑补为一条链的形状,因此也有 \(O(\log n)\) 个。
不妨设 \(dp_{i,j}\) 表示以 \(i\) 为根的子树中长度为 \(j\) 的路径个数,\(f_{i,j}\) 表示以 \(i\) 为根的子树中,以 \(i\) 为结束端点长度为 \(j\) 的路径个数。满二叉树时,转移方程应该为:
\[\begin{aligned}f_{i,1} &= 1 \\f_{i,j} &= f_{lson(i), j-1} + f_{rson(i), j-1} (j \geq 2) \\dp_{i,1} &= size_i \\dp_{i,j} &= dp_{lson(i),j} + dp_{rson(i), j} + \sum_{k=0}^{j-1} f_{lson(i), k} \times f_{rson(i), j – 1 – k} (j \geq 2) \\\end{aligned}\]
具体实现的时候,事实上一共 \(O(\log n)\) 个点,因此第二部分的算法复杂度为 \(O(\log^3 n)\),这里也可以用 FFT 优化这个式子做到 \(O(\log^2 n)\)。
不过官方题解说可以做到。
最后一步,由于第一步中没有考虑不在路径上的其他点的方案影响,因此需要乘上去。
\[ans = \sum_{t=1}^{\text{max path length}} dp_{1, t} \times \text{path contribution}_t \times m^{n-t}\]
这个题,本质上还是很妙的。我们很容易思考到这是一个转化成各个部分对于总的答案的贡献这一思路,然而这个题目中固定路径长度 \(t\),然后计算分成不同长度路径对于答案贡献这一方式还是相当难想到的。
#includeusing namespace std;typedef long long ll;typedef double db;typedef long double ld;#define IL inline#define fi first#define se second#define mk make_pair#define pb push_back#define SZ(x) (int)(x).size()#define ALL(x) (x).begin(), (x).end()#define dbg1(x) cout << #x << " = " << x << ", "#define dbg2(x) cout << #x << " = " << x << endltemplate IL void read(Tp &x) { x=0; int f=1; char ch=getchar(); while(!isdigit(ch)) {if(ch == '-') f=-1; ch=getchar();} while(isdigit(ch)) { x=x*10+ch-'0'; ch=getchar();} x *= f;}int buf[42];template IL void write(Tp x) { int p = 0; if(x >= 1ll; } return ret;}pair depl(ll u) { if ((u < n) { return mk(1, u); } auto p = depl(u << 1ll); return mk(p.fi + 1, p.se);}int depr(ll u) { if ((u < n) { return 1; } return depr(u << 1ll | 1ll) + 1;}bool fulltree(ll u) { return ((u < n && (u < n) || (depl(u << 1ll).fi == depr(u << 1ll | 1ll));}ll getsz(ll u) { if ((u < n) return 1; if ((u < n) return 2; auto p = depl(u); int dr = depr(u); // dbg1(u); dbg1(p.fi); dbg1(p.se); dbg1(dr); dbg1((1ll << (1ll * dr)) - 1); dbg2((1ll << (1ll * dr)) - 1 + (n - p.se + 1)); if (p.fi == dr) return (1ll << (1ll * dr)) - 1; else { return (1ll << (1ll * dr)) - 1 + (n - p.se + 1); }}unordered_map dpid;void dfs(ll u) { int uid; ll szu = getsz(u); if (dpid.count(szu) == 0) dpid[szu] = uid = ++dpid_cnt; else return; f[uid][0] = f[uid][1] = 1; dp[uid][0] = 1; dp[uid][1] = szu % mod; if ((u < n) return; else if((u < n) { dfs(u << 1ll); f[uid][2] = dp[uid][2] = 1; return; } dfs(u << 1ll); dfs(u << 1ll | 1ll); int lid = dpid[getsz(u << 1ll)], rid = dpid[getsz(u << 1ll | 1ll)]; for (int j = 2; j <= 2 * LOGN; j++) { f[uid][j] = (f[lid][j-1] + f[rid][j-1]) % mod; dp[uid][j] = (dp[lid][j] + dp[rid][j]) % mod; for (int k = 0; k < j; k++) { dp[uid][j] = (dp[uid][j] + 1ll * f[lid][k] * f[rid][j - 1 - k]) % mod; } }}void solve() { dpid_cnt = 0; dpid.clear(); memset(pathcon, 0, sizeof(pathcon)); memset(f, 0, sizeof(f)); memset(dp, 0, sizeof(dp)); read(n); read(m); for (int t = 0; t <= (LOGN << 1); t++) { pathcon[t] = ksm(m, t + 1); for (int k = 1; k < m; k++) { pathcon[t] = (1ll * pathcon[t] - ksm(k, t) + mod) % mod; } } dfs(1); int ans = 0; for (int t = 1; t <= min(n, 2ll * LOGN); t++) { if (dp[1][t] == 0) break; ans = (ans + 1ll * dp[1][t] * pathcon[t] % mod * ksm(m, n - t)) % mod; } write(ans); putchar(10);}int main() {#ifdef LOCAL freopen("test.in", "r", stdin); // freopen("test.out", "w", stdout);#endif int T = 1; read(T); while(T--) solve(); return 0;}