树状数组与线段树入门

树状数组和线段树都是用于处理动态区间问题的数据结构。

树状数组:支持区间加法的同时区间查询区间和,以及最值;

线段树:支持区间加法的同时区间乘法的同时区间查询区间和,以及最值。

线段树的适用范围相较于树状数组更加广泛,但树状数组相对于线段树简洁很多,且常数极小。

树状数组

树状数组的空间复杂度为O(n),时间上单次查询为O(logn)。

首先我们需要知道这个定义:

lowbit(x):表示x在二进制下从低位往高位数的第一个1(最低位有效1)

int lowbit(int x){return x&-x;}

树状数组听起来是一个树形结构,但是实际上用一个长度与操作数组同样为n的数组存储就行了

我们令一个树状数组为t[n],则t[x]表示的是从数组下标为x-lowbit(x)+1到x之间数组元素的和

∴对于单点更新: 

void update(int pos,int val){
    while(pos<=n)
        t[pos]+=val,pos+=lowbit(pos);
}

那么如何求区间和呢?

我们先设sum初始化为0,用于记录数组下标为1~pos的元素之和

树状数组最需要我们理解的一点,就是t[x]表示a[x-lowbit(x)+1]至a[x]的和

那么如果我们要求数组下标为1~pos的元素之和,是不是每次让sum加上t[pos]后,pos变为(pos-lowbit(pos)+1)-1,即pos=pos-lowbit(pos)就可以了?

因为t[x]表示t[x-lowbit(x)+1]到t[pos]的和,每次让pos跳转到这个区间的前一位((x-lowbit(x)+1)-1=x-lowbit(x))就可以了?

计算l~r的和的话,简单的容斥一下:query(r)-query(l-1)   就可以了。

int query(int pos){
    int sum=0;
    while(pos)
        sum+=t[pos],pos-=lowbit(pos);
    return sum;
}

例题 树状数组1

题目描述

如题,已知一个数列,你需要进行下面两种操作:

输入格式

第一行包含两个正整数 n,m,分别表示该数列数字的个数和操作的总个数。

第二行包含 n 个用空格分隔的整数,其中第 i 个数字表示数列第 i 项的初始值。

接下来 m 行每行包含 3 个整数,表示一个操作,具体如下:

  • 1 x k 含义:将第 x 个数加上 k

  • 2 x y 含义:输出区间 [x,y] 内每个数的和

输出格式

输出包含若干行整数,即为所有操作 2 的结果。

代码如下:

#include<cstdio>
using namespace std;
typedef long long ll;
const int maxn=500001;
ll n,m,x,a,b,t[maxn];
ll lowbit(ll x){return x&-x;}
void update(ll pos,ll val){
	while(pos<=n)
		t[pos]+=val,pos+=lowbit(pos);
}
ll query(ll pos){
	ll sum=0;
	while(pos)
		sum+=t[pos],pos-=lowbit(pos);
	return sum;
}
int main()
{
	scanf("%lld%lld",&n,&m);
	for(ll i=1;i<=n;++i)
	{
		scanf("%lld",&x);
		update(i,x);
	}
	while(m--){
		scanf("%lld%lld%lld",&x,&a,&b);
		switch(x){
			case 1:
				update(a,b);
				break;
			default:
				printf("%lld\n",query(b)-query(a-1));
				break;
		}
	}
	return 0;
}

线段树

线段树的空间复杂度>=2*n,通常需要4*n避免RE,单次查询时间复杂度为O(logn)

第一次听起来非常容易理解:这是一棵树:树上任意一点都有t[x].l与t[x].r,此节点表示数组下标为t[x].l到t[x].r的元素之和

                                                       (1,8)

                               (1,4)                                         (5,8)

                   (1,2)              (3,4)                    (5,6)                 (7,8)

             (1,1)   (2,2)     (3,3)  (4,4)          (5,5)   (6,6)      (7,7)    (8,8)

也就是他的左儿子的值加上他的右儿子的值

const int maxn=100001;
struct node{
	ll l,r,x,upd;
}t[maxn<<2];//maxn<<2等同于maxn*2*2 (maxn*4)

