思路
这道题本来用爆搜来写的(DFS),然而TLE了3个点。我只好去向机房巨佬lornd求教。在他神奇的操作之下这道题从\(O(n^3)\)变成了\(O(n)\)的时间复杂度。
我们先来解析一下我们的流程。在存图之后,我们需要遍历n个点,然后获取与点n相连的点。这些点之间的距离必定为2。接下来我们来处理。
第一问
我们在遍历的时候记录下第一大和第二大的数,在处理完这些点之后,我们便把第一大、第二大的数相乘,可以得到这些有序对之中最大的联合权值,与ans变量取最大值便可解决第一问。
第二问
我们先来写一个公式:\[(\sum^{n}_{i=0}{a_i})^2 = \sum^{n}_{i=0}{a^2_i} + 2(\sum^{m}_{i=0,j=0}{a_i a_j})\]
第二问其实问的就是这个式子中的第三项:\[2(\sum^{m}_{i=0,j=0}{a_i a_j})\]
即\(2ab + 2cd + 2ef + \dots \)所以我们如果要求,我们便可以在代码中把第一项和第二项算出,相减即可出答案。
代码
// P1351.cpp
#include <iostream>
#include <algorithm>
#include <cstring>
#include <queue>
using namespace std;
#define ull unsigned long long
const int maxn = 2000000;
int n;
ull W[maxn];
// The graph data structure;
int to[maxn];
int point[maxn];
int next[maxn];
bool vis[maxn];
int current = 0;
// answers;
ull ans_max = 0;
ull ans_tot = 0;
// initialize to make the graph up;
void init()
{
memset(to, -1, sizeof(to));
memset(point, -1, sizeof(point));
memset(next, -1, sizeof(next));
memset(vis, false, sizeof(vis));
cin >> n;
for (int i = 0; i < n - 1; i++)
{
int x, y;
cin >> x >> y;
next[current] = point[x];
to[current] = y;
point[x] = current;
current++;
next[current] = point[y];
to[current] = x;
point[y] = current;
current++;
}
for (int i = 1; i <= n; i++)
cin >> W[i];
}
// calculate the n^2;
ull secondary(ull a)
{
return a * a;
}
// solve;
void solve()
{
for (int i = 1; i <= n; i++)
{
// get the maximum and the second one;
ull firmax = 0;
ull secmax = 0;
// the term1 and the term2;
ull tmp1 = 0;
ull tmp2 = 0;
for (int j = point[i]; j != -1; j = next[j])
{
int jto = to[j];
if (W[jto] > firmax)
secmax = firmax, firmax = W[jto];
else if (W[jto] > secmax)
secmax = W[jto];
tmp1 += W[jto];
tmp2 += secondary(W[jto]);
}
// add them up;
ans_tot += (secondary(tmp1) - tmp2) % 10007;
ans_tot %= 10007;
ans_max = max(ans_max, firmax * secmax);
}
}
// just solve it;
int main()
{
init();
solve();
cout << ans_max << " " << ans_tot;
return 0;
}
// P1351.cpp
#include <iostream>
#include <algorithm>
#include <cstring>
#include <queue>
using namespace std;
#define ull unsigned long long
const int maxn = 2000000;
int n;
ull W[maxn];
// The graph data structure;
int to[maxn];
int point[maxn];
int next[maxn];
bool vis[maxn];
int current = 0;
// answers;
ull ans_max = 0;
ull ans_tot = 0;
// initialize to make the graph up;
void init()
{
memset(to, -1, sizeof(to));
memset(point, -1, sizeof(point));
memset(next, -1, sizeof(next));
memset(vis, false, sizeof(vis));
cin >> n;
for (int i = 0; i < n - 1; i++)
{
int x, y;
cin >> x >> y;
next[current] = point[x];
to[current] = y;
point[x] = current;
current++;
next[current] = point[y];
to[current] = x;
point[y] = current;
current++;
}
for (int i = 1; i <= n; i++)
cin >> W[i];
}
// calculate the n^2;
ull secondary(ull a)
{
return a * a;
}
// solve;
void solve()
{
for (int i = 1; i <= n; i++)
{
// get the maximum and the second one;
ull firmax = 0;
ull secmax = 0;
// the term1 and the term2;
ull tmp1 = 0;
ull tmp2 = 0;
for (int j = point[i]; j != -1; j = next[j])
{
int jto = to[j];
if (W[jto] > firmax)
secmax = firmax, firmax = W[jto];
else if (W[jto] > secmax)
secmax = W[jto];
tmp1 += W[jto];
tmp2 += secondary(W[jto]);
}
// add them up;
ans_tot += (secondary(tmp1) - tmp2) % 10007;
ans_tot %= 10007;
ans_max = max(ans_max, firmax * secmax);
}
}
// just solve it;
int main()
{
init();
solve();
cout << ans_max << " " << ans_tot;
return 0;
}
// P1351.cpp #include <iostream> #include <algorithm> #include <cstring> #include <queue> using namespace std; #define ull unsigned long long const int maxn = 2000000; int n; ull W[maxn]; // The graph data structure; int to[maxn]; int point[maxn]; int next[maxn]; bool vis[maxn]; int current = 0; // answers; ull ans_max = 0; ull ans_tot = 0; // initialize to make the graph up; void init() { memset(to, -1, sizeof(to)); memset(point, -1, sizeof(point)); memset(next, -1, sizeof(next)); memset(vis, false, sizeof(vis)); cin >> n; for (int i = 0; i < n - 1; i++) { int x, y; cin >> x >> y; next[current] = point[x]; to[current] = y; point[x] = current; current++; next[current] = point[y]; to[current] = x; point[y] = current; current++; } for (int i = 1; i <= n; i++) cin >> W[i]; } // calculate the n^2; ull secondary(ull a) { return a * a; } // solve; void solve() { for (int i = 1; i <= n; i++) { // get the maximum and the second one; ull firmax = 0; ull secmax = 0; // the term1 and the term2; ull tmp1 = 0; ull tmp2 = 0; for (int j = point[i]; j != -1; j = next[j]) { int jto = to[j]; if (W[jto] > firmax) secmax = firmax, firmax = W[jto]; else if (W[jto] > secmax) secmax = W[jto]; tmp1 += W[jto]; tmp2 += secondary(W[jto]); } // add them up; ans_tot += (secondary(tmp1) - tmp2) % 10007; ans_tot %= 10007; ans_max = max(ans_max, firmax * secmax); } } // just solve it; int main() { init(); solve(); cout << ans_max << " " << ans_tot; return 0; }
lornd tql%%%