AVL 树和列表实现 Map:单词频率统计项目
项目简介
这是一个学习数据结构的作业项目,通过实现三种不同的 Map(字典)结构来统计文本文件中单词的出现频率。项目展示了继承 和多态在实际代码中的应用。
核心组件
1. AVL 树基础 (avl.py)
提供了完整的 AVL 平衡二叉搜索树实现,包括:
- 左旋/右旋操作
- 节点高度维护
- 平衡因子计算
- 插入和删除后的自动重平衡
2. AVLMap (avl_map.py)
继承 自 AVLTree,将节点存储改为 [key, value] 列表:
_insert_or_update():插入或更新键值对__getitem__/__setitem__:支持map[key]语法inorder_generator():中序遍历生成器- 利用 AVL 树特性,保持 O(log n) 的查找、插入和删除性能
3. ListMap (list_map.py)
基于有序列表的 Map 实现:
_binary_search():二分查找核心算法- 查找 O(log n),但插入/删除 O(n)
- 适合小规模数据或读多写少的场景
4. 单词统计程序 (word_count.py)
统一的接口处理三种 Map:
- 使用正则表达式
re.split(r'[\W_]+', line)分割单词 - 过滤非字母单词,统一转小写
- 对 dict、AVLMap、ListMap 使用相同的操作方式
技术亮点
多态的魅力 :count_words() 函数完全不关心传入的是哪种 Map,只要实现了 in、[]、[]= 操作即可工作。
性能对比
| 结构 | 查找 | 插入 | 实际表现 |
|---|---|---|---|
| dict (哈希表) | O(1) | O(1) | 最快 |
| AVLMap | O(log n) | O(log n) | 中等 |
| ListMap | O(log n) | O(n) | 较慢 |
调试经验
- pop() 方法的坑 :AVLTree 的
remove()期望接收完整的[key, value]列表,不能只传字符串 key - 管道错误 :
head -n5会提前关闭管道,需要处理BrokenPipeError
测试运行
bash
# 单元测试
python3 avl_test.py
python3 avl_map_test.py
python3 list_map_test.py
python3 word_count_test.py
# 统计单词 //ulysses.txt 超量数据,可自行创建
./word_count.py dict fruit.txt
./word_count.py dict quotes.txt | head -n10
time ./word_count.py dict ulysses.txt | sort -rn | head -n5
./word_count.py avl fruit.txt
./word_count.py avl quotes.txt | head -n10
time ./word_count.py avl ulysses.txt | sort -rn | head -n5
./word_count.py bst fruit.txt
./word_count.py bst quotes.txt | head -n10
time ./word_count.py bst ulysses.txt | sort -rn | head -n5
./word_count.py list fruit.txt
./word_count.py list quotes.txt | head -n10
time ./word_count.py list ulysses.txt | sort -rn | head -n5
项目意义
这个项目很好地演示了:
- 如何通过继承复用已有代码
- 如何设计统一接口让不同数据结构可互换
- AVL 树如何保持平衡以维持高效性能
- 列表 + 二分查找的 trade-off
bash
avl.py
from tree_print import pretty_tree
class Node:
def __init__(self, key, left=None, right=None):
self.key = key
self.left = left
self.right = right
self.height = 0
class AVLTree:
# Constructor
# Optionally, you can initialize the tree with a root node
# and specify whether the tree should be balanced
def __init__(self, root=None, do_balance=True):
self.root = root
self.do_balance = do_balance
# Clear the tree
def clear(self):
self.root = None
# Helper function to get the height of a node for AVL balancing
# Note: The height of a null node is -1
def _get_height(self, root):
if not root:
return -1
return root.height
# Helper function to update the height of a node
def _update_height(self, node):
# TODO
if node:
node.height = 1 + max(self._get_height(node.left), self._get_height(node.right))
# Helper function to get the balance factor of a node
# Note: The balance factor of a null node is 0
def _balance_factor(self, root):
if not root:
return 0
return self._get_height(root.left) - self._get_height(root.right)
# Helper functions to rotate left for AVL balancing
def _rotate_left(self, z):
"""
Example:
Given the following tree:
z
/ \
T1 y
/ \
T2 T3
After _left_rotate(z), the tree becomes:
y
/ \
z T3
/ \
T1 T2
"""
# TODO
# Do the rotation
y = z.right
T2 = y.left
y.left = z
z.right = T2
# Update the heights
self._update_height(z)
self._update_height(y)
# Return the new root
return y
# Helper functions to rotate right for AVL balancing
def _rotate_right(self, z):
"""
Example:
Given the following tree:
z
/ \
y T3
/ \
T1 T2
After _right_rotate(z), the tree becomes:
y
/ \
T1 z
/ \
T2 T3
"""
# Do the rotation
y = z.left
T2 = y.right
y.right = z
z.left = T2
# Update the heights
self._update_height(z)
self._update_height(y)
# Return the new root
return y
# Helper function to rebalance the tree after insertion or removal
def _balance(self, node):
# TODO
# update height
self._update_height(node)
# left heavy
if self._balance_factor(node) > 1:
# left child is right heavy
if self._balance_factor(node.left) < 0:
node.left = self._rotate_left(node.left)
return self._rotate_right(node)
# right heavy
if self._balance_factor(node) < -1:
# right child is left heavy
if self._balance_factor(node.right) > 0:
node.right = self._rotate_right(node.right)
return self._rotate_left(node)
return node
# Insert a Node with a given key into the tree
def insert(self, key):
self.root = self._insert(self.root, key)
# Helper function for insert
def _insert(self, root, key):
# Regular BST insertion
if root is None:
return Node(key)
if key == root.key:
return root
if key < root.key:
root.left = self._insert(root.left, key)
else:
root.right = self._insert(root.right, key)
if self.do_balance:
# Rebalance the tree
return self._balance(root)
else:
return root
# Remove a Node with a given key from the tree
def remove(self, key):
self.root = self._remove(self.root, key)
# Helper function for remove
def _remove(self, root, key):
# Regular BST removal
if root is None:
return root
# Key is not yet found
if key < root.key:
root.left = self._remove(root.left, key)
elif key > root.key:
root.right = self._remove(root.right, key)
# Key is found
else:
# Node with only one child or leaf node: return the non-null child
# If the node has no children, return None
if root.left is None:
return root.right
if root.right is None:
return root.left
# Node with two children: Get the inorder successor (smallest in the right subtree)
root.key = self._min_value_node(root.right)
# Delete the inorder successor
root.right = self._remove(root.right, root.key)
if self.do_balance:
# Rebalance the tree
return self._balance(root)
else:
return root
# Helper function to find the minimum value node in a tree
def _min_value_node(self, root):
current = root
while current.left is not None:
current = current.left
return current.key
# Write the BFS traversal of the tree to a list
def write_bfs(self):
# If the tree is empty, return an empty list
if self.root is None:
return []
# Push the root node to the queue
queue = [self.root]
# List to store the BFS traversal results
bfs = []
# While there are nodes to process
while queue:
# Pop the first node from the queue
node = queue.pop(0)
# If the node is None (missing children), append None to the BFS list
if node is None:
bfs.append(None)
# If the node is not None, append its key to the results and push its children to the queue
else:
bfs.append(node.key)
queue.append(node.left)
queue.append(node.right)
# Remove trailing None values
while bfs and bfs[-1] is None:
bfs.pop()
# Return the BFS traversal list
return bfs
# Magic method: string representation of the tree
# Support for the print() function
def __str__(self):
return pretty_tree(self)
bash
avl_map.py
from avl import Node, AVLTree
from typing import Iterator
class AVLMap(AVLTree):
def __init__(self, root: Node=None, do_balance: bool=True) -> None:
""" AVLMap is a subclass of AVLTree
The constructor calls the constructor of the superclass
The constructor takes an optional argument do_balance
to enable or disable balancing
Input: root (Node) - the root of the AVL tree
do_balance (bool) - whether to balance the tree after insertion or deletion
Output: None
"""
super().__init__(root, do_balance)
def _insert_or_update(self, root: Node, key, value) -> Node:
""" Helper method to insert or update a [key, value] pair
It is similar to the insert method in AVLTree, except
that it updates the value if the key is already in the tree
Input: root (Node) - the root of the AVL tree
key (comparable) - the key to insert or update
value (any) - the value to insert or update
Output: Node - the root of the AVL tree
"""
# TODO:
if root is None:
return Node([key, value])
if key == root.key[0]:
root.key[1] = value
return root
elif key < root.key[0]:
root.left = self._insert_or_update(root.left, key, value)
else:
root.right = self._insert_or_update(root.right, key, value)
if self.do_balance:
return self._balance(root)
else:
return root
pass
def __setitem__(self, key, value) -> None:
""" Magic method to support assignment using the [] operator
Input: key (comparable) - the key to insert or update
value (any) - the value to insert or update
Output: None
"""
# TODO:
self.root = self._insert_or_update(self.root, key, value)
pass
def _get_node(self, root: Node, key) -> Node:
""" Helper method to get the node with the key
Used by __getitem__, __contains__, and get
Input: root (Node) - the root of the AVL tree
key (comparable) - the key to search for
Output: Node - the node with the key, or None if not found
"""
# TODO:
if root is None:
return None
if key == root.key[0]:
return root
elif key < root.key[0]:
return self._get_node(root.left, key)
else:
return self._get_node(root.right, key)
pass
def __getitem__(self, key) -> any:
""" Magic method to support retrieval using the [] operator
Raise a KeyError if the key is not found, consistent with the behavior
of the __getitem__ method of the dict class
Uses binary search to locate the index of the key
Input: key (comparable) - the key to search for
Output: any - the value of the key
"""
# TODO:
node = self._get_node(self.root, key)
if node is None:
raise KeyError(key)
return node.key[1]
pass
def get(self, key, default=None) -> any:
""" Get the value of the key, or return the default value if the key is not found,
consistent with the behavior of the get method of the dict class
Input: key (comparable) - the key to search for
default (any) - the default value to return if the key is not found
Output: any - the value of the key or the default value
"""
# TODO:
node = self._get_node(self.root, key)
if node is None:
return default
return node.key[1]
pass
def __contains__(self, key) -> bool:
""" Magic method to support the in operator
Input: key (comparable) - the key to search for
Output: bool - True if the key is found, False otherwise
"""
# TODO:
return self._get_node(self.root, key) is not None
pass
def pop(self, key) -> any:
""" Remove a [key, value] pair from the map and return the value,
consistent with the pop method in the dict class
Raise a KeyError if the key is not found
Input: key (comparable) - the key to remove
Output: any - the value of the key
"""
# TODO:
node = self._get_node(self.root, key)
if node is None:
raise KeyError(key)
value = node.key[1]
self.remove(node.key)
return value
pass
def inorder_generator(self) -> Iterator[list[any]]:
""" A generator that yields the [key, value] pairs in the map
using an inorder traversal
of the dict class.
Input: None
Output: generator - the [key, value] pairs in the map
"""
# TODO:
yield from self._inorder_generator(self.root)
pass
def _inorder_generator(self, root) -> Iterator[list[any]]:
""" Helper recursive generator method that yields the [key, value] pairs
in the map in an inorder traversal
Input: root (Node) - the root of the AVL tree
Output: generator - the [key, value] pairs in the map
"""
# TODO:
if root is None:
return
yield from self._inorder_generator(root.left)
yield root.key
yield from self._inorder_generator(root.right)
pass
def items(self) -> Iterator[list[any]]:
""" A generator that yields the [key, value] pairs in the map"
Usage is consistent with the behavior of the items method
of the dict class."
Hint: You can just yield from the inorder_generator method
Input: None
Output: generator - the [key, value] pairs in the map
"""
# TODO:
yield from self.inorder_generator()
pass
def __iter__(self) -> Iterator[any]:
""" Magic method to support iteration using a generator
This yields the keys only, consistent with
the behavior of the iterator for the dict class.
Hint: Use the inorder_generator method
Input: None
Output: generator - the keys in the map
"""
# TODO:
for pair in self.inorder_generator():
yield pair[0]
pass
bash
list_map.py
from typing import Iterator
class ListMap():
def __init__(self) -> None:
""" Constructor
The data is a list of [key, value] pairs
"""
self.data = []
def _binary_search(self, key) -> tuple[int, int | None]:
""" Helper method to search for the index of the key and its value
This function returns two values as a tuple:
- the index of the key and the value if found: (index, value)
- the index where the key should be inserted and None if not found: (index, None)
Note that in a typical binary search, the index where the key should be inserted
is the minimum index of the search range when the key is not found
Input: key (comparable) - the key to search for
Output: tuple - the index of the key and the value if found,
or the index where the key should be inserted and None if not found
"""
# TODO:
left, right = 0, len(self.data) - 1
while left <= right:
mid = (left + right) // 2
mid_key = self.data[mid][0]
if key == mid_key:
return (mid, self.data[mid][1])
elif key < mid_key:
right = mid - 1
else:
left = mid + 1
return (left, None)
pass
def __setitem__(self, key, value) -> None:
""" Magic method to support assignment using the [] operator
Insert or update the [key, value] pair
Uses binary search to locate the index of the key or insertion point
Input: key (comparable) - the key to insert or update
value (any) - the value to insert or update
Output: None
"""
# TODO:
index, existing_value = self._binary_search(key)
if existing_value is not None:
self.data[index][1] = value
else:
self.data.insert(index, [key, value])
pass
def __getitem__(self, key) -> any:
""" Magic method to support retrieval using the [] operator
Raise a KeyError if the key is not found, consistent with the behavior
of the __getitem__ method of the dict class
Uses binary search to locate the index of the key
Input: key (comparable) - the key to search for
Output: any - the value of the key
"""
# TODO:
index, value = self._binary_search(key)
if value is None:
raise KeyError(key)
return value
pass
def get(self, key, default=None) -> any:
""" Get the value of the key, or return the default value if the key is not found,
consistent with the behavior of the get method of the dict class
Uses binary search to locate the index of the key
Input: key (comparable) - the key to search for
default (any) - the default value to return if the key is not found
Output: any - the value at the key or the default value
"""
# TODO:
index, value = self._binary_search(key)
if value is None:
return default
return value
pass
def __contains__(self, key) -> bool:
""" Magic method to support the in operator
Input: key (comparable) - the key to search for
Output: bool - True if the key is found, False otherwise
"""
# TODO:
index, value = self._binary_search(key)
return value is not None
pass
def pop(self, key) -> any:
""" Remove a [key, value] pair from the map and return the value,
consistent with the pop method in the dict class
Raise a KeyError if the key is not found
Input: key (comparable) - the key to remove
Output: any - the value of the key
"""
# TODO:
index, value = self._binary_search(key)
if value is None:
raise KeyError(key)
self.data.pop(index)
return value
pass
def __iter__(self) -> Iterator:
""" Magic method to support iteration using a generator
Note that this yields the keys only, consistent with
the behavior of the iterator for the dict class
Input: None
Output: generator - a generator that yields the keys
"""
# TODO:
for pair in self.data:
yield pair[0]
pass
def items(self) -> Iterator[list[any]]:
""" A generator that yields the [key, value] pairs in the map
Usage is consistent with the behavior of the items method
of the dict class.
Input: None
Output: generator - the [key, value] pairs in the map
"""
# TODO:
for pair in self.data:
yield pair
pass
bash
word_count.py
#!/usr/bin/env python3
from avl_map import AVLMap
from list_map import ListMap
import sys
import re
def count_words(filename: str, word_count_map: AVLMap | ListMap) -> None:
# TODO:
with open(filename, 'r',encoding ='utf-8') as f:
for line in f:
words = re.split(r'[\W_]+', line)
for word in words:
word = word.lower()
if word and word.isalpha():
if word in word_count_map:
word_count_map[word] += 1
else:
word_count_map[word] = 1
def main():
if len(sys.argv) != 3:
script = sys.argv[0]
sys.stderr.write(f'Usage: {script} [dict|list|bst|avl] filename\n')
sys.exit(1)
map_type = sys.argv[1]
filename = sys.argv[2]
if map_type == 'avl':
word_count_map = AVLMap(do_balance=True)
elif map_type == 'bst':
word_count_map = AVLMap(do_balance=False)
elif map_type == 'list':
word_count_map = ListMap()
elif map_type == 'dict':
word_count_map = {}
else:
sys.stderr.write(f'Invalid map type: {map_type}\n')
sys.exit(1)
count_words(filename, word_count_map)
for word, count in word_count_map.items():
print(f'{count} {word}')
if __name__ == '__main__':
main()
测试代码:
bash
word_count_test.py
#!/usr/bin/env python3
from avl_map import AVLMap
from list_map import ListMap
import unittest
import unittest.mock
import sys
import io
import word_count
class WordCountTest(unittest.TestCase):
AssignmentTotal = 10
Total = 5
Points = 0
@classmethod
def setupClass(cls):
pass
@classmethod
def tearDownClass(cls):
assignment_points = cls.AssignmentTotal * cls.Points / cls.Total
print()
print(f' Score {assignment_points:.2f} / {cls.AssignmentTotal:.2f}')
print(f' Status {"Success" if cls.Points >= cls.Total else "Failure"}')
def test_count_words_dict(self):
word_count_map = {}
word_count.count_words('fruit.txt', word_count_map)
self.assertEqual(word_count_map['apple'], 1)
self.assertEqual(word_count_map['banana'], 2)
self.assertEqual(word_count_map['cherry'], 3)
word_count_map = {}
word_count.count_words('quotes.txt', word_count_map)
self.assertEqual(word_count_map['the'], 9)
self.assertEqual(word_count_map['of'], 6)
self.assertEqual(word_count_map['and'], 6)
self.assertEqual(word_count_map['to'], 4)
WordCountTest.Points += 1
def test_count_words_avl(self):
word_count_map = AVLMap(do_balance=True)
word_count.count_words('fruit.txt', word_count_map)
self.assertEqual(word_count_map['apple'], 1)
self.assertEqual(word_count_map['banana'], 2)
self.assertEqual(word_count_map['cherry'], 3)
word_count_map = AVLMap(do_balance=True)
word_count.count_words('quotes.txt', word_count_map)
self.assertEqual(word_count_map['the'], 9)
self.assertEqual(word_count_map['of'], 6)
self.assertEqual(word_count_map['and'], 6)
self.assertEqual(word_count_map['to'], 4)
WordCountTest.Points += 1
def test_count_words_bst(self):
word_count_map = AVLMap(do_balance=False)
word_count.count_words('fruit.txt', word_count_map)
self.assertEqual(word_count_map['apple'], 1)
self.assertEqual(word_count_map['banana'], 2)
self.assertEqual(word_count_map['cherry'], 3)
word_count_map = AVLMap(do_balance=False)
word_count.count_words('quotes.txt', word_count_map)
self.assertEqual(word_count_map['the'], 9)
self.assertEqual(word_count_map['of'], 6)
self.assertEqual(word_count_map['and'], 6)
self.assertEqual(word_count_map['to'], 4)
WordCountTest.Points += 1
def test_count_words_list(self):
word_count_map = ListMap()
word_count.count_words('fruit.txt', word_count_map)
self.assertEqual(word_count_map['apple'], 1)
self.assertEqual(word_count_map['banana'], 2)
self.assertEqual(word_count_map['cherry'], 3)
word_count_map = ListMap()
word_count.count_words('quotes.txt', word_count_map)
self.assertEqual(word_count_map['the'], 9)
self.assertEqual(word_count_map['of'], 6)
self.assertEqual(word_count_map['and'], 6)
self.assertEqual(word_count_map['to'], 4)
WordCountTest.Points += 1
def test_main(self):
with unittest.mock.patch('sys.stdout', new=io.StringIO()) as output:
sys.argv = ['word_count.py', 'dict', 'fruit.txt']
word_count.main()
outstrings = output.getvalue().splitlines()
self.assertEqual(sorted(outstrings), ['1 apple', '2 banana', '3 cherry'])
with unittest.mock.patch('sys.stdout', new=io.StringIO()) as output:
sys.argv = ['word_count.py', 'avl', 'fruit.txt']
word_count.main()
outstrings = output.getvalue().splitlines()
self.assertEqual(sorted(outstrings), ['1 apple', '2 banana', '3 cherry'])
with unittest.mock.patch('sys.stdout', new=io.StringIO()) as output:
sys.argv = ['word_count.py', 'bst', 'fruit.txt']
word_count.main()
outstrings = output.getvalue().splitlines()
self.assertEqual(sorted(outstrings), ['1 apple', '2 banana', '3 cherry'])
with unittest.mock.patch('sys.stdout', new=io.StringIO()) as output:
sys.argv = ['word_count.py', 'list', 'fruit.txt']
word_count.main()
outstrings = output.getvalue().splitlines()
self.assertEqual(sorted(outstrings), ['1 apple', '2 banana', '3 cherry'])
WordCountTest.Points += 1
if __name__ == '__main__':
unittest.main()
bash
avl_test.py
#!/usr/bin/env python3
import unittest
from avl import AVLTree, Node
class AvlTest(unittest.TestCase):
@classmethod
def setupClass(cls):
pass
@classmethod
def tearDownClass(cls):
pass
def test_init(self):
avl = AVLTree()
self.assertIsNone(avl.root)
def test_init_with_root(self):
root = Node(10)
avl = AVLTree(root)
self.assertEqual(avl.root, root)
def test_get_height(self):
avl = AVLTree()
self.assertEqual(avl._get_height(avl.root), -1)
def test_update_height(self):
root = Node(20)
avl = AVLTree(root)
avl._update_height(avl.root)
self.assertEqual(avl._get_height(avl.root), 0)
root.left = Node(10)
avl._update_height(avl.root)
self.assertEqual(avl._get_height(avl.root), 1)
root.right = Node(30)
avl._update_height(avl.root)
self.assertEqual(avl._get_height(avl.root), 1)
root.left = None
avl._update_height(avl.root)
self.assertEqual(avl._get_height(avl.root), 1)
def test_balance_factor(self):
root = Node(20)
avl = AVLTree(root)
avl._update_height(avl.root)
self.assertEqual(avl._balance_factor(avl.root), 0)
root.left = Node(10)
avl._update_height(avl.root)
self.assertEqual(avl._balance_factor(avl.root), 1)
root.right = Node(30)
avl._update_height(avl.root)
self.assertEqual(avl._balance_factor(avl.root), 0)
root.left = None
avl._update_height(avl.root)
self.assertEqual(avl._balance_factor(avl.root), -1)
def test_rotate_left(self):
# General case
root = Node(4)
root.left = Node(2)
root.right = Node(6)
root.right.left = Node(5)
root.right.right = Node(7)
avl = AVLTree(root)
avl.root = avl._rotate_left(avl.root)
self.assertEqual(avl.write_bfs(), [6, 4, 7, 2, 5])
# Right-right case
root = Node(1)
root.right = Node(2)
root.right.right = Node(3)
avl = AVLTree(root)
avl.root = avl._rotate_left(avl.root)
self.assertEqual(avl.write_bfs(), [2, 1, 3])
def test_rotate_right(self):
# General case
root = Node(4)
root.left = Node(2)
root.right = Node(6)
root.left.left = Node(1)
root.left.right = Node(3)
avl = AVLTree(root)
avl.root = avl._rotate_right(avl.root)
self.assertEqual(avl.write_bfs(), [2, 1, 4, None, None, 3, 6])
# Left-left case
root = Node(3)
root.left = Node(2)
root.left.left = Node(1)
avl = AVLTree(root)
avl.root = avl._rotate_right(avl.root)
self.assertEqual(avl.write_bfs(), [2, 1, 3])
def test_balance(self):
# Left-left case
# 3
# /
# 2
# /
# 1
root = Node(3)
root.height = 2
root.left = Node(2)
root.left.height = 1
root.left.left = Node(1)
avl = AVLTree(root)
avl.root = avl._balance(avl.root)
self.assertEqual(avl.write_bfs(), [2, 1, 3])
# Left-right case
# __3
# /
# 1
# \
# 2
root = Node(3)
root.height = 2
root.left = Node(1)
root.left.height = 1
root.left.right = Node(2)
avl = AVLTree(root)
avl.root = avl._balance(avl.root)
self.assertEqual(avl.write_bfs(), [2, 1, 3])
# Right-right case
# 1
# \
# 2
# \
# 3
root = Node(1)
root.height
root.right = Node(2)
root.right.height = 1
root.right.right = Node(3)
avl = AVLTree(root)
avl.root = avl._balance(avl.root)
self.assertEqual(avl.write_bfs(), [2, 1, 3])
# Right-left case
# 1__
# \
# 3
# /
# 2
root = Node(1)
root.height = 2
root.right = Node(3)
root.right.height = 1
root.right.left = Node(2)
avl = AVLTree(root)
avl.root = avl._balance(avl.root)
self.assertEqual(avl.write_bfs(), [2, 1, 3])
def test_insert(self):
avl = AVLTree()
avl.insert(30)
self.assertEqual(avl.root.key, 30)
self.assertIsNone(avl.root.left)
self.assertIsNone(avl.root.right)
avl.insert(30)
self.assertEqual(avl.root.key, 30)
self.assertIsNone(avl.root.left)
self.assertIsNone(avl.root.right)
avl.insert(10)
self.assertEqual(avl.root.left.key, 10)
avl.insert(50)
self.assertEqual(avl.root.right.key, 50)
avl.insert(20)
self.assertEqual(avl.root.left.right.key, 20)
avl.insert(40)
self.assertEqual(avl.root.right.left.key, 40)
avl.insert(70)
self.assertEqual(avl.root.right.right.key, 70)
avl.insert(60)
self.assertEqual(avl.root.right.right.left.key, 60)
avl.clear()
avl.insert(1)
avl.insert(2)
avl.insert(3)
self.assertEqual(avl.write_bfs(), [2, 1, 3])
avl.insert(4)
avl.insert(5)
avl.insert(6)
self.assertEqual(avl.write_bfs(), [4, 2, 5, 1, 3, None, 6])
avl.clear()
avl.insert(6)
avl.insert(5)
avl.insert(4)
self.assertEqual(avl.write_bfs(), [5, 4, 6])
avl.insert(3)
avl.insert(2)
avl.insert(1)
self.assertEqual(avl.write_bfs(), [3, 2, 5, 1, None, 4, 6])
def test_min_value_node(self):
avl = AVLTree()
avl.insert(3)
self.assertEqual(avl._min_value_node(avl.root), 3)
avl.insert(2)
self.assertEqual(avl._min_value_node(avl.root), 2)
avl.insert(1)
self.assertEqual(avl._min_value_node(avl.root), 1)
def test_remove(self):
avl = AVLTree()
avl.insert(2)
avl.remove(2)
self.assertEqual(avl.root, None)
avl.insert(2)
avl.insert(1)
avl.insert(3)
avl.remove(1)
self.assertEqual(avl.root.left, None)
self.assertEqual(avl.root.right.key, 3)
avl.remove(3)
self.assertEqual(avl.root.right, None)
avl.insert(3)
avl.remove(2)
self.assertEqual(avl.root.key, 3)
avl.clear()
avl.insert(30)
avl.insert(10)
avl.insert(50)
avl.insert(40)
avl.insert(45)
avl.remove(30)
self.assertEqual(avl.root.key, 40)
self.assertEqual(avl.root.right.key, 45)
self.assertEqual(avl.root.right.left, None)
self.assertEqual(avl.root.right.right.key, 50)
avl.clear()
# Check rebalancing resulting from remove
# Case: Right-Right
avl.insert(40)
avl.insert(20)
avl.insert(60)
avl.insert(70)
avl.remove(20)
self.assertEqual(avl.root.key, 60)
self.assertEqual(avl.root.left.key, 40)
self.assertEqual(avl.root.right.key, 70)
avl.clear()
# Case: Right-Left
avl.insert(40)
avl.insert(20)
avl.insert(60)
avl.insert(55)
avl.remove(20)
self.assertEqual(avl.root.key, 55)
self.assertEqual(avl.root.left.key, 40)
self.assertEqual(avl.root.right.key, 60)
def test_write_bfs(self):
avl = AVLTree()
avl.insert(30)
self.assertEqual(avl.write_bfs(), [30])
avl.insert(50)
self.assertEqual(avl.write_bfs(), [30, None, 50])
avl.insert(10)
self.assertEqual(avl.write_bfs(), [30, 10, 50])
avl.insert(20)
avl.insert(40)
avl.insert(70)
self.assertEqual(avl.write_bfs(), [30, 10, 50, None, 20, 40, 70])
avl.insert(60)
self.assertEqual(avl.write_bfs(), [30, 10, 50, None, 20, 40, 70, None, None, None, None, 60])
if __name__ == '__main__':
unittest.main()
bash
avl_map_test.py
from avl import Node, AVLTree
from typing import Iterator
class AVLMap(AVLTree):
def __init__(self, root: Node=None, do_balance: bool=True) -> None:
""" AVLMap is a subclass of AVLTree
The constructor calls the constructor of the superclass
The constructor takes an optional argument do_balance
to enable or disable balancing
Input: root (Node) - the root of the AVL tree
do_balance (bool) - whether to balance the tree after insertion or deletion
Output: None
"""
super().__init__(root, do_balance)
def _insert_or_update(self, root: Node, key, value) -> Node:
""" Helper method to insert or update a [key, value] pair
It is similar to the insert method in AVLTree, except
that it updates the value if the key is already in the tree
Input: root (Node) - the root of the AVL tree
key (comparable) - the key to insert or update
value (any) - the value to insert or update
Output: Node - the root of the AVL tree
"""
# TODO:
if root is None:
return Node([key, value])
if key == root.key[0]:
root.key[1] = value
return root
elif key < root.key[0]:
root.left = self._insert_or_update(root.left, key, value)
else:
root.right = self._insert_or_update(root.right, key, value)
if self.do_balance:
return self._balance(root)
else:
return root
pass
def __setitem__(self, key, value) -> None:
""" Magic method to support assignment using the [] operator
Input: key (comparable) - the key to insert or update
value (any) - the value to insert or update
Output: None
"""
# TODO:
self.root = self._insert_or_update(self.root, key, value)
pass
def _get_node(self, root: Node, key) -> Node:
""" Helper method to get the node with the key
Used by __getitem__, __contains__, and get
Input: root (Node) - the root of the AVL tree
key (comparable) - the key to search for
Output: Node - the node with the key, or None if not found
"""
# TODO:
if root is None:
return None
if key == root.key[0]:
return root
elif key < root.key[0]:
return self._get_node(root.left, key)
else:
return self._get_node(root.right, key)
pass
def __getitem__(self, key) -> any:
""" Magic method to support retrieval using the [] operator
Raise a KeyError if the key is not found, consistent with the behavior
of the __getitem__ method of the dict class
Uses binary search to locate the index of the key
Input: key (comparable) - the key to search for
Output: any - the value of the key
"""
# TODO:
node = self._get_node(self.root, key)
if node is None:
raise KeyError(key)
return node.key[1]
pass
def get(self, key, default=None) -> any:
""" Get the value of the key, or return the default value if the key is not found,
consistent with the behavior of the get method of the dict class
Input: key (comparable) - the key to search for
default (any) - the default value to return if the key is not found
Output: any - the value of the key or the default value
"""
# TODO:
node = self._get_node(self.root, key)
if node is None:
return default
return node.key[1]
pass
def __contains__(self, key) -> bool:
""" Magic method to support the in operator
Input: key (comparable) - the key to search for
Output: bool - True if the key is found, False otherwise
"""
# TODO:
return self._get_node(self.root, key) is not None
pass
def pop(self, key) -> any:
""" Remove a [key, value] pair from the map and return the value,
consistent with the pop method in the dict class
Raise a KeyError if the key is not found
Input: key (comparable) - the key to remove
Output: any - the value of the key
"""
# TODO:
node = self._get_node(self.root, key)
if node is None:
raise KeyError(key)
value = node.key[1]
self.remove(node.key)
return value
pass
def inorder_generator(self) -> Iterator[list[any]]:
""" A generator that yields the [key, value] pairs in the map
using an inorder traversal
of the dict class.
Input: None
Output: generator - the [key, value] pairs in the map
"""
# TODO:
yield from self._inorder_generator(self.root)
pass
def _inorder_generator(self, root) -> Iterator[list[any]]:
""" Helper recursive generator method that yields the [key, value] pairs
in the map in an inorder traversal
Input: root (Node) - the root of the AVL tree
Output: generator - the [key, value] pairs in the map
"""
# TODO:
if root is None:
return
yield from self._inorder_generator(root.left)
yield root.key
yield from self._inorder_generator(root.right)
pass
def items(self) -> Iterator[list[any]]:
""" A generator that yields the [key, value] pairs in the map"
Usage is consistent with the behavior of the items method
of the dict class."
Hint: You can just yield from the inorder_generator method
Input: None
Output: generator - the [key, value] pairs in the map
"""
# TODO:
yield from self.inorder_generator()
pass
def __iter__(self) -> Iterator[any]:
""" Magic method to support iteration using a generator
This yields the keys only, consistent with
the behavior of the iterator for the dict class.
Hint: Use the inorder_generator method
Input: None
Output: generator - the keys in the map
"""
# TODO:
for pair in self.inorder_generator():
yield pair[0]
pass
bash
list_map_test.py
#!/usr/bin/env python3
import unittest
from list_map import ListMap
class ListMapTest(unittest.TestCase):
AssignmentTotal = 45
Total = 9
Points = 0
@classmethod
def setupClass(cls):
pass
@classmethod
def tearDownClass(cls):
assignment_points = cls.AssignmentTotal * cls.Points / cls.Total
print()
print(f' Score {assignment_points:.2f} / {cls.AssignmentTotal:.2f}')
print(f' Status {"Success" if cls.Points >= cls.Total else "Failure"}')
def test_binary_search(self):
list_map = ListMap()
list_map.data = [['a', 1], ['b', 2], ['c', 3], ['d', 4], ['e', 5]]
self.assertEqual(list_map._binary_search('a'), (0, 1))
self.assertEqual(list_map._binary_search('b'), (1, 2))
self.assertEqual(list_map._binary_search('c'), (2, 3))
self.assertEqual(list_map._binary_search('d'), (3, 4))
self.assertEqual(list_map._binary_search('e'), (4, 5))
self.assertEqual(list_map._binary_search('f'), (5, None))
ListMapTest.Points += 1
def test_setitem(self):
list_map = ListMap()
list_map['a'] = 1
list_map['b'] = 2
list_map['c'] = 3
list_map['d'] = 4
list_map['e'] = 5
self.assertEqual(list_map.data, [['a', 1], ['b', 2], ['c', 3], ['d', 4], ['e', 5]])
list_map['c'] = 33
self.assertEqual(list_map.data, [['a', 1], ['b', 2], ['c', 33], ['d', 4], ['e', 5]])
ListMapTest.Points += 1
def test_getitem(self):
list_map = ListMap()
list_map['a'] = 1
list_map['b'] = 2
list_map['c'] = 3
list_map['d'] = 4
list_map['e'] = 5
self.assertEqual(list_map['a'], 1)
self.assertEqual(list_map['b'], 2)
self.assertEqual(list_map['c'], 3)
self.assertEqual(list_map['d'], 4)
self.assertEqual(list_map['e'], 5)
with self.assertRaises(KeyError):
list_map['f']
ListMapTest.Points += 1
def test_get(self):
list_map = ListMap()
list_map['a'] = 1
list_map['b'] = 2
list_map['c'] = 3
list_map['d'] = 4
list_map['e'] = 5
self.assertEqual(list_map.get('a'), 1)
self.assertEqual(list_map.get('b'), 2)
self.assertEqual(list_map.get('c'), 3)
self.assertEqual(list_map.get('d'), 4)
self.assertEqual(list_map.get('e'), 5)
self.assertEqual(list_map.get('f'), None)
self.assertEqual(list_map.get('f', 'default'), 'default')
ListMapTest.Points += 1
def test_contains(self):
list_map = ListMap()
list_map['a'] = 1
list_map['b'] = 2
list_map['c'] = 3
list_map['d'] = 4
list_map['e'] = 5
self.assertTrue('a' in list_map)
self.assertTrue('b' in list_map)
self.assertTrue('c' in list_map)
self.assertTrue('d' in list_map)
self.assertTrue('e' in list_map)
self.assertFalse('f' in list_map)
ListMapTest.Points += 1
def test_pop(self):
list_map = ListMap()
list_map['a'] = 1
list_map['b'] = 2
list_map['c'] = 3
list_map['d'] = 4
list_map['e'] = 5
self.assertEqual(list_map.pop('a'), 1)
self.assertEqual(list_map.pop('c'), 3)
self.assertEqual(list_map.pop('e'), 5)
self.assertEqual(list_map.pop('b'), 2)
self.assertEqual(list_map.pop('d'), 4)
with self.assertRaises(KeyError):
list_map.pop('f')
ListMapTest.Points += 1
def test_iter(self):
list_map = ListMap()
list_map['a'] = 1
list_map['b'] = 2
list_map['c'] = 3
list_map['d'] = 4
list_map['e'] = 5
result = []
for key in list_map:
result.append(key)
self.assertEqual(result, ['a', 'b', 'c', 'd', 'e'])
self.assertEqual(list(list_map), ['a', 'b', 'c', 'd', 'e'])
ListMapTest.Points += 1
def test_items(self):
list_map = ListMap()
list_map['a'] = 1
list_map['b'] = 2
list_map['c'] = 3
list_map['d'] = 4
list_map['e'] = 5
self.assertEqual(list(list_map.items()), [['a', 1], ['b', 2], ['c', 3], ['d', 4], ['e', 5]])
ListMapTest.Points += 1
def test_next(self):
list_map = ListMap()
list_map['a'] = 1
list_map['b'] = 2
list_map['c'] = 3
list_map['d'] = 4
list_map['e'] = 5
iter_map = iter(list_map)
self.assertEqual(next(iter_map), 'a')
self.assertEqual(next(iter_map), 'b')
self.assertEqual(next(iter_map), 'c')
self.assertEqual(next(iter_map), 'd')
self.assertEqual(next(iter_map), 'e')
with self.assertRaises(StopIteration):
next(iter_map)
ListMapTest.Points += 1
if __name__ == '__main__':
unittest.main()
测试数据:
bash
fruit.txt
apple banana cherry
banana cherry
cherry
quotes.txt
We the People of the United States, in Order to form a more perfect Union,
establish Justice, insure domestic Tranquility, provide for the common defense,
promote the general Welfare, and secure the Blessings of Liberty to ourselves
and our Posterity, do ordain and establish this Constitution
for the United States of America.
We hold these truths to be self-evident, that all men are created equal,
that they are endowed by their Creator with certain unalienable Rights,
that among these are Life, Liberty and the pursuit of Happiness.
And for the support of this Declaration,
with a firm reliance on the protection of divine Providence,
we mutually pledge to each other our Lives, our Fortunes and our sacred Honor.
bash
Makefile
test:
@$(MAKE) -sk test_all
test_all: test_avl_map test_list_map test_word_count
test_avl_map: avl_map_test.py
@echo Testing avl_map ...
@chmod +x ./avl_map_test.py
@./avl_map_test.py -v
@echo
test_list_map: list_map_test.py
@echo Testing list_map ...
@chmod +x ./list_map_test.py
@./list_map_test.py -v
@echo
test_word_count: word_count_test.py
@echo Testing word_count ...
@chmod +x ./word_count_test.py
@./word_count_test.py -v
@echo