【线段树】2569. 更新数组后处理求和查询

本文涉及知识点

C++线段树

LeetCode2569. 更新数组后处理求和查询

给你两个下标从 0 开始的数组 nums1 和 nums2 ,和一个二维数组 queries 表示一些操作。总共有 3 种类型的操作:

操作类型 1 为 queries[i] = [1, l, r] 。你需要将 nums1 从下标 l 到下标 r 的所有 0 反转成 1 并且所有 1 反转成 0 。l 和 r 下标都从 0 开始。

操作类型 2 为 queries[i] = [2, p, 0] 。对于 0 <= i < n 中的所有下标,令 nums2[i] = nums2[i] + nums1[i] * p 。

操作类型 3 为 queries[i] = [3, 0, 0] 。求 nums2 中所有元素的和。

请你返回一个 数组,包含 所有第三种操作类型 的答案。

示例 1:

输入:nums1 = [1,0,1], nums2 = [0,0,0], queries = [[1,1,1],[2,1,0],[3,0,0]]

输出:[3]

解释:第一个操作后 nums1 变为 [1,1,1] 。第二个操作后,nums2 变成 [1,1,1] ,所以第三个操作的答案为 3 。所以返回 [3] 。

示例 2:

输入:nums1 = [1], nums2 = [5], queries = [[2,0,0],[3,0,0]]

输出:[5]

解释:第一个操作后,nums2 保持不变为 [5] ,所以第二个操作的答案是 5 。所以返回 [5] 。

提示:

1 <= nums1.length,nums2.length <= 10^5^

nums1.length = nums2.length

1 <= queries.length <= 10^5^

queries[i].length = 3

0 <= l <= r <= nums1.length - 1

0 <= p <= 10^6^

0 <= nums1[i] <= 1

0 <= nums2[i] <= 10^9^

线段树

lineTree表示num1。

操作二:sum = sum + p × \times × lineTree.All

操作三:返回sum。
注意:不要忘记了初始化lineTree和sum

代码

核心代码

cpp 复制代码
template<class TSave, class TRecord >
class CSingeUpdateLineTree
{
protected:
	virtual void OnInit(TSave& save, int iSave) = 0;
	virtual void OnQuery(TSave& save) = 0;
	virtual void OnUpdate(TSave& save, int iSave, const TRecord& update) = 0;
	virtual void OnUpdateParent(TSave& par, const TSave& left, const TSave& r, int iSaveLeft, int iSaveRight) = 0;
};