他的左儿子表示x*2,右儿子表示x*2+1

则t[x*2].l=t[x].l,t[x*2].r=(t[x].l+t[x].r)/2;

t[x*2+1].l=(t[x].l+t[x].r)/2+1,t[x*2+1].r=t[x].r

建树时递归到l==r时不再递归,这个节点的值就是a[l]的值

int build(int pos,int l,int r)
{
	t[pos].l=l,t[pos].r=r;
	return t[pos].x=(l==r?a[l]:build(pos<<1,l,l+r>>1)+build(pos<<1|1,(l+r>>1)+1,r));
}//x<<1等同于x*2,x<<1|1等同于x*2+1

那么:问题来了

Q:那区间加法的时候时间复杂度不是n+logn的吗,这时应该怎么办呢?

A:Lazy_tag!

我们设有一棵线段树:

                                                       (1,8)

                               (1,4)                                         (5,8)

                   (1,2)              (3,4)                    (5,6)                 (7,8)

             (1,1)   (2,2)     (3,3)  (4,4)          (5,5)   (6,6)      (7,7)    (8,8)

如果我们要使区间1~4的全体元素集体加上一个值呢?

引进一个叫做lazy_tag的东西;

我们可以看到,从(1,8)开始更新,发现他的左右儿子都跟要更新的区间有关,继续递归下去

此时递归到了(1,4),(5,8)

这时!

我们发现区间(1,4)完全被我们要更新的区间1~4包含在内,于是我们就可以不用递归下去,直接在这层操作!

if(t[pos].l>=l&&t[pos].r<=r)
	{
		t[pos].x+=(t[pos].r-t[pos].l+1)*val,t[pos].upd+=val;
		return;
	}

t[pos].x加上(t[pos].r-t[pos].l+1)*val,是因为这个区间一共有(t[pos].r-t[pos].l+1)个元素,每个元素都加上val,即(t[pos].r-t[pos].l+1)*val

而t[pos].upd+=val表示这段区间内的每个数都加上val

我们又发现,区间(5,8)与我们要更新的1~4没有关系,故直接return

if(t[pos].l>r||t[pos].r<l)
		return;

除了完全包含和完全不包含外,还有两个区间一部分包括,一部分不重叠的情况

此时怎么办呢?

在这个时候,t[pos].upd(lazy_tag)的作用就体现出来了:
我们可以把这一层的标记细分给他的左儿子与右儿子

我们新定义一个函数push_down(下放标记):

void push(int pos){
    t[pos*2].x+=t[pos].upd*(t[pos*2].r-t[pos*2].l+1);
    t[pos*2+1].x+=t[pos].upd*(t[pos*2+1].r-t[pos*2+1].l+1);
    t[pos*2].upd+=t[pos].upd;
    t[pos*2+1].upd+=t[pos].upd;
    t[pos].upd=0;
}

update内函数新增:

push(pos);
update(pos<<1,l,r,val);
update(pos<<1|1,l,r,val);

最后我们更新了这个点的左右儿子,不要忘了再更新一下它自己:

t[pos].x=t[pos*2].x+t[pos*2+1].x;

完整的区间加法代码:

void update(ll pos,ll l,ll r,ll val)
{
	if(t[pos].l>=l&&t[pos].r<=r)
	{
		t[pos].x+=(t[pos].r-t[pos].l+1)*val,t[pos].upd+=val;
		return;
	}
	if(t[pos].l>r||t[pos].r<l)
		return;
	push(pos);
	update(pos<<1,l,r,val);
	update(pos<<1|1,l,r,val);
	t[pos].x=t[pos<<1].x+t[pos<<1|1].x;
}

区间查询也一样:往下搜索搜索到完全被需要查找的区间覆盖的区间就直接返回这个区间的值,不相关就返回0,否则下放标记,继续往下查询

int search(int pos,int l,int r)
{
	if(t[pos].l>=l&&t[pos].r<=r)
		return t[pos].x;
	if(t[pos].r<l||t[pos].l>r)
		return 0;
	push(pos);
	return search(pos<<1,l,r)+search(pos<<1|1,l,r);
}

