CF1234D – Distinct Characters Queries(线段树:单点修改,区间查询不同个数/set)1600

2022年7月8日 下午3:00 数据结构 , ,

题意

给你一个字符串s,有q次询问,每次询问格式有两种:
1) 1 pos c:将字符串中位置为pos的字符改为c
2) 2 x y:询问区间[x,y]中有多少个不一样的字符

数据范围:
|s|\leq 10^5,q\leq 10^5

思路

我是先会做了另外一道线段树的题(线段树-单点修改+区间求最大连续子段和 - CarryNotKarry)才会秒这个题的。

这个题很明显是单点修改,然后区间查询,这里设置一个结构体,有普通的存当前结点的l,r以及如果是叶子节点的字符ch,也存上一个num[26]数组,分别代表这个结点内包含的26个字母出现的个数。

struct node
{
    int num[26];
    int l,r;
    char ch;
}tr[N<<2];

为什么我不封装呢?因为我不像封装那样,我需要的不只是一个值,而是一个结点的所有信息,最后这里不难想象,我查询的时候,得到的是一个结点,node res = query(1,x,y);然后在对这个res里面的num[26]经行统计输出。(这里不需要pushdown,没有lazytag

那对于一个结点,我的pushup就只需要维护num[26]即可,也就是把左右儿子的num对应加起来即可:

inline void pushup(node &o,node &l,node &r)//注意这个需要传引用
{
    for (int i=0;i<26;i++)
        o.num[i] = l.num[i] + r.num[i];
}
inline void pushup(int o)
{
    pushup(tr[o],tr[lson],tr[rson]);
}

build的话也很简单,只需要在叶子节点的时候,把当前的字符存进去,因为一开始都是$0$,所以也符合我们所需要的nummodify也与build很像,叶子节点的时候修改即可。

inline void build(int o,int l,int r)
{
    tr[o].l = l, tr[o].r = r;
    if (l==r)
    {
        tr[o].ch = c[l];
        tr[o].num[c[l]-'a'] = 1;
        return;
    }
    build(lson,l,mid);build(rson,mid+1,r);
    pushup(o);
}
inline void modify(int o,int x,char c)
{
    int l = tr[o].l, r = tr[o].r;
    if (l==r)
    {
        tr[o].num[tr[o].ch-'a']--;//减去原来的
        tr[o].ch = c;//更新
        tr[o].num[c-'a']++;//填入新的
        return;
    }
    if (x<=mid) modify(lson,x,c);
    else modify(rson,x,c);//寻找x位置
    pushup(o);
}

然后就是查询,查询我的想法也是先开并且初始化两个lc,rc,然后维护res返回即可

inline node query(int o,int ql,int qr)
{
    int l = tr[o].l, r = tr[o].r;
    if (ql<=l&&r<=qr) return tr[o];
    node lc,rc,res;
    for (int i=0;i<26;i++)
        lc.num[i] = rc.num[i] = 0;
    if (ql<=mid) lc = query(lson,ql,qr);
    if (qr>mid) rc = query(rson,ql,qr);
    pushup(res,lc,rc);
    return res;
}

整体下来的代码跟线段树-单点修改+区间求最大连续子段和 - CarryNotKarry很像。

代码

char c[N];
struct node
{
    int num[26];
    int l,r;
    char ch;
}tr[N<<2];
#define mid ((l+r)>>1)
#define lson (o<<1)
#define rson (o<<1|1)
inline void pushup(node &o,node &l,node &r)
{
    for (int i=0;i<26;i++)
        o.num[i] = l.num[i] + r.num[i];
}
inline void pushup(int o)
{
    pushup(tr[o],tr[lson],tr[rson]);
}
inline void build(int o,int l,int r)
{
    tr[o].l = l, tr[o].r = r;
    if (l==r)
    {
        tr[o].ch = c[l];
        tr[o].num[c[l]-'a'] = 1;
        return;
    }
    build(lson,l,mid);build(rson,mid+1,r);
    pushup(o);
}
inline void modify(int o,int x,char c)
{
    int l = tr[o].l, r = tr[o].r;
    if (l==r)
    {
        tr[o].num[tr[o].ch-'a']--;
        tr[o].ch = c;
        tr[o].num[c-'a']++;
        return;
    }
    if (x<=mid) modify(lson,x,c);
    else modify(rson,x,c);
    pushup(o);
}
inline node query(int o,int ql,int qr)
{
    int l = tr[o].l, r = tr[o].r;
    if (ql<=l&&r<=qr) return tr[o];
    node lc,rc,res;
    for (int i=0;i<26;i++)
        lc.num[i] = rc.num[i] = 0;
    if (ql<=mid) lc = query(lson,ql,qr);
    if (qr>mid) rc = query(rson,ql,qr);
    pushup(res,lc,rc);
    return res;
}
inline void Case_Test()
{
    cin>>c;
    int n = strlen(c);
    for (int i=n;i>=1;i--)
        c[i] = c[i-1];
    build(1,1,n);
    int q;
    cin>>q;
    while (q--)
    {
        int op,x;
        cin>>op;
        if (op==1)
        {
            char c;
            cin>>x>>c;
            modify(1,x,c);
        }
        else
        {
            int y;
            cin>>x>>y;
            node res = query(1,x,y);
            int ans = 0;
            for (int i=0;i<26;i++)
                if (res.num[i]) ans++;
            cout<<ans<<endl;
        }
    }
}

这时间大概是300ms,有更简便的set方法,不过需要800~900ms,用26个set来表示每个字符的位置,每次查找用二分去判断每个字符的那个区间是否有数字。

set<int> st[26];
inline void Case_Test()
{
    cin>>s>>q;
    n = strlen(s);
    for (int i = 0; i < n; i++)
    {
        st[s[i]-'a'].insert(i+1);
    }
    while (q--)
    {
        int op;
        cin>>op;
        if (op==1)
        {
            cin>>x>>c;
            st[s[x-1]-'a'].erase(x);
            s[x-1] = c;
            st[c-'a'].insert(x);
        }
        else
        {
            cin>>x>>y;
            ans = 0;
            for (int i=0;i<26;i++)
            {
                auto l = st[i].lower_bound(x);
                auto r = st[i].upper_bound(y);
                if (l!=r) ans++;//如果不相等,说明区间[l,r]内有
            }
            cout<<ans<<endl;
        }
    }
}

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注