template<class TSave, class TRecord >
class CVectorSingUpdateLineTree : public CSingeUpdateLineTree<TSave, TRecord>
{
public:
	CVectorSingUpdateLineTree(int iEleSize, TSave tDefault) :m_iEleSize(iEleSize),m_save(iEleSize*4,tDefault){

	}
	void Update(int index, TRecord update) {
		Update(1, 0, m_iEleSize-1, index, update);
	}
	void Query(int leftIndex, int leftRight) {
		Query(1, 0, m_iEleSize - 1, leftIndex, leftRight);
	}
	void Init() {
		Init(1, 0, m_iEleSize - 1);
	}
	TSave QueryAll() {
		return m_save[1];
	}
protected:
	int m_iEleSize;
	void Init(int iNodeNO, int iSaveLeft, int iSaveRight)
	{
		if (iSaveLeft == iSaveRight) {
			this->OnInit(m_save[iNodeNO], iSaveLeft);
			return;
		}
		const int mid = iSaveLeft + (iSaveRight - iSaveLeft) / 2;
		Init(iNodeNO * 2, iSaveLeft, mid);
		Init(iNodeNO * 2 + 1, mid + 1, iSaveRight);
		this->OnUpdateParent(m_save[iNodeNO], m_save[iNodeNO*2], m_save[iNodeNO*2+1], iSaveLeft, iSaveRight);
	}
	void Query(int iNodeNO, int iSaveLeft, int iSaveRight, int iQueryLeft, int iQueryRight) {
		if ((iSaveLeft >= iQueryLeft) && (iSaveRight <= iQueryRight)) {
			this->OnQuery(m_save[iNodeNO]);
			return;
		}
		if (iSaveLeft == iSaveRight) {//没有子节点
			return;
		}
		const int mid = iSaveLeft + (iSaveRight - iSaveLeft) / 2;
		if (mid >= iQueryLeft) {
			Query(iNodeNO * 2, iSaveLeft, mid, iQueryLeft, iQueryRight);
		}
		if (mid + 1 <= iQueryRight) {
			Query(iNodeNO * 2 + 1, mid + 1, iSaveRight, iQueryLeft, iQueryRight);
		}
	}
	void Update(int iNodeNO, int iSaveLeft, int iSaveRight, int iUpdateNO, TRecord update) {
		if (iSaveLeft == iSaveRight)
		{
			this->OnUpdate(m_save[iNodeNO], iSaveLeft, update);
			return;
		}
		const int mid = iSaveLeft + (iSaveRight - iSaveLeft) / 2;
		if (iUpdateNO <= mid) {
			Update(iNodeNO * 2, iSaveLeft, mid, iUpdateNO, update);
		}
		else {
			Update(iNodeNO * 2 + 1, mid + 1, iSaveRight, iUpdateNO, update);
		}
		this->OnUpdateParent(m_save[iNodeNO], m_save[iNodeNO*2], m_save[iNodeNO*2+1], iSaveLeft, iSaveRight);
	}
	vector<TSave> m_save;
};

template<class TSave, class TRecord >
class CTreeSingeLineTree : public CSingeUpdateLineTree<TSave, TRecord>
{
protected:
	struct CTreeNode
	{		
		int Cnt()const { return m_iMaxIndex - m_iMinIndex + 1; }
		int m_iMinIndex;
		int m_iMaxIndex;
		TSave data;
		CTreeNode* m_lChild=nullptr, *m_rChild=nullptr;
	};
	CTreeNode* m_root;
	TSave m_tDefault;
public:
	CTreeSingeLineTree(int iMinIndex, int iMaxIndex, TSave tDefault) {
		m_tDefault = tDefault;
		m_root = CreateNode(iMinIndex, iMaxIndex);
	}
	void Init() {
		Init(m_root);
	}
	void Update(int index, TRecord update) {
		Update(m_root, index, update);
	}
	TSave QueryAll() {
		return m_root->data;
	}
	void Query(int leftIndex, int leftRight) {
		Query(m_root, leftIndex, leftRight);
	}
protected:
	void Query(CTreeNode* node, int iQueryLeft, int iQueryRight) {
		if ((node->m_iMinIndex >= iQueryLeft) && (node->m_iMaxIndex <= iQueryRight)) {
			this->OnQuery(node->data);
			return;
		}
		if (1 == node->Cnt()) {//没有子节点
			return;
		}
		CreateChilds(node);
		const int mid = node->m_iMinIndex + (node->m_iMaxIndex - node->m_iMinIndex) / 2;
		if (mid >= iQueryLeft) {
			Query(node->m_lChild, iQueryLeft, iQueryRight);
		}
		if (mid + 1 <= iQueryRight) {
			Query(node->m_rChild, iQueryLeft, iQueryRight);
		}
	}
	void Init(CTreeNode* node)
	{
		if (1 == node->Cnt()) {
			this->OnInit(node->data, node->m_iMinIndex);
			return;
		}
		CreateChilds(node);
		Init(node->m_lChild);
		Init(node->m_rChild);
		this->OnUpdateParent(node->data, node->m_lChild->data, node->m_rChild->data, node->m_iMinIndex, node->m_iMaxIndex);
	}
	void Update(CTreeNode* node, int iUpdateNO, TRecord update) {
		if ((iUpdateNO < node->m_iMinIndex) || (iUpdateNO > node->m_iMaxIndex)) {
			return;
		}
		if (1 == node->Cnt()) {
			this->OnUpdate(node->data, node->m_iMinIndex, update);
			return;
		}
		CreateChilds(node);
		Update(node->m_lChild, iUpdateNO, update);
		Update(node->m_rChild, iUpdateNO, update);
		this->OnUpdateParent(node->data, node->m_lChild->data, node->m_rChild->data, node->m_iMinIndex, node->m_iMaxIndex);
	}
	void CreateChilds(CTreeNode* node) {
		if (nullptr != node->m_lChild) { return; }
		const int iSaveLeft = node->m_iMinIndex;
		const int iSaveRight = node->m_iMaxIndex;
		const int mid = iSaveLeft + (iSaveRight - iSaveLeft) / 2;
		node->m_lChild = CreateNode(iSaveLeft,mid);
		node->m_rChild = CreateNode(mid+1, iSaveRight);
	}
	CTreeNode* CreateNode(int iMinIndex, int iMaxIndex) {
		CTreeNode* node = new CTreeNode;
		node->m_iMinIndex = iMinIndex;
		node->m_iMaxIndex = iMaxIndex;
		node->data = m_tDefault;
		return node;
	}
};

