#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#include<vector>
#include<set>
#include<map>
#include<cmath>
#include<algorithm>
using namespace std;
typedef pair<int, int>p;
typedef long long ll;
const int maxn = 5e4 + 5, oo = 0x3f3f3f3f;
int tot;
int n, k;
struct EDGE{
int to;
int nxt;
}edge[maxn << 1];
int kind[maxn];
int head[maxn];
bool used[maxn];
int sub_sz[maxn];
ll ans;
ll cnt[1305];
int a[maxn];
int mask;
int ROOT;
int ANS;
void init(int n)
{
kind[0] = 0;
mask = (1 << k) - 1;
memset(head, -1, sizeof(head));
ans = 0;
tot = 0;
}
int dfs_size(int u, int pa)
{
int sz = 1;
for(int i = head[u]; i != -1; i = edge[i].nxt){
int to = edge[i].to;
if(to == pa || used[to]) continue;
sz += dfs_size(to, u);
}
sub_sz[u] = sz;
return sz;
}
void dfs_root(int u, int pa, int sz)
{
int m = 0;
int t = 0;
for(int i = head[u]; i != -1; i = edge[i].nxt){
int to = edge[i].to;
if(to == pa || used[to]) continue;
dfs_root(to, u, sz);
m = max(m, sub_sz[to]);
t += m;
}
m = max(m, sz - t);
if(ANS > m){
ANS = m;
ROOT = u;
}
}
void cal_kind(int u, int p, int d)
{
kind[++kind[0]] = d;
for(int i = head[u]; i != -1; i = edge[i].nxt){
int to = edge[i].to;
if(to == p || used[to]) continue;
cal_kind(to, u, d | (1 << a[to]));
}
}
ll count_kind(int u, int d)
{
ll res = 0;
kind[0] = 0;
cal_kind(u, u, d);
memset(cnt, 0, sizeof(cnt));
for(int i = 1; i <= kind[0]; ++i) cnt[kind[i]]++;
for(int i = 0; i < k; ++i){
for(int j = mask; j >= 1; --j){
if((j >> i) & 1) cnt[j ^ (1 << i)] += cnt[j];
}
}
for(int i = 1; i <= kind[0]; ++i){
res += cnt[mask ^ kind[i]];
}
return res;
}
void solve(int u)
{
dfs_size(u, u);
ANS = n;
dfs_root(u, u, sub_sz[u]);
int s = ROOT;
used[s] = true;
ans += count_kind(s, 1 << a[s]);
for(int i = head[s]; i != -1; i = edge[i].nxt){
int to = edge[i].to;
if(used[to]) continue;
solve(to);
}
for(int i = head[s]; i != -1; i = edge[i].nxt){
int to = edge[i].to;
if(used[to]) continue;
ans -= count_kind(to, (1 << a[s]) | (1 << a[to]));
}
used[s] = false;
}
int main (void)
{
while(~scanf("%d%d", &n, &k)){
init(n);
for(int i = 1; i <= n; ++i){
scanf("%d", &a[i]);
a[i]--;
}
int u, v;
for(int i = 0; i < n - 1; ++i){
scanf("%d%d", &u, &v);
edge[tot].to = u;
edge[tot].nxt = head[v];
head[v] = tot++;
edge[tot].to = v;
edge[tot].nxt = head[u];
head[u] = tot++;
}
solve(1);
printf("%lld\n", ans);
}
return 0;
}