模板 线段树1

题目描述

如题,已知一个数列,你需要进行下面两种操作:

  1. 将某区间每一个加上 k。
  2. 求出某区间每一个数的和。

输入格式

第一行包含两个整数 n,m,分别表示该数列数字的个数和操作的总个数。

第二行包含 n 个用空格分隔的整数,其中第 i 个数字表示数列第 i 项的初始值。

接下来 m 行每行包含 3 或 4 个整数,表示一个操作,具体如下:

  1. 1 x y k:将区间 [x, y] 内每个数加上 k。
  2. 2 x y输出区间 [x, y] 内每个数的和。

输出格式

输出包含若干行整数,即为所有操作 2 的结果。

代码:

#include<cstdio>
using namespace std;
typedef long long ll;
const int maxn=100001;
ll n,m,x,l,r,k,a[maxn];
struct node{
	ll l,r,x,upd;
}t[maxn<<2];
ll build(ll pos,ll l,ll r)
{
	t[pos].l=l,t[pos].r=r;
	return t[pos].x=(l==r?a[l]:build(pos<<1,l,l+r>>1)+build(pos<<1|1,(l+r>>1)+1,r));
}
void push(ll pos){t[pos<<1].x+=t[pos].upd*(t[pos<<1].r-t[pos<<1].l+1),t[pos<<1|1].x+=t[pos].upd*(t[pos<<1|1].r-t[pos<<1|1].l+1),t[pos<<1].upd+=t[pos].upd,t[pos<<1|1].upd+=t[pos].upd,t[pos].upd=0;}
void update(ll pos,ll l,ll r,ll val)
{
	if(t[pos].l>=l&&t[pos].r<=r)
	{
		t[pos].x+=(t[pos].r-t[pos].l+1)*val,t[pos].upd+=val;
		return;
	}
	if(t[pos].l>r||t[pos].r<l)
		return;
	push(pos);
	update(pos<<1,l,r,val);
	update(pos<<1|1,l,r,val);
	t[pos].x=t[pos<<1].x+t[pos<<1|1].x;
}
ll search(ll pos,ll l,ll r)
{
	if(t[pos].l>=l&&t[pos].r<=r)
		return t[pos].x;
	if(t[pos].r<l||t[pos].l>r)
		return 0;
	push(pos);
	return search(pos<<1,l,r)+search(pos<<1|1,l,r);
}
int main()
{
	scanf("%lld%lld",&n,&m);
	for(ll i=1;i<=n;++i)
		scanf("%lld",a+i);
	build(1,1,n);
	while(m--){
		scanf("%lld%lld%lld",&x,&l,&r);
		switch(x){
			case 1:
				scanf("%lld",&k);
				update(1,l,r,k);
				break;
			default:
				printf("%lld\n",search(1,l,r));
				break;
		}
	}
	return 0;
}

也可以用树状数组过:

#include<cstdio>
using namespace std;
const int maxn=100001;
typedef long long ll;
ll n,m,x,l,r,k,t[maxn],f[maxn];
ll lowbit(ll a){return a&-a;}
void update(ll pos,ll val)
{
	for(ll i=pos;i<=n;i+=lowbit(i))
		t[i]+=val,f[i]+=pos*val;
}
ll query(ll pos)
{
	ll sum=0;
	for(ll i=pos;i;i-=lowbit(i))
		sum+=(pos+1)*t[i]-f[i];
	return sum;
}
int main()
{
	scanf("%lld%lld",&n,&m);
	for(ll i=1;i<=n;++i)
	{
		scanf("%lld",&x);
		update(i,x);
		update(i+1,-x);
	}
	while(m--){
		scanf("%lld%lld%lld",&x,&l,&r);
		switch(x){
			case 1:
				scanf("%lld",&k);
				update(l,k);
				update(r+1,-k);
				break;
			default:
				printf("%lld\n",query(r)-query(l-1));
				break;
		}
	}
	return 0;
}

模板 线段树2

题目描述

如题,已知一个数列,你需要进行下面三种操作:

输入格式