template<class TSave, class TRecord >
class CRangUpdateLineTree
{
protected:
	virtual void OnQuery(const TSave& save, const int& iSaveLeft, const int& iSaveRight) = 0;
	virtual void OnUpdate(TSave& save, const int& iSaveLeft, const int& iSaveRight, const TRecord& update) = 0;
	virtual void OnUpdateParent(TSave& par, const TSave& left, const TSave& r, const int& iSaveLeft, const int& iSaveRight) = 0;
	virtual void OnUpdateRecord(TRecord& old, const TRecord& newRecord) = 0;
};


template<class TSave, class TRecord >
class CTreeRangeLineTree : public CRangUpdateLineTree<TSave, TRecord>
{
protected:
	struct CTreeNode
	{
		int Cnt()const { return m_iMaxIndex - m_iMinIndex + 1; }
		int m_iMinIndex;
		int m_iMaxIndex;
		TRecord record;
		TSave data;
		CTreeNode* m_lChild = nullptr, * m_rChild = nullptr;
	};
	CTreeNode* m_root;
	TSave m_tDefault;
	TRecord m_tRecordDef;
public:
	CTreeRangeLineTree(int iMinIndex, int iMaxIndex, TSave tDefault, TRecord tRecordDef) {
		m_tDefault = tDefault;
		m_tRecordDef = tRecordDef;
		m_root = CreateNode(iMinIndex, iMaxIndex);
	}
	void Update(int iLeftIndex, int iRightIndex, TRecord value)
	{
		Update(m_root, iLeftIndex, iRightIndex, value);
	}
	TSave QueryAll() {
		return m_root->data;
	}
	void Query(int leftIndex, int leftRight) {
		Query(m_root, leftIndex, leftRight);
	}
protected:
	void Query(CTreeNode* node, int iQueryLeft, int iQueryRight) {
		if ((node->m_iMinIndex >= iQueryLeft) && (node->m_iMaxIndex <= iQueryRight)) {
			this->OnQuery(node->data, node->m_iMinIndex, node->m_iMaxIndex);
			return;
		}
		if (1 == node->Cnt()) {//没有子节点
			return;
		}
		CreateChilds(node);
		Fresh(node);
		const int mid = node->m_iMinIndex + (node->m_iMaxIndex - node->m_iMinIndex) / 2;
		if (mid >= iQueryLeft) {
			Query(node->m_lChild, iQueryLeft, iQueryRight);
		}
		if (mid + 1 <= iQueryRight) {
			Query(node->m_rChild, iQueryLeft, iQueryRight);
		}
	}
	void Update(CTreeNode* node, int iOpeLeft, int iOpeRight, TRecord value)
	{
		const int& iSaveLeft = node->m_iMinIndex;
		const int& iSaveRight = node->m_iMaxIndex;
		if ((iOpeLeft <= iSaveLeft) && (iOpeRight >= iSaveRight))
		{
			this->OnUpdate(node->data, iSaveLeft, iSaveRight, value);
			this->OnUpdateRecord(node->record, value);
			return;
		}
		if (1 == node->Cnt()) {//没有子节点
			return;
		}
		CreateChilds(node);
		Fresh(node);
		const int mid = node->m_iMinIndex + (node->m_iMaxIndex - node->m_iMinIndex) / 2;
		if (mid >= iOpeLeft) {
			this->Update(node->m_lChild, iOpeLeft, iOpeRight, value);
		}
		if (mid + 1 <= iOpeRight) {
			this->Update(node->m_rChild, iOpeLeft, iOpeRight, value);
		}
		// 如果有后代,至少两个后代
		this->OnUpdateParent(node->data, node->m_lChild->data, node->m_rChild->data, node->m_iMinIndex, node->m_iMaxIndex);
	}
	void CreateChilds(CTreeNode* node) {
		if (nullptr != node->m_lChild) { return; }
		const int iSaveLeft = node->m_iMinIndex;
		const int iSaveRight = node->m_iMaxIndex;
		const int mid = iSaveLeft + (iSaveRight - iSaveLeft) / 2;
		node->m_lChild = CreateNode(iSaveLeft, mid);
		node->m_rChild = CreateNode(mid + 1, iSaveRight);
	}
	CTreeNode* CreateNode(int iMinIndex, int iMaxIndex) {
		CTreeNode* node = new CTreeNode;
		node->m_iMinIndex = iMinIndex;
		node->m_iMaxIndex = iMaxIndex;
		node->data = m_tDefault;
		node->record = m_tRecordDef;
		return node;
	}
	void Fresh(CTreeNode* node)
	{
		if (m_tRecordDef == node->record)
		{
			return;
		}
		CreateChilds(node);
		Update(node->m_lChild, node->m_lChild->m_iMinIndex, node->m_lChild->m_iMaxIndex, node->record);
		Update(node->m_rChild, node->m_rChild->m_iMinIndex, node->m_rChild->m_iMaxIndex, node->record);
		node->record = m_tRecordDef;
	}
};

