CSP认证 备考(python)

第一题

1.python的接收输入

import sys

line = sys.stdin.readline() # 假设输入"hello"后回车

print(repr(line)) # 输出: 'hello\n' (能看到\n)

循环读取接下来的n行,每一行是一个点的坐标 for _ in range(n): x, y = map(float, sys.stdin.readline().split())

2.十位数,个位数,十分位,百分位

个位:a = x%10

十位:b = (x//10)%10

十分位:c = int(x*10)%10

百分位:col = int(absZ * 100) % 10

3.最大共约数

def gcd(a, b):

while b:

a, b = b, a % b

return a

直接 import math

g = math.gca(a,b)

多个数的最大公约数:

4.二分查找
复制代码
def binary_search(nums, target):
    left, right = 0, len(nums)-1
    
    while left <= right:  # 包含等号
        mid = left + (right - left) // 2
        
        if nums[mid] == target:
            return mid
        elif nums[mid] < target:
            left = mid + 1  # 明确排除mid
        else:  # nums[mid] > target
            right = mid - 1  # 明确排除mid
    
    return -1
复制代码
while x < y:  # 当 x==y 时,只剩一个元素,就是峰值

第二题

def sliding_window_matrix_optimized(grid, k):

"""

计算 k*k 滑动窗口在矩阵上的和。

时间复杂度: O(m*n)

空间复杂度: O(m*n) (用于存储中间结果,可优化至 O(n))

"""

if not grid or not grid0:

return \[\]

m, n = len(grid), len(grid0)

if m < k or n < k:

return \[\]

Step 1: 计算每一行的水平滑动窗口和

horizontal_sums 的大小为 m x (n - k + 1)

horizontal_sums = \[0 * (n - k + 1) for _ in range(m)]

for i in range(m):

初始化当前行第一个窗口的和

current_sum = sum(gridi:k)

horizontal_sumsi0 = current_sum

滑动窗口:减去左边离开的,加上右边进入的

for j in range(1, n - k + 1):

current_sum = current_sum - gridij - 1 + gridij + k - 1

horizontal_sumsij = current_sum

Step 2: 在 horizontal_sums 的基础上,计算垂直滑动窗口和

result 的大小为 (m - k + 1) x (n - k + 1)

result = \[\]

只需要遍历列的宽度 (n - k + 1)

width = n - k + 1

for j in range(width):

提取这一列所有的水平和

col_data = horizontal_sums\[ij for i in range(m)]

初始化这一列第一个垂直窗口的和

current_col_sum = sum(col_data:k)

col_result = current_col_sum

垂直滑动

for i in range(1, m - k + 1):

current_col_sum = current_col_sum - col_datai - 1 + col_datai + k - 1

col_result.append(current_col_sum)

将这一列的结果暂时存入,注意这里得到的是转置的结构,或者是列表的列表

这种写法为了代码清晰分了两步,实际为了 result 结构正确,通常会按行遍历

if j == 0:

result = \[x for x in col_result]

else:

for i in range(len(col_result)):

resulti.append(col_resulti)

return result

def max_sliding_window(nums, k):

"""单调队列实现滑动窗口最大值"""

from collections import deque

n = len(nums)

if n * k == 0:

return \[\]

if k == 1:

return nums

def clean_deque(i):

移除不在窗口中的元素

if dq and dq0 == i - k:

dq.popleft()

移除小于当前元素的元素

while dq and numsi > numsdq\[-1]:

dq.pop()

dq = deque()

max_idx = 0

for i in range(k):

clean_deque(i)

dq.append(i)

if numsi > numsmax_idx:

max_idx = i

output = nums\[max_idx]

for i in range(k, n):

clean_deque(i)

dq.append(i)

output.append(numsdq\[0])

return output

2.递归公式构建

from functools import lru_cache

from typing import List

class RecursionTemplate:

"""递归+记忆化通用模板"""

def fibonacci(self, n: int) -> int:

"""斐波那契数列:f(n) = f(n-1) + f(n-2)"""

@lru_cache(None)

def dfs(x):

if x <= 1:

return x

return dfs(x-1) + dfs(x-2)

return dfs(n)

def combination_sum(self, nums: Listint, target: int) -> int:

"""组合求和:从nums中选取元素和为target的方案数"""

@lru_cache(None)

def dfs(remaining):

if remaining == 0:

return 1

if remaining < 0:

return 0

total = 0

for num in nums:

total += dfs(remaining - num)

return total

return dfs(target)

def dfs_with_params(self, grid: ListList\[int]) -> int:

"""带参数的DFS模板"""

m, n = len(grid), len(grid0)

directions = (0, 1), (1, 0), (0, -1), (-1, 0)

@lru_cache(None)

def dfs(x, y, visited_mask):

"""visited_mask可以用位运算表示访问状态"""

if x < 0 or x >= m or y < 0 or y >= n:

return 0

pos_mask = 1 << (x * n + y)

if visited_mask & pos_mask:

return 0

标记访问

new_mask = visited_mask | pos_mask

递归探索

best = 1 # 当前格子

for dx, dy in directions:

best = max(best, 1 + dfs(x+dx, y+dy, new_mask))

return best

return dfs(0, 0, 0)

3.分治

class DivideConquer:

"""分治算法模板"""

def merge_sort(self, nums: Listint) -> Listint:

"""归并排序"""

if len(nums) <= 1:

return nums

mid = len(nums) // 2

left = self.merge_sort(nums:mid)

right = self.merge_sort(numsmid:)

return self._merge(left, right)

def _merge(self, left: Listint, right: Listint) -> Listint:

result = \[\]

i = j = 0

while i < len(left) and j < len(right):

if lefti <= rightj:

result.append(lefti)

i += 1

else:

result.append(rightj)

j += 1

result.extend(lefti:)

result.extend(rightj:)

return result

def max_subarray(self, nums: Listint) -> int:

"""最大子数组和(分治版本)"""

def divide_conquer(l, r):

if l == r:

return numsl

mid = (l + r) // 2

分治递归

left_max = divide_conquer(l, mid)

right_max = divide_conquer(mid + 1, r)

计算跨越中点的最大子数组和

从中点向左

left_cross = numsmid

curr = numsmid

for i in range(mid - 1, l - 1, -1):

curr += numsi

left_cross = max(left_cross, curr)

从中点向右

right_cross = numsmid + 1 if mid + 1 <= r else float('-inf')

curr = numsmid + 1 if mid + 1 <= r else 0

for i in range(mid + 2, r + 1):

curr += numsi

right_cross = max(right_cross, curr)

cross_max = left_cross + right_cross

return max(left_max, right_max, cross_max)

return divide_conquer(0, len(nums) - 1) if nums else 0

def closest_pair(self, points: ListList\[int]) -> float:

"""最近点对问题"""

import math

def distance(p1, p2):

return math.sqrt((p10 - p20)**2 + (p11 - p21)**2)

def brute_force(pts):

min_dist = float('inf')

for i in range(len(pts)):

for j in range(i + 1, len(pts)):

min_dist = min(min_dist, distance(ptsi, ptsj))

return min_dist

def strip_closest(strip, d):

min_dist = d

strip.sort(key=lambda p: p1)

for i in range(len(strip)):

j = i + 1

只需检查y坐标差小于d的点

while j < len(strip) and (stripj1 - stripi1) < min_dist:

min_dist = min(min_dist, distance(stripi, stripj))

j += 1

return min_dist

def divide_conquer(pts):

n = len(pts)

if n <= 3:

return brute_force(pts)

mid = n // 2

mid_point = ptsmid

dl = divide_conquer(pts:mid)

dr = divide_conquer(ptsmid:)

d = min(dl, dr)

检查跨越分割线的点对

strip = \[\]

for p in pts:

if abs(p0 - mid_point0) < d:

strip.append(p)

return min(d, strip_closest(strip, d))

points.sort(key=lambda p: p0)

return divide_conquer(points)

4.最长回文子串
复制代码
class Solution:
    def longestPalindrome(self, s: str) -> str:
        def expand(l, r):
            while l >= 0 and r < len(s) and s[l] == s[r]:
                l -= 1
                r += 1
            return s[l + 1:r]
        
        res = ""
        for i in range(len(s)):
            # 奇数中心:i 是中心字符
            odd = expand(i, i)
            # 偶数中心:i 和 i+1 是中心两侧
            even = expand(i, i + 1)
            # 取最长
            res = max(res, odd, even, key=len)
        
        return res
复制代码
class Solution:
    def longestPalindrome(self, s: str) -> str:
        n = len(s)
        if n < 2:
            return s
        
        # dp[i][j] 表示 s[i:j+1] 是否是回文
        dp = [[False] * n for _ in range(n)]
        start = 0
        max_len = 1
        
        # 初始化:所有单个字符都是回文
        for i in range(n):
            dp[i][i] = True
        
        # 先枚举右边界 j
        for j in range(1, n):
            # 再枚举左边界 i (从 0 到 j-1)
            for i in range(j):
                # 核心判断:首尾字符是否相等
                if s[i] == s[j]:
                    # 如果子串长度 <= 3,一定是回文
                    if j - i < 3:
                        dp[i][j] = True
                    else:
                        # 否则取决于去掉首尾后的子串
                        dp[i][j] = dp[i + 1][j - 1]
                else:
                    dp[i][j] = False
                
                # 如果是回文且更长,更新结果
                if dp[i][j] and (j - i + 1) > max_len:
                    max_len = j - i + 1
                    start = i
        
        return s[start:start + max_len]
5.KMP算法

最短回文串

def shortestPalindrome_optimized(s: str) -> str:

"""优化版本,减少内存使用"""

if not s:

return ""

rev_s = s::-1

寻找s + "#" + rev_s的前缀函数

pattern = s + "#" + rev_s

只计算前缀函数

n = len(pattern)

next_arr = 0 * n

for i in range(1, n):

j = next_arri - 1

while j > 0 and patterni != patternj:

j = next_arrj - 1

if patterni == patternj:

j += 1

next_arri = j

最长回文前缀长度

palindrome_len = next_arr-1

需要添加的前缀

add_front = spalindrome_len:::-1

return add_front + s

class KMP:

"""KMP算法完整模板"""

def build_next(self, pattern: str) -> Listint:

"""构建next数组(前缀函数)"""

n = len(pattern)

next_arr = 0 * n

j = 0

for i in range(1, n):

while j > 0 and patterni != patternj:

j = next_arrj - 1

if patterni == patternj:

j += 1

next_arri = j

return next_arr

def search(self, text: str, pattern: str) -> Listint:

"""在text中搜索pattern的所有出现位置"""

if not pattern:

return 0

next_arr = self.build_next(pattern)

m, n = len(text), len(pattern)

result = \[\]

j = 0

for i in range(m):

while j > 0 and texti != patternj:

j = next_arrj - 1

if texti == patternj:

j += 1

if j == n:

result.append(i - n + 1)

j = next_arrj - 1

return result

def repeated_substring_pattern(self, s: str) -> bool:

"""判断字符串是否由重复子串构成"""

n = len(s)

if n <= 1:

return False

next_arr = self.build_next(s)

如果存在重复子串,则n % (n - next_arr-1) == 0

return next_arr-1 != 0 and n % (n - next_arr-1) == 0

6.二叉树最近公共祖先

(包含节点不在树中的情况)

Definition for a binary tree node.

class TreeNode:

def init(self, x):

self.val = x

self.left = None

self.right = None

#哈希表法

class Solution:

def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':

if not root:

return None

parents = {root:None}

stack = root

while stack and (p not in parents or q not in parents):

node = stack.pop()

if node.left:

parentsnode.left = node

stack.append(node.left)

if node.right:

parentsnode.right = node

stack.append(node.right)

if p not in parents or q not in parents:

return None

ancestor = set()

while p:

ancestor.add(p)

p = parentsp

while q not in ancestor:

q = parentsq

return q

#递归法

复制代码
class Solution:
    def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':
        self.find_p = False
        self.find_q = False

        def f(root):
            if not root:
                return None

            left = f(root.left)
            right = f(root.right)

            if root==p:
                self.find_p=True
                return root
            if root==q:
                self.find_q=True
                return root


            if left and right:
                return root
            
            return left if left else right

        res = f(root)
        return res if self.find_p and self.find_q else None
深度优先搜索

class Solution:

def pathSum(self, root: OptionalTreeNode, targetSum: int) -> ListList\[int]:

result = \[\]

path = \[\]

def dfs(node: OptionalTreeNode, target: int):

if not node: return

path.append(node.val)

if not node.left and not node.right:

if target == node.val:

result.append(path:)

if node.left:

dfs(node.left, target - node.val)

if node.right:

dfs(node.right, target - node.val)

path.pop()

dfs(root, targetSum)

return result

第三题

1.队列/栈

from collections import deque

class StackQueueTemplates:

"""栈和队列常用模板"""

def monotonic_stack(self, nums: Listint) -> Listint:

"""单调栈:下一个更大元素"""

n = len(nums)

result = -1 * n

stack = \[\] # 存储索引

for i in range(n):

while stack and numsi > numsstack\[-1]:

idx = stack.pop()

resultidx = numsi

stack.append(i)

return result

def sliding_window_queue(self, nums: Listint, k: int) -> Listint:

"""单调队列实现滑动窗口最大值"""

n = len(nums)

if n == 0:

return \[\]

dq = deque()

result = \[\]

for i in range(n):

移除不在窗口中的元素

if dq and dq0 == i - k:

dq.popleft()

维护单调递减队列

while dq and numsdq\[-1] < numsi:

dq.pop()

dq.append(i)

当窗口形成时记录结果

if i >= k - 1:

result.append(numsdq\[0])

return result

def daily_temperatures(self, temperatures: Listint) -> Listint:

"""每日温度(下一个更高温度的天数)"""

n = len(temperatures)

result = 0 * n

stack = \[\] # 存储(温度, 索引)

for i, temp in enumerate(temperatures):

while stack and temp > stack-10:

_, idx = stack.pop()

resultidx = i - idx

stack.append((temp, i))

return result

2.树构建

中序遍历与后续遍历还原树

class Solution:

def buildTree(self, inorder: Listint, postorder: Listint) -> OptionalTreeNode:

dic = {val:index for index,val in enumerate(inorder)}

def build(in_l,in_r,post_l,post_r):

if in_l>in_r or post_l>post_r:

return None

value = postorderpost_r

index = dicvalue

length = index-in_l

root = TreeNode(value)

root.left = build(in_l,index-1,post_l,post_l+length-1)

root.right = build(index+1,in_r,post_l+length,post_r-1)

return root

return build(0,len(inorder)-1,0,len(inorder)-1)

3.散列/哈希表

class HashTemplates:

"""哈希表相关模板"""

def two_sum(self, nums: Listint, target: int) -> Listint:

"""两数之和"""

hashmap = {}

for i, num in enumerate(nums):

complement = target - num

if complement in hashmap:

return hashmap\[complement, i]

hashmapnum = i

return \[\]

def subarray_sum(self, nums: Listint, k: int) -> int:

"""和为K的子数组个数"""

prefix_sum = 0

count = 0

hashmap = {0: 1} # 前缀和为0出现了1次

for num in nums:

prefix_sum += num

查找是否存在prefix_sum - k的前缀和

if prefix_sum - k in hashmap:

count += hashmapprefix_sum - k

更新当前前缀和的出现次数

hashmapprefix_sum = hashmap.get(prefix_sum, 0) + 1

return count

def longest_consecutive(self, nums: Listint) -> int:

"""最长连续序列"""

if not nums:

return 0

num_set = set(nums)

longest = 0

for num in num_set:

只从序列的起点开始计算

if num - 1 not in num_set:

current_num = num

current_streak = 1

while current_num + 1 in num_set:

current_num += 1

current_streak += 1

longest = max(longest, current_streak)

return longest

第四题

1.最短路径-树

import heapq

from typing import List, Tuple

class GraphShortestPath:

"""图的最短路径模板"""

def dijkstra(self, n: int, edges: ListList\[int], start: int) -> Listint:

"""

Dijkstra算法 - 邻接表实现

edges: \[u, v, w, ...] u->v权重w

"""

构建邻接表

graph = \[ for _ in range(n)]

for u, v, w in edges:

graphu.append((v, w))

graphv.append((u, w)) # 无向图

初始化距离数组

dist = float('inf') * n

diststart = 0

优先队列 (距离, 节点)

pq = (0, start)

while pq:

current_dist, u = heapq.heappop(pq)

如果当前距离大于记录的距离,跳过

if current_dist > distu:

continue

遍历邻居

for v, w in graphu:

new_dist = current_dist + w

if new_dist < distv:

distv = new_dist

heapq.heappush(pq, (new_dist, v))

return dist

def floyd_warshall(self, n: int, edges: ListList\[int]) -> ListList\[int]:

"""Floyd-Warshall算法"""

初始化距离矩阵

dist = \[float('inf') * n for _ in range(n)]

for i in range(n):

distii = 0

for u, v, w in edges:

distuv = w

distvu = w # 无向图

动态规划

for k in range(n):

for i in range(n):

for j in range(n):

if distik + distkj < distij:

distij = distik + distkj

return dist

def kruskal_mst(self, n: int, edges: ListList\[int]) -> ListList\[int]:

"""Kruskal算法求最小生成树"""

按权重排序

edges.sort(key=lambda x: x2)

parent = list(range(n))

def find(x):

if parentx != x:

parentx = find(parentx)

return parentx

def union(x, y):

root_x = find(x)

root_y = find(y)

if root_x != root_y:

parentroot_x = root_y

return True

return False

mst = \[\]

total_weight = 0

for u, v, w in edges:

if union(u, v):

mst.append(u, v, w)

total_weight += w

if len(mst) == n - 1:

break

return mst, total_weight

2.树上的公共节点

class TreeLCA:

"""树上最近公共祖先模板"""

def init(self, n: int, edges: ListList\[int], root: int = 0):

self.n = n

self.log = (n).bit_length()

self.parent = \[-1 * n for _ in range(self.log)]

self.depth = 0 * n

构建邻接表

self.adj = \[ for _ in range(n)]

for u, v in edges:

self.adju.append(v)

self.adjv.append(u)

BFS初始化深度和父节点

stack = root

visited = False * n

visitedroot = True

while stack:

u = stack.pop()

for v in self.adju:

if not visitedv:

visitedv = True

self.parent0v = u

self.depthv = self.depthu + 1

stack.append(v)

倍增预处理

for k in range(1, self.log):

for v in range(n):

if self.parentk-1v != -1:

self.parentkv = self.parentk-1self.parent\[k-1v]

def lca(self, u: int, v: int) -> int:

"""查询u和v的最近公共祖先"""

if self.depthu < self.depthv:

u, v = v, u

将u提到和v同一深度

diff = self.depthu - self.depthv

for k in range(self.log):

if diff & (1 << k):

u = self.parentku

if u == v:

return u

同时向上跳

for k in range(self.log - 1, -1, -1):

if self.parentku != self.parentkv:

u = self.parentku

v = self.parentkv

return self.parent0u

def distance(self, u: int, v: int) -> int:

"""计算树上两点间距离"""

lca_node = self.lca(u, v)

return self.depthu + self.depthv - 2 * self.depthlca_node

Definition for a binary tree node.

class TreeNode:

def init(self, x):

self.val = x

self.left = None

self.right = None

class Solution:

def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':

if not root:

return None

parents = {root:None}

stack = root

while stack and (p not in parents or q not in parents):

node = stack.pop()

if node.left:

parentsnode.left = node

stack.append(node.left)

if node.right:

parentsnode.right = node

stack.append(node.right)

if p not in parents or q not in parents:

return None

ancestor = set()

while p:

ancestor.add(p)

p = parentsp

while q not in ancestor:

q = parentsq

return q

class Solution:

def lowestCommonAncestor(self, root: TreeNode, p: TreeNode, q: TreeNode) -> TreeNode:

find_p = False

find_q = False

def f(root):

nonlocal find_p

nonlocal find_q

if not root:

return None

left = f(root.left)

right = f(root.right)

if root==p:

find_p = True

return root

if root==q:

find_q = True

return root

if left and right:

return root

return left if left else right

res = f(root)

return res if find_p and find_q else None

3.二分排序

class BinarySearch:

"""二分查找通用模板"""

def binary_search(self, nums: Listint, target: int) -> int:

"""标准二分查找"""

left, right = 0, len(nums) - 1

while left <= right:

mid = left + (right - left) // 2

if numsmid == target:

return mid

elif numsmid < target:

left = mid + 1

else:

right = mid - 1

return -1

def lower_bound(self, nums: Listint, target: int) -> int:

"""第一个大于等于target的位置"""

left, right = 0, len(nums)

while left < right:

mid = left + (right - left) // 2

if numsmid >= target:

right = mid

else:

left = mid + 1

return left

def upper_bound(self, nums: Listint, target: int) -> int:

"""第一个大于target的位置"""

left, right = 0, len(nums)

while left < right:

mid = left + (right - left) // 2

if numsmid > target:

right = mid

else:

left = mid + 1

return left

def search_range(self, nums: Listint, target: int) -> Listint:

"""在排序数组中查找元素的第一个和最后一个位置"""

start = self.lower_bound(nums, target)

if start == len(nums) or numsstart != target:

return -1, -1

end = self.upper_bound(nums, target) - 1

return start, end

def binary_search_answer(self, nums: Listint, k: int) -> int:

"""

二分答案模板

问题:将数组分成k个子数组,最小化最大子数组和

"""

def can_split(max_sum):

"""检查是否能在最大子数组和为max_sum的情况下分成k个子数组"""

current_sum = 0

count = 1

for num in nums:

if current_sum + num > max_sum:

count += 1

current_sum = num

if count > k:

return False

else:

current_sum += num

return True

left, right = max(nums), sum(nums)

while left < right:

mid = left + (right - left) // 2

if can_split(mid):

right = mid

else:

left = mid + 1

return left

4.容斥原理

class InclusionExclusion:

"""容斥原理模板"""

def count_divisible(self, n: int, divisors: Listint) -> int:

"""

计算1到n中能被至少一个除数整除的数的个数

"""

m = len(divisors)

total = 0

遍历所有非空子集

for mask in range(1, 1 << m):

lcm_val = 1

bits = 0

for i in range(m):

if mask & (1 << i):

bits += 1

lcm_val = self.lcm(lcm_val, divisorsi)

if lcm_val > n: # 超过范围,可以提前终止

break

计算交集大小

count = n // lcm_val

根据子集大小添加或减去

if bits % 2 == 1:

total += count

else:

total -= count

return total

def lcm(self, a: int, b: int) -> int:

"""最小公倍数"""

return a * b // self.gcd(a, b)

def gcd(self, a: int, b: int) -> int:

"""最大公约数"""

while b:

a, b = b, a % b

return a

def count_coprime(self, n: int, m: int) -> int:

"""计算1到m中与n互质的数的个数(欧拉函数)"""

质因数分解

temp = n

primes = \[\]

p = 2

while p * p <= temp:

if temp % p == 0:

primes.append(p)

while temp % p == 0:

temp //= p

p += 1

if temp > 1:

primes.append(temp)

容斥原理计算

total = 0

k = len(primes)

for mask in range(1, 1 << k):

product = 1

bits = 0

for i in range(k):

if mask & (1 << i):

bits += 1

product *= primesi

count = m // product

if bits % 2 == 1:

total += count

else:

total -= count

return m - total

5.状态机DP

LeetCode1397. 找到所有好字符串,1931. 用三种颜色给网格涂色,1220. 统计元音字母序列

LeetCode123.买卖股票的最佳时机 III,714.买卖股票的最佳时机含手续费,309.最佳买卖股票时机含冷冻期,72.编辑距离

第五题

动态规划

1.背包问题(完全/不完全)

class Knapsack:

"""背包问题完整模板"""

def zero_one_knapsack(self, weights: Listint, values: Listint, capacity: int) -> int:

"""01背包问题"""

n = len(weights)

dp = 0 * (capacity + 1)

for i in range(n):

逆向遍历,确保每个物品只选一次

for w in range(capacity, weightsi - 1, -1):

dpw = max(dpw, dpw - weights\[i] + valuesi)

return dpcapacity

def unbounded_knapsack(self, weights: Listint, values: Listint, capacity: int) -> int:

"""完全背包问题"""

n = len(weights)

dp = 0 * (capacity + 1)

for i in range(n):

正向遍历,允许重复选择

for w in range(weightsi, capacity + 1):

dpw = max(dpw, dpw - weights\[i] + valuesi)

return dpcapacity

def multi_knapsack(self, weights: Listint, values: Listint, counts: Listint, capacity: int) -> int:

"""多重背包问题(二进制优化)"""

n = len(weights)

二进制拆分

new_weights = \[\]

new_values = \[\]

for i in range(n):

k = 1

remaining = countsi

while remaining >= k:

new_weights.append(weightsi * k)

new_values.append(valuesi * k)

remaining -= k

k <<= 1

if remaining > 0:

new_weights.append(weightsi * remaining)

new_values.append(valuesi * remaining)

01背包

return self.zero_one_knapsack(new_weights, new_values, capacity)

def knapsack_scheme(self, weights: Listint, values: Listint, capacity: int) -> Listint:

"""输出具体方案"""

n = len(weights)

dp = \[0 * (capacity + 1) for _ in range(n + 1)]

for i in range(1, n + 1):

for w in range(capacity + 1):

if weightsi-1 <= w:

dpiw = max(dpi-1w, dpi-1w - weights\[i-1] + valuesi-1)

else:

dpiw = dpi-1w

回溯找方案

res = \[\]

w = capacity

for i in range(n, 0, -1):

if dpiw != dpi-1w:

res.append(i-1)

w -= weightsi-1

return res::-1 # 返回选择的物品索引

编辑距离
2.反悔贪心

LeetCode 630

import heapq

class RegretGreedy:

"""反悔贪心模板"""

def schedule_course(self, courses: ListList\[int]) -> int:

"""

课程安排III:选择最多的课程

coursesi = duration, lastDay

"""

按截止时间排序

courses.sort(key=lambda x: x1)

max_heap = \[\] # 最大堆,存储已选课程的持续时间

current_time = 0

for duration, last_day in courses:

if current_time + duration <= last_day:

heapq.heappush(max_heap, -duration)

current_time += duration

elif max_heap and -max_heap0 > duration:

反悔:替换掉持续时间最长的课程

longest = -heapq.heappop(max_heap)

current_time = current_time - longest + duration

heapq.heappush(max_heap, -duration)

return len(max_heap)

def max_profit_jobs(self, startTime: Listint, endTime: Listint, profit: Listint) -> int:

"""最大收益工作安排"""

jobs = sorted(zip(startTime, endTime, profit), key=lambda x: x1)

n = len(jobs)

dp = 0 * (n + 1)

for i in range(1, n + 1):

s, e, p = jobsi-1

找到结束时间不超过s的最后一个工作

j = i - 1

while j >= 1 and jobsj-11 > s:

j -= 1

选择当前工作或不选

dpi = max(dpi-1, dpj + p)

return dpn

3.最大堆

class HeapTemplates:

"""堆操作模板"""

def median_finder(self):

"""数据流的中位数(双堆法)"""

import heapq

class MedianFinder:

def init(self):

self.small = \[\] # 最大堆(用负数实现)

self.large = \[\] # 最小堆

def addNum(self, num: int) -> None:

if len(self.small) == len(self.large):

heapq.heappush(self.large, -heapq.heappushpop(self.small, -num))

else:

heapq.heappush(self.small, -heapq.heappushpop(self.large, num))

def findMedian(self) -> float:

if len(self.small) == len(self.large):

return (-self.small0 + self.large0) / 2

else:

return self.large0

return MedianFinder()

def kth_largest(self, nums: Listint, k: int) -> int:

"""第K大的元素"""

import heapq

min_heap = \[\]

for num in nums:

heapq.heappush(min_heap, num)

if len(min_heap) > k:

heapq.heappop(min_heap)

return min_heap0 if min_heap else -1

def merge_k_sorted(self, lists: ListList\[int]) -> Listint:

"""合并K个有序链表/数组"""

import heapq

heap = \[\]

初始化堆,存储(值, 列表索引, 元素索引)

for i, lst in enumerate(lists):

if lst:

heapq.heappush(heap, (lst0, i, 0))

result = \[\]

while heap:

val, list_idx, elem_idx = heapq.heappop(heap)

result.append(val)

如果当前列表还有下一个元素

if elem_idx + 1 < len(listslist_idx):

next_val = listslist_idxelem_idx + 1

heapq.heappush(heap, (next_val, list_idx, elem_idx + 1))

return result

4.如何找树的割点

class ArticulationPoints:

"""寻找无向图的割点(Tarjan算法)"""

def find_cut_points(self, n: int, edges: ListList\[int]) -> Listint:

"""

寻找图中的所有割点

割点:移除该点后,图的连通分量数增加

"""

构建邻接表

graph = \[ for _ in range(n)]

for u, v in edges:

graphu.append(v)

graphv.append(u)

visited = False * n

disc = 0 * n # 发现时间

low = 0 * n # 可回溯到的最早发现时间

parent = -1 * n

time = 0

articulation = False * n

def dfs(u):

nonlocal time

children = 0

visitedu = True

discu = lowu = time

time += 1

for v in graphu:

if not visitedv:

children += 1

parentv = u

dfs(v)

更新low值

lowu = min(lowu, lowv)

判断是否为割点

1. 根节点且有两个以上子节点

if parentu == -1 and children > 1:

articulationu = True

2. 非根节点且lowv >= discu

if parentu != -1 and lowv >= discu:

articulationu = True

elif v != parentu: # 回退边

lowu = min(lowu, discv)

for i in range(n):

if not visitedi:

dfs(i)

return i for i in range(n) if articulation\[i]

5.树删除节点后如何划分

树上子树区间(Euler 序)与前缀和

class TreeSubtree:

"""子树统计与划分"""

def init(self, n: int, edges: ListList\[int]):

self.n = n

self.adj = \[ for _ in range(n)]

for u, v in edges:

self.adju.append(v)

self.adjv.append(u)

欧拉序

self.euler_in = 0 * n

self.euler_out = 0 * n

self.euler_path = \[\]

self.time = 0

self.dfs_euler(0, -1)

子树大小

self.subtree_size = 0 * n

self.dfs_size(0, -1)

def dfs_euler(self, u: int, parent: int):

"""DFS求欧拉序"""

self.euler_inu = self.time

self.euler_path.append(u)

self.time += 1

for v in self.adju:

if v != parent:

self.dfs_euler(v, u)

self.euler_outu = self.time - 1

def dfs_size(self, u: int, parent: int) -> int:

"""计算子树大小"""

size = 1

for v in self.adju:

if v != parent:

size += self.dfs_size(v, u)

self.subtree_sizeu = size

return size

def is_ancestor(self, u: int, v: int) -> bool:

"""判断u是否是v的祖先(欧拉序)"""

return self.euler_inu <= self.euler_inv <= self.euler_outu

def subtree_range(self, u: int) -> Tupleint, int:

"""返回节点u的子树在欧拉序中的范围"""

return (self.euler_inu, self.euler_outu)

def remove_node_partition(self, u: int) -> Listint:

"""

删除节点u后,划分成的各个连通块的大小

注意:删除节点后,除了子树外,还有父节点所在的连通块

"""

sizes = \[\]

父节点所在的连通块大小

parent_component = self.n - self.subtree_sizeu

if parent_component > 0:

sizes.append(parent_component)

每个子节点的子树大小

for v in self.adju:

if self.is_ancestor(v, u): # v是u的子节点

sizes.append(self.subtree_sizev)

return sizes

def find_centroid(self) -> Listint:

"""寻找树的重心(删除后最大子树最小的节点)"""

centroids = \[\]

min_max_subtree = self.n

def dfs_centroid(u: int, parent: int):

nonlocal min_max_subtree, centroids

max_subtree = 0

total = 1

for v in self.adju:

if v != parent:

dfs_centroid(v, u)

subtree_size = self.subtree_sizev

max_subtree = max(max_subtree, subtree_size)

total += subtree_size

父节点所在的连通块

parent_subtree = self.n - total

max_subtree = max(max_subtree, parent_subtree)

if max_subtree < min_max_subtree:

min_max_subtree = max_subtree

centroids = u

elif max_subtree == min_max_subtree:

centroids.append(u)

return total

dfs_centroid(0, -1)

return centroids

6.动态树/查询子树

class DynamicTree:

"""支持子树更新的动态树"""

def init(self, n: int, values: Listint, edges: ListList\[int]):

self.n = n

self.values = values

构建邻接表

self.adj = \[ for _ in range(n)]

for u, v in edges:

self.adju.append(v)

self.adjv.append(u)

欧拉序

self.in_time = 0 * n

self.out_time = 0 * n

self.euler = \[\]

self.time = 0

self.dfs_euler(0, -1)

树状数组(Fenwick Tree)

self.bit = 0 * (n + 2)

初始化树状数组

for i, node in enumerate(self.euler):

self._add(i + 1, self.valuesnode)

def dfs_euler(self, u: int, parent: int):

"""DFS求欧拉序"""

self.in_timeu = self.time

self.euler.append(u)

self.time += 1

for v in self.adju:

if v != parent:

self.dfs_euler(v, u)

self.out_timeu = self.time - 1

def _add(self, idx: int, delta: int):

"""树状数组更新"""

while idx <= self.n:

self.bitidx += delta

idx += idx & -idx

def _sum(self, idx: int) -> int:

"""树状数组前缀和"""

res = 0

while idx > 0:

res += self.bitidx

idx -= idx & -idx

return res

def update_subtree(self, u: int, delta: int):

"""更新节点u的整个子树"""

l = self.in_timeu + 1 # 树状数组从1开始

r = self.out_timeu + 1

self._add(l, delta)

if r + 1 <= self.n:

self._add(r + 1, -delta)

def query_subtree(self, u: int) -> int:

"""查询节点u的子树和"""

idx = self.in_timeu + 1

return self._sum(idx)

def update_node(self, u: int, new_val: int):

"""更新单个节点的值"""

old_val = self.valuesu

delta = new_val - old_val

self.valuesu = new_val

self.update_subtree(u, delta)

def query_path(self, u: int, v: int) -> int:

"""查询u到v路径上的和(需要LCA)"""

实现路径查询需要LCA,这里给出简化版本

实际应用中可能需要树链剖分

pass

一些细节问题

1.count = defaultdict(int)指定默认值类型

class Utils:

"""常用工具函数"""

@staticmethod

def quick_sort(nums: Listint) -> Listint:

"""快速排序"""

if len(nums) <= 1:

return nums

pivot = numslen(nums) // 2

left = x for x in nums if x \< pivot

middle = x for x in nums if x == pivot

right = x for x in nums if x \> pivot

return Utils.quick_sort(left) + middle + Utils.quick_sort(right)

@staticmethod

def union_find(n: int):

"""并查集模板"""

parent = list(range(n))

size = 1 * n

def find(x: int) -> int:

if parentx != x:

parentx = find(parentx)

return parentx

def union(x: int, y: int) -> bool:

root_x = find(x)

root_y = find(y)

if root_x == root_y:

return False

按秩合并

if sizeroot_x < sizeroot_y:

root_x, root_y = root_y, root_x

parentroot_y = root_x

sizeroot_x += sizeroot_y

return True

def connected(x: int, y: int) -> bool:

return find(x) == find(y)

return find, union, connected

@staticmethod

def topological_sort(n: int, edges: ListList\[int]) -> Listint:

"""拓扑排序"""

from collections import deque

graph = \[ for _ in range(n)]

indegree = 0 * n

for u, v in edges:

graphu.append(v)

indegreev += 1

queue = deque(i for i in range(n) if indegree\[i == 0])

result = \[\]

while queue:

u = queue.popleft()

result.append(u)

for v in graphu:

indegreev -= 1

if indegreev == 0:

queue.append(v)

return result if len(result) == n else \[\] # 有环则返回空

相关推荐
8Qi83 小时前
回文子串(Palindromic Substrings)—— 题解
算法·leetcode·职场和发展·动态规划
珺毅同学6 小时前
YOLO生成预测json标签迁移问题
python·yolo·json
骑士雄师6 小时前
18.4 长期记忆可修改版
python
~小先生~6 小时前
Python从入门到放弃(一)
开发语言·python
天佑木枫7 小时前
第2天:变量与数据类型 —— 让程序记住信息
python
小宋加油啊7 小时前
机械臂抓取物体 PVN3D算法调研学习
学习·算法·3d
lqqjuly7 小时前
前沿算法深度解析(一)
算法
Dust-Chasing8 小时前
Claude Code源码剖析 - Claude Code 上下文压缩机制
人工智能·python·ai
小欣加油8 小时前
leetcode1926 迷宫中离入口最近的出口
数据结构·c++·算法·leetcode·职场和发展