第一行包含三个整数n,m,p,分别表示该数列数字的个数、操作的总个数和模数。

第二行包含 n 个用空格分隔的整数,其中第 i 个数字表示数列第 i 项的初始值。

接下来 m 行每行包含若干个整数,表示一个操作,具体如下:

操作 1: 格式:1 x y k 含义:将区间 [x,y] 内每个数乘上 k

操作 2: 格式:2 x y k 含义:将区间 [x,y] 内每个数加上 k

操作 3: 格式:3 x y 含义:输出区间 [x,y] 内每个数的和对 p 取模所得的结果

输出格式

输出包含若干行整数,即为所有操作 33 的结果。

涉及到区间乘法,但依旧是维护lazy_tag换汤不换药

代码:

#include<cstdio>
#define ls pos<<1
#define rs pos<<1|1
#define mid (l+r>>1)
using namespace std;
typedef long long ll;
const int maxn=100001;
ll n,m,p,x,l,r,k,a[maxn];
struct node{
	ll l,r,x,upd,mul;
}t[maxn<<2];
ll build(ll pos,ll l,ll r){
	t[pos].l=l,t[pos].r=r,t[pos].mul=1;
	return t[pos].x=(l==r?a[l]:build(ls,l,mid)+build(rs,mid+1,r))%p;
}
void push(ll pos){t[ls].x=(t[ls].x*t[pos].mul+(t[ls].r-t[ls].l+1)*t[pos].upd)%p,t[rs].x=(t[rs].x*t[pos].mul+(t[rs].r-t[rs].l+1)*t[pos].upd)%p,t[ls].upd=(t[ls].upd*t[pos].mul+t[pos].upd)%p,t[rs].upd=(t[rs].upd*t[pos].mul+t[pos].upd)%p,t[ls].mul=t[ls].mul*t[pos].mul%p,t[rs].mul=t[rs].mul*t[pos].mul%p,t[pos].upd=0,t[pos].mul=1;}
void update(ll pos,ll l,ll r,ll val)
{
	if(t[pos].l>=l&&t[pos].r<=r)
	{
		t[pos].x=(t[pos].x+(t[pos].r-t[pos].l+1)*val)%p,t[pos].upd=(t[pos].upd+val)%p;
		return;
	}
	if(t[pos].r<l||t[pos].l>r)
		return;
	push(pos);
	update(ls,l,r,val);
	update(rs,l,r,val);
	t[pos].x=(t[ls].x+t[rs].x)%p;
}
void mul(ll pos,ll l,ll r,ll val)
{
	if(t[pos].l>=l&&t[pos].r<=r)
	{
		t[pos].x=t[pos].x*val%p,t[pos].upd=t[pos].upd*val%p,t[pos].mul=t[pos].mul*val%p;
		return;
	}
	if(t[pos].r<l||t[pos].l>r)
		return;
	push(pos);
	mul(ls,l,r,val);
	mul(rs,l,r,val);
	t[pos].x=(t[ls].x+t[rs].x)%p;
}
ll search(ll pos,ll l,ll r)
{
	if(t[pos].l>=l&&t[pos].r<=r)
		return t[pos].x;
	if(t[pos].l>r||t[pos].r<l)
		return 0;
	push(pos);
	return (search(ls,l,r)+search(rs,l,r))%p;
}
int main()
{
	scanf("%lld%lld%lld",&n,&m,&p);
	for(ll i=1;i<=n;++i)
		scanf("%lld",a+i);
	build(1,1,n);
	while(m--){
		scanf("%lld%lld%lld",&x,&l,&r);
		switch(x){
			case 1:
				scanf("%lld",&k);
				mul(1,l,r,k);
				break;
			case 2:
				scanf("%lld",&k);
				update(1,l,r,k);
				break;
			default:
				printf("%lld\n",search(1,l,r));
				break;
		}
	}
	return 0;
}

区间最值

题目描述

给定一个长度为 N 的数列,和 M 次询问,求出每一次询问的区间内数字的最大值。

输入格式

第一行包含两个整数 N,M,分别表示数列的长度和询问的个数。