template<class TSave, class TRecord >
class CVectorRangeUpdateLineTree : public CRangUpdateLineTree<TSave, TRecord>
{
public:
	CVectorRangeUpdateLineTree(int iEleSize, TSave tDefault, TRecord tRecordNull) :m_iEleSize(iEleSize)
		, m_save(iEleSize * 4, tDefault), m_record(iEleSize * 4, tRecordNull) {
		m_recordNull = tRecordNull;
	}
	void Update(int iLeftIndex, int iRightIndex, TRecord value)
	{
		Update(1, 0, m_iEleSize - 1, iLeftIndex, iRightIndex, value);
	}
	void Query(int leftIndex, int rightIndex) {
		Query(1, 0, m_iEleSize - 1, leftIndex, rightIndex);
	}
	//void Init() {
	//	Init(1, 0, m_iEleSize - 1);
	//}
	TSave QueryAll() {
		return m_save[1];
	}
	void swap(CVectorRangeUpdateLineTree<TSave, TRecord>& other) {
		m_save.swap(other.m_save);
		m_record.swap(other.m_record);
		std::swap(m_recordNull, other.m_recordNull);
		assert(m_iEleSize == other.m_iEleSize);
	}
protected:
	//void Init(int iNodeNO, int iSaveLeft, int iSaveRight)
	//{
	//	if (iSaveLeft == iSaveRight) {
	//		this->OnInit(m_save[iNodeNO], iSaveLeft);
	//		return;
	//	}
	//	const int mid = iSaveLeft + (iSaveRight - iSaveLeft) / 2;
	//	Init(iNodeNO * 2, iSaveLeft, mid);
	//	Init(iNodeNO * 2 + 1, mid + 1, iSaveRight);
	//	this->OnUpdateParent(m_save[iNodeNO], m_save[iNodeNO * 2], m_save[iNodeNO * 2 + 1], iSaveLeft, iSaveRight);
	//}
	void Query(int iNodeNO, int iSaveLeft, int iSaveRight, int iQueryLeft, int iQueryRight) {
		if ((iSaveLeft >= iQueryLeft) && (iSaveRight <= iQueryRight)) {
			this->OnQuery(m_save[iNodeNO], iSaveLeft, iSaveRight);
			return;
		}
		if (iSaveLeft == iSaveRight) {//没有子节点
			return;
		}
		Fresh(iNodeNO, iSaveLeft, iSaveRight);
		const int mid = iSaveLeft + (iSaveRight - iSaveLeft) / 2;
		if (mid >= iQueryLeft) {
			Query(iNodeNO * 2, iSaveLeft, mid, iQueryLeft, iQueryRight);
		}
		if (mid + 1 <= iQueryRight) {
			Query(iNodeNO * 2 + 1, mid + 1, iSaveRight, iQueryLeft, iQueryRight);
		}
	}
	void Update(int iNode, int iSaveLeft, int iSaveRight, int iOpeLeft, int iOpeRight, TRecord value)
	{
		if ((iOpeLeft <= iSaveLeft) && (iOpeRight >= iSaveRight))
		{
			this->OnUpdate(m_save[iNode], iSaveLeft, iSaveRight, value);
			this->OnUpdateRecord(m_record[iNode], value);
			return;
		}
		if (iSaveLeft == iSaveRight) {
			return;//没有子节点
		}
		Fresh(iNode, iSaveLeft, iSaveRight);
		const int iMid = iSaveLeft + (iSaveRight - iSaveLeft) / 2;
		if (iMid >= iOpeLeft)
		{
			Update(iNode * 2, iSaveLeft, iMid, iOpeLeft, iOpeRight, value);
		}
		if (iMid + 1 <= iOpeRight)
		{
			Update(iNode * 2 + 1, iMid + 1, iSaveRight, iOpeLeft, iOpeRight, value);
		}
		// 如果有后代,至少两个后代
		this->OnUpdateParent(m_save[iNode], m_save[iNode * 2], m_save[iNode * 2 + 1], iSaveLeft, iSaveRight);
	}
	void Fresh(int iNode, int iDataLeft, int iDataRight)
	{
		if (m_recordNull == m_record[iNode])
		{
			return;
		}
		const int iMid = iDataLeft + (iDataRight - iDataLeft) / 2;
		Update(iNode * 2, iDataLeft, iMid, iDataLeft, iMid, m_record[iNode]);
		Update(iNode * 2 + 1, iMid + 1, iDataRight, iMid + 1, iDataRight, m_record[iNode]);
		m_record[iNode] = m_recordNull;
	}
	vector<TSave> m_save;
	vector<TRecord> m_record;
	TRecord m_recordNull;
	const int m_iEleSize;
};
class CMyLineTree : public CVectorRangeUpdateLineTree<int, int> {
	typedef  int TSave;
	typedef  int TRecord;
	using  CVectorRangeUpdateLineTree<int, int>::CVectorRangeUpdateLineTree;
	// 通过 CVectorRangeUpdateLineTree 继承
	virtual void OnQuery(const TSave& save, const int& iSaveLeft, const int& iSaveRight) override
	{
	}
	virtual void OnUpdate(TSave& save, const int& iSaveLeft, const int& iSaveRight, const TRecord& update) override
	{
		save = (iSaveRight - iSaveLeft + 1) - save;
	}
	virtual void OnUpdateParent(TSave& par, const TSave& left, const TSave& r, const int& iSaveLeft, const int& iSaveRight) override
	{
		par = left + r;
	}
	virtual void OnUpdateRecord(TRecord& old, const TRecord& newRecord) override
	{
		old = (old + newRecord) % 2;
	}
};
class Solution {
public:
	vector<long long> handleQuery(vector<int>& nums1, vector<int>& nums2, vector<vector<int>>& queries) {
		CMyLineTree lineTree(nums1.size(),0,0);
		for (int i = 0; i < nums1.size(); i++) {
			if (0 == nums1[i]) { continue; }
			lineTree.Update(i, i, 1);
		}
		vector<long long> ret;
		long long sum = accumulate(nums2.begin(),nums2.end(),0LL);
		for (const auto& v : queries) {
			if (1 == v[0]) {
				lineTree.Update(v[1], v[2], 1);
			}
			else if (2 == v[0]) {
				sum += (long long)v[1] * lineTree.QueryAll();
			}
			else {
				ret.emplace_back(sum);
			}
		}
		return ret;
	}
};

单元测试

cpp 复制代码
template<class T1, class T2>
void AssertEx(const T1& t1, const T2& t2)
{
	Assert::AreEqual(t1, t2);
}
void AssertEx( double t1,  double t2)
{
	auto str = std::to_wstring(t1) + std::wstring(1,32) + std::to_wstring(t2);
	Assert::IsTrue(abs(t1 - t2) < 1e-5,str.c_str() );
}

template<class T>
void AssertEx(const vector<T>& v1, const vector<T>& v2)
{
	Assert::AreEqual(v1.size(), v2.size());
	for (int i = 0; i < v1.size(); i++)
	{
		Assert::AreEqual(v1[i], v2[i]);
	}
}

template<class T>
void AssertV2(vector<vector<T>> vv1, vector<vector<T>> vv2)
{
	sort(vv1.begin(), vv1.end());
	sort(vv2.begin(), vv2.end());
	Assert::AreEqual(vv1.size(), vv2.size());
	for (int i = 0; i < vv1.size(); i++)
	{
		AssertEx(vv1[i], vv2[i]);
	}
}

namespace UnitTest
{
	vector<int> nums1, nums2;
	vector<vector<int>> queries;
	TEST_CLASS(UnitTest)
	{
	public:
		TEST_METHOD(TestMethod00)
		{
			nums1 = { 1,0,1 }, nums2 = { 0,0,0 }, queries = { {1,1,1},{2,1,0},{3,0,0} };
			auto res = Solution().handleQuery(nums1, nums2, queries);
			AssertEx(vector<long long>{3}, res);
		}
		TEST_METHOD(TestMethod01)
		{
			nums1 = { 1 }, nums2 = { 5 }, queries = { {2,0,0},{3,0,0} };
			auto res = Solution().handleQuery(nums1, nums2, queries);
			AssertEx(vector<long long>{5}, res);
		}
	};
}
相关推荐
励志成为嵌入式工程师2 小时前
c语言简单编程练习9
c语言·开发语言·算法·vim
捕鲸叉2 小时前
创建线程时传递参数给线程
开发语言·c++·算法
A charmer2 小时前
【C++】vector 类深度解析:探索动态数组的奥秘
开发语言·c++·算法
Peter_chq2 小时前
【操作系统】基于环形队列的生产消费模型
linux·c语言·开发语言·c++·后端
wheeldown3 小时前
【数据结构】选择排序
数据结构·算法·排序算法
青花瓷4 小时前
C++__XCode工程中Debug版本库向Release版本库的切换
c++·xcode
观音山保我别报错4 小时前
C语言扫雷小游戏
c语言·开发语言·算法
幺零九零零5 小时前
【C++】socket套接字编程
linux·服务器·网络·c++
TangKenny5 小时前
计算网络信号
java·算法·华为
景鹤5 小时前
【算法】递归+深搜:814.二叉树剪枝
算法