第二行包含 N 个整数(记为 ai​),依次表示数列的第 i 项。

接下来 M 行,每行包含两个整数 li​,ri​,表示查询的区间为 [li​,ri​]。

输出格式

输出包含 M 行,每行一个整数,依次表示每一次询问的结果。

树状数组:

#include<cstdio>
using namespace std;
typedef long long ll;
const int maxn=100001;
inline ll read(){
	ll x=0,f=1; char ch=getchar();
	while(ch<48||ch>57){if(ch=='-'){f=-1;}ch=getchar();}
	while(ch>=48&&ch<=57){x=(x<<1)+(x<<3)+ch-48;ch=getchar();}
	return x*f;
}
ll max(ll a,ll b){return a>b?a:b;}
ll n,m,x,l,r,a[maxn],t[maxn],as[maxn];
ll lowbit(ll x){return x&-x;}
void update(ll pos){
	while(pos<=n){
		t[pos]=a[pos];
		for(ll i=1;i<lowbit(pos);i<<=1)
			t[pos]=max(t[pos],t[pos-i]);
		pos+=lowbit(pos);
	}
}
ll search(ll l,ll r){
	ll sum=0;
	while(r>=l){
		sum=max(sum,a[r]),--r;
		for(;r-lowbit(r)>=l;r-=lowbit(r))
			sum=max(sum,t[r]);
	}
	return sum;
}
int main()
{
	n=read(),m=read();
	for(ll i=1;i<=n;++i)
	{
		a[i]=read();
		as[i]=max(as[i-1],a[i]);
		update(i);
	}
	while(m--){
		l=read(),r=read();
		printf("%lld\n",as[r]!=as[l]?as[r]:search(l,r));
	}
	return 0;
}

线段树:

#include<cstdio>
#define INF 1145141919810
using namespace std;
typedef long long ll;
const int maxn=100001;
ll max(ll a,ll b){return a>b?a:b;}
inline ll read(){
	ll x=0,f=1; char ch=getchar();
	while(ch<48||ch>57){if(ch=='-'){f=-1;}ch=getchar();}
	while(ch>=48&&ch<=57){x=(x<<1)+(x<<3)+ch-48; ch=getchar();}
	return x*f;
}
ll n,m,sum,l,r,a[maxn],af[maxn];
struct node{
	ll l,r,x;
}t[maxn<<2];
ll build(ll x,ll l,ll r)
{
	t[x].l=l,t[x].r=r;
	return t[x].x=(l==r?a[l]:max(build(x<<1,l,l+r>>1),build(x<<1|1,(l+r>>1)+1,r)));
}
void update(ll x)
{
	if(!x)
		return;
	t[x].x=max(t[x<<1].x,t[x<<1|1].x);
	update(x>>1);
}
void search(ll x,ll l,ll r)
{
	if(t[x].l>=l&&t[x].r<=r)
	{
		sum=max(sum,t[x].x);
		return;
	}
	if(t[x].l>r||t[x].r<l)
		return;
	search(x<<1,l,r);
	search(x<<1|1,l,r);
}
int main()
{
	n=read(),m=read();
	for(ll i=1;i<=n;++i)
		a[i]=read(),af[i]=max(af[i-1],a[i]);
	build(1,1,n);
	while(m--){
		sum=-INF,l=read(),r=read();
		if(af[r]!=af[l-1])
		{
			printf("%lld\n",af[r]);
			continue;
		}
		search(1,l,r);
		printf("%lld\n",sum);
	}
	return 0;
}

还有几道例题:树状数组2

单调队列

线段树/树状数组套二分模板

蒟蒻的第一篇,哪里有不对望指正%%%:)))

完结撒花~

相关文章

显卡天梯图2024最新版,显卡是电脑进行图形处理的重要设备,...
初始化电脑时出现问题怎么办,可以使用win系统的安装介质,连...
todesk远程开机怎么设置,两台电脑要在同一局域网内,然后需...
油猴谷歌插件怎么安装,可以通过谷歌应用商店进行安装,需要...
虚拟内存这个名词想必很多人都听说过,我们在使用电脑的时候...