11.18 自定义Pandas扩展开发指南:打造你的专属数据分析武器库


文章目录

  • 前言
  • [一、 为什么要开发自定义扩展?](#一、 为什么要开发自定义扩展?)
    • [1.1 现实痛点与解决方案](#1.1 现实痛点与解决方案)
  • [二、 Pandas扩展类型基础](#二、 Pandas扩展类型基础)
    • [2.1 扩展数组(ExtensionArray)](#2.1 扩展数组(ExtensionArray))
    • [2.2 更实用的业务扩展类型](#2.2 更实用的业务扩展类型)
  • [三、 DataFrame扩展方法](#三、 DataFrame扩展方法)
    • [3.1 自定义DataFrame访问器(Accessor)](#3.1 自定义DataFrame访问器(Accessor))
    • [3.2 Series扩展访问器](#3.2 Series扩展访问器)
  • [四、 完整业务扩展实战:电商数据分析套件](#四、 完整业务扩展实战:电商数据分析套件)
  • 五、性能优化扩展:高效数据处理工具
  • 六、打包与分发扩展
    • [6.1 创建可安装的扩展包](#6.1 创建可安装的扩展包)
    • [6.2 最佳实践与部署指南](#6.2 最佳实践与部署指南)

前言

你是否曾想过:如果Pandas能直接支持我特定业务领域的数据操作该多好?今天,我将带你走进Pandas扩展开发的世界,教你如何打造专属的数据分析工具,让你的团队效率提升十倍!


一、 为什么要开发自定义扩展?

1.1 现实痛点与解决方案

python 复制代码
python
import pandas as pd
import numpy as np
from typing import Any, Dict, List, Optional, Union
import warnings
warnings.filterwarnings('ignore')

print("🎯 Pandas版本: ", pd.__version__)

# 常见业务场景痛点示例
class BusinessDataProcessor:
    """传统业务数据处理类"""
    
    @staticmethod
    def calculate_financial_ratios(df):
        """计算财务比率 - 传统方法"""
        ratios = {}
        
        # 净利率
        if 'net_profit' in df.columns and 'revenue' in df.columns:
            ratios['net_margin'] = df['net_profit'] / df['revenue']
        
        # 资产负债率
        if 'total_liabilities' in df.columns and 'total_assets' in df.columns:
            ratios['debt_ratio'] = df['total_liabilities'] / df['total_assets']
        
        # ROE
        if 'net_profit' in df.columns and 'equity' in df.columns:
            ratios['roe'] = df['net_profit'] / df['equity']
        
        return pd.DataFrame(ratios)
    
    @staticmethod  
    def clean_chinese_text(df, column):
        """清洗中文文本 - 传统方法"""
        import re
        
        def clean_text(text):
            if pd.isna(text):
                return text
            
            # 移除多余空格
            text = re.sub(r'\s+', ' ', str(text))
            # 移除特殊字符但保留中文
            text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s]', '', text)
            return text.strip()
        
        return df[column].apply(clean_text)

# 使用传统方法
print("传统方法痛点演示:")
sample_data = pd.DataFrame({
    'net_profit': [100, 200, 150],
    'revenue': [1000, 1500, 1200],
    'total_liabilities': [500, 600, 550],
    'total_assets': [1000, 1200, 1100],
    'equity': [500, 600, 550],
    'chinese_text': ['  你好,世界!  ', '测试-数据  ', '  重要:信息']
})

processor = BusinessDataProcessor()
ratios = processor.calculate_financial_ratios(sample_data)
print("财务比率计算:")
print(ratios)

cleaned_text = processor.clean_chinese_text(sample_data, 'chinese_text')
print("\n中文文本清洗:")
print(cleaned_text)

print("\n❌ 传统方法的问题:")
print("1. 代码重复:每个项目都要重写")
print("2. 不易维护:业务逻辑分散")
print("3. 性能不佳:逐行apply操作")
print("4. 接口不统一:每个团队有自己的实现")

二、 Pandas扩展类型基础

2.1 扩展数组(ExtensionArray)

python 复制代码
python
# 创建一个简单的扩展数组
from pandas.api.extensions import ExtensionArray, ExtensionDtype
import pandas as pd
import numpy as np

class PercentageDtype(ExtensionDtype):
    """百分比数据类型"""
    
    name = 'percentage'
    type = float  # 底层Python类型
    kind = 'f'  # numpy dtype kind
    
    @classmethod
    def construct_array_type(cls):
        return PercentageArray
    
    @classmethod
    def construct_from_string(cls, string):
        if string == cls.name:
            return cls()
        else:
            raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")

class PercentageArray(ExtensionArray):
    """百分比数组实现"""
    
    def __init__(self, values):
        """初始化百分比数组
        
        Parameters
        ----------
        values : array-like
            百分比值,可以是0-1的小数或0-100的百分比
        """
        # 确保是numpy数组
        self._data = np.asarray(values, dtype=np.float64)
        
        # 自动检测并转换:如果最大值>1,假设是百分比格式
        if self._data.max() > 1:
            self._data = self._data / 100
    
    def __len__(self):
        return len(self._data)
    
    def __getitem__(self, item):
        if isinstance(item, int):
            return self._data[item]
        elif isinstance(item, slice):
            return type(self)(self._data[item])
        else:
            # 处理其他索引类型
            return type(self)(self._data[item])
    
    @classmethod
    def _from_sequence(cls, scalars, dtype=None, copy=False):
        """从序列创建实例"""
        return cls(scalars)
    
    @classmethod
    def _from_factorized(cls, values, original):
        """从因式分解值创建实例"""
        return cls(values)
    
    def _values_for_factorize(self):
        """返回用于因式分解的值"""
        return self._data, np.nan
    
    def __setitem__(self, key, value):
        """设置项"""
        if isinstance(key, int):
            self._data[key] = value
        else:
            self._data[key] = value
    
    def isna(self):
        """检测缺失值"""
        return pd.isna(self._data)
    
    def take(self, indices, allow_fill=False, fill_value=None):
        """取元素"""
        from pandas.core.algorithms import take
        
        data = self._data
        if allow_fill and fill_value is None:
            fill_value = self.dtype.na_value
        
        result = take(
            data, 
            indices, 
            allow_fill=allow_fill, 
            fill_value=fill_value
        )
        
        return self._from_sequence(result)
    
    def copy(self):
        """复制数组"""
        return type(self)(self._data.copy())
    
    @classmethod
    def _concat_same_type(cls, to_concat):
        """连接相同类型的数组"""
        data = np.concatenate([x._data for x in to_concat])
        return cls(data)
    
    def _format_values(self):
        """格式化显示值"""
        return [f"{x:.1%}" if not pd.isna(x) else "NaN" for x in self._data]
    
    def __repr__(self):
        return f"<PercentageArray>\n{self._format_values()}"
    
    @property
    def dtype(self):
        return PercentageDtype()
    
    @property
    def nbytes(self):
        """返回字节大小"""
        return self._data.nbytes

# 注册扩展类型
pd.api.extensions.register_extension_dtype(PercentageDtype)

# 测试扩展类型
print("🔧 测试百分比扩展类型:")
percentages = PercentageArray([0.1, 0.25, 0.5, 0.75, 1.0])
print("百分比数组:")
print(percentages)
print(f"数据类型: {percentages.dtype}")
print(f"长度: {len(percentages)}")
print(f"内存使用: {percentages.nbytes} 字节")

# 创建包含扩展类型的Series
s = pd.Series(percentages, name="利润率")
print("\nPandas Series:")
print(s)

# 测试缺失值处理
percentages_with_na = PercentageArray([0.1, np.nan, 0.5, None, 1.0])
print("\n带缺失值的百分比数组:")
print(percentages_with_na)
print(f"缺失值检测: {percentages_with_na.isna()}")

2.2 更实用的业务扩展类型

python 复制代码
python
class CurrencyDtype(ExtensionDtype):
    """货币数据类型(支持不同货币单位)"""
    
    def __init__(self, currency_code='CNY'):
        self.currency_code = currency_code
        self.name = f'currency[{currency_code}]'
    
    @property
    def type(self):
        return float
    
    @property
    def kind(self):
        return 'f'
    
    @classmethod
    def construct_array_type(cls):
        return CurrencyArray
    
    @classmethod
    def construct_from_string(cls, string):
        if string.startswith('currency[') and string.endswith(']'):
            currency_code = string[9:-1]
            return cls(currency_code)
        elif string == 'currency':
            return cls()
        else:
            raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")
    
    @property
    def na_value(self):
        return np.nan

class CurrencyArray(ExtensionArray):
    """货币数组实现"""
    
    def __init__(self, values, currency_code='CNY', exchange_rates=None):
        self._data = np.asarray(values, dtype=np.float64)
        self.currency_code = currency_code
        self.exchange_rates = exchange_rates or {}
    
    def __len__(self):
        return len(self._data)
    
    def __getitem__(self, item):
        if isinstance(item, int):
            return self._data[item]
        elif isinstance(item, slice):
            return type(self)(
                self._data[item], 
                self.currency_code, 
                self.exchange_rates
            )
        else:
            return type(self)(
                self._data[item], 
                self.currency_code, 
                self.exchange_rates
            )
    
    @classmethod
    def _from_sequence(cls, scalars, dtype=None, copy=False):
        if isinstance(dtype, CurrencyDtype):
            currency_code = dtype.currency_code
        else:
            currency_code = 'CNY'
        
        return cls(scalars, currency_code)
    
    def convert_to(self, target_currency):
        """转换到目标货币"""
        if target_currency == self.currency_code:
            return self.copy()
        
        rate_key = f"{self.currency_code}_{target_currency}"
        if rate_key in self.exchange_rates:
            rate = self.exchange_rates[rate_key]
            converted_data = self._data * rate
            return CurrencyArray(
                converted_data, 
                target_currency, 
                self.exchange_rates
            )
        else:
            raise ValueError(f"找不到汇率: {rate_key}")
    
    def _format_values(self):
        """格式化显示"""
        symbols = {
            'CNY': '¥',
            'USD': '$',
            'EUR': '€',
            'JPY': '¥'
        }
        symbol = symbols.get(self.currency_code, self.currency_code)
        
        formatted = []
        for val in self._data:
            if pd.isna(val):
                formatted.append("NaN")
            else:
                # 根据货币选择合适的小数位数
                if self.currency_code in ['JPY', 'KRW']:
                    formatted.append(f"{symbol}{val:,.0f}")
                else:
                    formatted.append(f"{symbol}{val:,.2f}")
        
        return formatted
    
    def __repr__(self):
        return f"<CurrencyArray[{self.currency_code}]>\n{self._format_values()}"
    
    # 实现必要的方法
    def isna(self):
        return pd.isna(self._data)
    
    def copy(self):
        return type(self)(
            self._data.copy(), 
            self.currency_code, 
            self.exchange_rates.copy()
        )
    
    def take(self, indices, allow_fill=False, fill_value=None):
        from pandas.core.algorithms import take
        
        data = self._data
        if allow_fill and fill_value is None:
            fill_value = self.dtype.na_value
        
        result = take(
            data, 
            indices, 
            allow_fill=allow_fill, 
            fill_value=fill_value
        )
        
        return type(self)(result, self.currency_code, self.exchange_rates)
    
    @classmethod
    def _concat_same_type(cls, to_concat):
        # 检查所有数组的货币单位是否相同
        currencies = [arr.currency_code for arr in to_concat]
        if len(set(currencies)) > 1:
            raise ValueError("不能连接不同货币单位的数组")
        
        data = np.concatenate([arr._data for arr in to_concat])
        exchange_rates = to_concat[0].exchange_rates
        
        return cls(data, currencies[0], exchange_rates)
    
    @property
    def dtype(self):
        return CurrencyDtype(self.currency_code)

# 注册货币类型
pd.api.extensions.register_extension_dtype(CurrencyDtype)

# 测试货币扩展类型
print("\n💰 测试货币扩展类型:")

# 创建汇率表
exchange_rates = {
    'CNY_USD': 0.14,
    'USD_CNY': 7.14,
    'CNY_EUR': 0.13,
    'EUR_CNY': 7.69
}

# 创建人民币数组
cny_array = CurrencyArray(
    [1000, 2000, 3000, 4000, 5000], 
    'CNY', 
    exchange_rates
)
print("人民币数组:")
print(cny_array)

# 转换为美元
try:
    usd_array = cny_array.convert_to('USD')
    print("\n转换为美元:")
    print(usd_array)
except ValueError as e:
    print(f"转换错误: {e}")

# 创建Series
s_cny = pd.Series(cny_array, name="收入_CNY")
print("\n货币Series:")
print(s_cny)

三、 DataFrame扩展方法

3.1 自定义DataFrame访问器(Accessor)

python 复制代码
python
# 金融数据分析访问器
@pd.api.extensions.register_dataframe_accessor("finance")
class FinanceAccessor:
    """金融数据分析访问器"""
    
    def __init__(self, pandas_obj):
        self._obj = pandas_obj
        self._validate()
    
    def _validate(self):
        """验证DataFrame是否包含必要的金融数据列"""
        required_cols = ['revenue', 'net_profit', 'assets']
        missing = [col for col in required_cols if col not in self._obj.columns]
        
        if missing:
            warnings.warn(
                f"缺少金融分析所需列: {missing}. "
                f"某些方法可能不可用。",
                UserWarning
            )
    
    def profitability_ratios(self):
        """计算盈利能力比率"""
        ratios = {}
        
        if 'revenue' in self._obj.columns and 'net_profit' in self._obj.columns:
            ratios['net_margin'] = self._obj['net_profit'] / self._obj['revenue']
        
        if 'assets' in self._obj.columns and 'net_profit' in self._obj.columns:
            ratios['roa'] = self._obj['net_profit'] / self._obj['assets']
        
        if 'equity' in self._obj.columns and 'net_profit' in self._obj.columns:
            ratios['roe'] = self._obj['net_profit'] / self._obj['equity']
        
        return pd.DataFrame(ratios, index=self._obj.index)
    
    def liquidity_ratios(self):
        """计算流动性比率"""
        ratios = {}
        
        if all(col in self._obj.columns for col in ['current_assets', 'current_liabilities']):
            ratios['current_ratio'] = (
                self._obj['current_assets'] / self._obj['current_liabilities']
            )
        
        if all(col in self._obj.columns for col in ['cash', 'current_liabilities']):
            ratios['cash_ratio'] = (
                self._obj['cash'] / self._obj['current_liabilities']
            )
        
        return pd.DataFrame(ratios, index=self._obj.index)
    
    def growth_rates(self, period='year'):
        """计算增长率"""
        growth = {}
        
        numeric_cols = self._obj.select_dtypes(include=[np.number]).columns
        
        for col in numeric_cols:
            if f'{col}_prev' in self._obj.columns:
                growth[f'{col}_growth'] = (
                    (self._obj[col] - self._obj[f'{col}_prev']) / 
                    self._obj[f'{col}_prev'].abs()
                )
        
        return pd.DataFrame(growth, index=self._obj.index)
    
    def dupont_analysis(self):
        """杜邦分析"""
        if not all(col in self._obj.columns for col in ['net_margin', 'asset_turnover', 'equity_multiplier']):
            # 尝试从现有数据计算
            result = {}
            
            if 'net_profit' in self._obj.columns and 'revenue' in self._obj.columns:
                result['net_margin'] = self._obj['net_profit'] / self._obj['revenue']
            
            if 'revenue' in self._obj.columns and 'assets' in self._obj.columns:
                result['asset_turnover'] = self._obj['revenue'] / self._obj['assets']
            
            if 'assets' in self._obj.columns and 'equity' in self._obj.columns:
                result['equity_multiplier'] = self._obj['assets'] / self._obj['equity']
            
            if all(key in result for key in ['net_margin', 'asset_turnover', 'equity_multiplier']):
                result['roe_dupont'] = (
                    result['net_margin'] * 
                    result['asset_turnover'] * 
                    result['equity_multiplier']
                )
            
            return pd.DataFrame(result, index=self._obj.index)
        
        return None

# 测试金融访问器
print("\n📈 测试金融数据分析访问器:")

# 创建金融数据
finance_data = pd.DataFrame({
    'company': ['A', 'B', 'C', 'D'],
    'revenue': [1000, 1500, 1200, 1800],
    'revenue_prev': [900, 1400, 1100, 1600],
    'net_profit': [100, 180, 130, 220],
    'net_profit_prev': [90, 170, 120, 200],
    'assets': [5000, 6000, 5500, 7000],
    'equity': [3000, 3500, 3200, 4000],
    'current_assets': [2000, 2500, 2300, 2800],
    'current_liabilities': [1000, 1200, 1100, 1300],
    'cash': [500, 600, 550, 700]
})

print("原始金融数据:")
print(finance_data)

# 使用访问器
print("\n盈利能力比率:")
profitability = finance_data.finance.profitability_ratios()
print(profitability)

print("\n流动性比率:")
liquidity = finance_data.finance.liquidity_ratios()
print(liquidity)

print("\n增长率:")
growth = finance_data.finance.growth_rates()
print(growth)

print("\n杜邦分析:")
dupont = finance_data.finance.dupont_analysis()
print(dupont)

3.2 Series扩展访问器

python 复制代码
python
# 中文文本处理访问器
@pd.api.extensions.register_series_accessor("chinese")
class ChineseTextAccessor:
    """中文文本处理访问器"""
    
    def __init__(self, pandas_obj):
        self._obj = pandas_obj
    
    def clean(self, remove_punctuation=True, remove_numbers=False, 
              remove_english=False, trim_spaces=True):
        """清洗中文文本"""
        import re
        
        def clean_text(text):
            if pd.isna(text):
                return text
            
            text = str(text)
            
            if trim_spaces:
                text = text.strip()
            
            if remove_punctuation:
                # 移除标点符号(保留中文标点)
                text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s]', '', text)
            
            if remove_numbers:
                text = re.sub(r'\d+', '', text)
            
            if remove_english:
                text = re.sub(r'[a-zA-Z]+', '', text)
            
            # 合并多个空格
            text = re.sub(r'\s+', ' ', text)
            
            return text
        
        return self._obj.apply(clean_text)
    
    def extract_chinese(self):
        """提取纯中文"""
        import re
        
        def extract(text):
            if pd.isna(text):
                return text
            
            # 匹配中文字符
            chinese_chars = re.findall(r'[\u4e00-\u9fa5]', str(text))
            return ''.join(chinese_chars)
        
        return self._obj.apply(extract)
    
    def word_count(self):
        """中文字数统计"""
        import re
        
        def count_words(text):
            if pd.isna(text):
                return 0
            
            # 统计中文字符
            chinese_chars = re.findall(r'[\u4e00-\u9fa5]', str(text))
            return len(chinese_chars)
        
        return self._obj.apply(count_words)
    
    def contains_keywords(self, keywords):
        """检查是否包含关键词"""
        if isinstance(keywords, str):
            keywords = [keywords]
        
        def check_contains(text):
            if pd.isna(text):
                return False
            
            text = str(text)
            return any(keyword in text for keyword in keywords)
        
        return self._obj.apply(check_contains)
    
    def sentiment_analysis(self, method='simple'):
        """简单情感分析"""
        # 简单的情感词典
        positive_words = {'好', '优秀', '满意', '喜欢', '推荐', '赞', '棒'}
        negative_words = {'差', '糟糕', '不满意', '讨厌', '垃圾', '坏'}
        
        def analyze_sentiment(text):
            if pd.isna(text):
                return 0
            
            text = str(text)
            words = set(text)
            
            positive_count = len(words & positive_words)
            negative_count = len(words & negative_words)
            
            if positive_count > negative_count:
                return 1
            elif negative_count > positive_count:
                return -1
            else:
                return 0
        
        if method == 'simple':
            return self._obj.apply(analyze_sentiment)
        else:
            raise ValueError(f"不支持的 sentiment 分析方法: {method}")

# 测试中文文本访问器
print("\n🇨🇳 测试中文文本处理访问器:")

# 创建中文文本数据
text_data = pd.Series([
    '这是一个非常好的产品,我很满意!',
    '质量太差了,非常不满意。',
    '服务一般,还可以改进。',
    '推荐购买,性价比很高。',
    '垃圾产品,不要买!',
    '优秀的表现,点赞!'
], name="用户评价")

print("原始文本:")
print(text_data)

print("\n文本清洗:")
cleaned = text_data.chinese.clean()
print(cleaned)

print("\n纯中文提取:")
chinese_only = text_data.chinese.extract_chinese()
print(chinese_only)

print("\n字数统计:")
word_counts = text_data.chinese.word_count()
print(word_counts)

print("\n关键词检查(包含'好'或'差'):")
has_keywords = text_data.chinese.contains_keywords(['好', '差'])
print(has_keywords)

print("\n情感分析:")
sentiments = text_data.chinese.sentiment_analysis()
print(sentiments)

# 综合应用
text_data_df = pd.DataFrame({
    'text': text_data,
    'cleaned': cleaned,
    'chinese_only': chinese_only,
    'word_count': word_counts,
    'sentiment': sentiments
})

print("\n处理后的完整数据:")
print(text_data_df)

四、 完整业务扩展实战:电商数据分析套件

python 复制代码
python
# 电商数据分析扩展套件
class EcommerceExtension:
    """电商数据分析扩展套件"""
    
    @staticmethod
    def register_all():
        """注册所有电商扩展"""
        EcommerceExtension.register_dataframe_accessor()
        EcommerceExtension.register_series_accessor()
        EcommerceExtension.register_extension_types()
    
    @staticmethod
    def register_dataframe_accessor():
        """注册DataFrame访问器"""
        
        @pd.api.extensions.register_dataframe_accessor("ecom")
        class EcommerceDataFrameAccessor:
            """电商DataFrame访问器"""
            
            def __init__(self, pandas_obj):
                self._obj = pandas_obj
                self._validate()
            
            def _validate(self):
                """验证电商数据格式"""
                required_for_analysis = [
                    'order_id', 'user_id', 'product_id', 
                    'quantity', 'price', 'order_date'
                ]
                
                missing = [
                    col for col in required_for_analysis 
                    if col not in self._obj.columns
                ]
                
                if missing:
                    warnings.warn(
                        f"缺少电商分析推荐列: {missing}",
                        UserWarning
                    )
            
            def customer_analysis(self):
                """客户分析"""
                if 'user_id' not in self._obj.columns:
                    raise ValueError("需要 'user_id' 列")
                
                analysis = {}
                
                # RFM分析
                if all(col in self._obj.columns for col in ['user_id', 'order_date', 'price']):
                    # 计算最近购买时间
                    latest_date = self._obj['order_date'].max()
                    recency = (
                        latest_date - self._obj.groupby('user_id')['order_date'].max()
                    ).dt.days
                    
                    # 计算购买频率
                    frequency = self._obj.groupby('user_id').size()
                    
                    # 计算消费金额
                    monetary = self._obj.groupby('user_id')['price'].sum()
                    
                    analysis['rfm'] = pd.DataFrame({
                        'recency': recency,
                        'frequency': frequency,
                        'monetary': monetary
                    })
                
                # 客户生命周期价值
                if all(col in self._obj.columns for col in ['user_id', 'price', 'order_date']):
                    # 简单CLV计算
                    clv_data = self._obj.groupby('user_id').agg({
                        'price': 'sum',
                        'order_date': ['min', 'max', 'nunique']
                    })
                    
                    clv_data.columns = ['total_spent', 'first_purchase', 
                                      'last_purchase', 'order_count']
                    
                    # 计算平均订单价值
                    clv_data['avg_order_value'] = (
                        clv_data['total_spent'] / clv_data['order_count']
                    )
                    
                    analysis['clv'] = clv_data
                
                return analysis
            
            def product_analysis(self):
                """商品分析"""
                if 'product_id' not in self._obj.columns:
                    raise ValueError("需要 'product_id' 列")
                
                analysis = {}
                
                # 商品销售统计
                product_stats = self._obj.groupby('product_id').agg({
                    'quantity': 'sum',
                    'price': ['sum', 'mean', 'count']
                })
                
                product_stats.columns = [
                    'total_quantity', 'total_revenue', 
                    'avg_price', 'order_count'
                ]
                
                analysis['product_stats'] = product_stats
                
                # 商品关联分析(简单版)
                if 'order_id' in self._obj.columns:
                    # 找出经常一起购买的商品
                    order_products = (
                        self._obj.groupby('order_id')['product_id']
                        .apply(lambda x: list(x.unique()))
                    )
                    
                    # 创建商品共现矩阵(简化版)
                    from collections import defaultdict
                    cooccurrence = defaultdict(int)
                    
                    for products in order_products:
                        for i in range(len(products)):
                            for j in range(i + 1, len(products)):
                                pair = tuple(sorted([products[i], products[j]]))
                                cooccurrence[pair] += 1
                    
                    analysis['product_pairs'] = pd.Series(cooccurrence)
                
                return analysis
            
            def sales_trends(self, freq='D'):
                """销售趋势分析"""
                if 'order_date' not in self._obj.columns:
                    raise ValueError("需要 'order_date' 列")
                
                # 设置日期索引
                df = self._obj.copy()
                df.set_index('order_date', inplace=True)
                
                trends = {}
                
                # 时间序列聚合
                if 'price' in df.columns:
                    revenue_trend = df['price'].resample(freq).sum()
                    trends['revenue'] = revenue_trend
                
                if 'quantity' in df.columns:
                    quantity_trend = df['quantity'].resample(freq).sum()
                    trends['quantity'] = quantity_trend
                
                if 'order_id' in df.columns:
                    order_trend = df['order_id'].resample(freq).nunique()
                    trends['orders'] = order_trend
                
                return pd.DataFrame(trends)
            
            def cohort_analysis(self):
                """用户群组分析"""
                if not all(col in self._obj.columns for col in ['user_id', 'order_date']):
                    raise ValueError("需要 'user_id' 和 'order_date' 列")
                
                # 创建群组(按用户首次购买月份)
                df = self._obj.copy()
                df['cohort'] = df.groupby('user_id')['order_date'].transform('min').dt.to_period('M')
                df['order_period'] = df['order_date'].dt.to_period('M')
                
                # 计算群组大小
                cohort_sizes = df.groupby('cohort')['user_id'].nunique()
                
                # 计算留存率
                cohort_data = df.groupby(['cohort', 'order_period']).agg({
                    'user_id': 'nunique',
                    'price': 'sum',
                    'quantity': 'sum'
                }).reset_index()
                
                cohort_data['period_number'] = (
                    cohort_data['order_period'] - cohort_data['cohort']
                ).apply(lambda x: x.n)
                
                # 创建留存矩阵
                retention_matrix = cohort_data.pivot_table(
                    index='cohort',
                    columns='period_number',
                    values='user_id',
                    aggfunc='sum'
                )
                
                # 计算留存率
                retention_rate = retention_matrix.divide(cohort_sizes, axis=0)
                
                return {
                    'cohort_sizes': cohort_sizes,
                    'cohort_data': cohort_data,
                    'retention_matrix': retention_matrix,
                    'retention_rate': retention_rate
                }
        
        return EcommerceDataFrameAccessor
    
    @staticmethod
    def register_series_accessor():
        """注册Series访问器"""
        
        @pd.api.extensions.register_series_accessor("ecom")
        class EcommerceSeriesAccessor:
            """电商Series访问器"""
            
            def __init__(self, pandas_obj):
                self._obj = pandas_obj
            
            def price_analysis(self):
                """价格分析"""
                if not pd.api.types.is_numeric_dtype(self._obj.dtype):
                    raise ValueError("价格数据必须是数值类型")
                
                analysis = {}
                
                # 基础统计
                analysis['stats'] = {
                    'mean': self._obj.mean(),
                    'median': self._obj.median(),
                    'std': self._obj.std(),
                    'min': self._obj.min(),
                    'max': self._obj.max(),
                    'skew': self._obj.skew(),
                    'kurtosis': self._obj.kurtosis()
                }
                
                # 价格分布分析
                q25, q50, q75 = self._obj.quantile([0.25, 0.5, 0.75])
                analysis['quartiles'] = {
                    'q25': q25,
                    'q50': q50,
                    'q75': q75,
                    'iqr': q75 - q25
                }
                
                # 价格段分析
                bins = [0, 10, 50, 100, 200, 500, 1000, float('inf')]
                labels = ['<10', '10-50', '50-100', '100-200', 
                         '200-500', '500-1000', '>1000']
                
                price_segments = pd.cut(self._obj, bins=bins, labels=labels)
                analysis['segments'] = price_segments.value_counts().sort_index()
                
                return analysis
            
            def discount_calculation(self, original_prices, discount_type='percentage'):
                """折扣计算"""
                if discount_type == 'percentage':
                    # 百分比折扣
                    discounts = (original_prices - self._obj) / original_prices * 100
                elif discount_type == 'absolute':
                    # 绝对折扣
                    discounts = original_prices - self._obj
                else:
                    raise ValueError(f"不支持的折扣类型: {discount_type}")
                
                return discounts

# 注册电商扩展
EcommerceExtension.register_all()

# 测试电商扩展
print("\n🛒 测试电商数据分析扩展:")

# 创建电商数据集
np.random.seed(42)
n_orders = 1000

ecom_data = pd.DataFrame({
    'order_id': range(1000, 1000 + n_orders),
    'user_id': np.random.randint(1, 101, n_orders),
    'product_id': np.random.choice(['P001', 'P002', 'P003', 'P004', 'P005'], n_orders),
    'quantity': np.random.randint(1, 5, n_orders),
    'price': np.random.uniform(10, 1000, n_orders).round(2),
    'order_date': pd.date_range('2023-01-01', periods=n_orders, freq='H'),
    'category': np.random.choice(['电子产品', '服装', '家居', '食品'], n_orders)
})

print(f"电商数据形状: {ecom_data.shape}")
print("\n前5行数据:")
print(ecom_data.head())

# 使用电商扩展进行分析
print("\n1. 客户分析:")
customer_analysis = ecom_data.ecom.customer_analysis()
print("RFM分析:")
print(customer_analysis.get('rfm', {}).head())

print("\n2. 商品分析:")
product_analysis = ecom_data.ecom.product_analysis()
print("商品销售统计:")
print(product_analysis.get('product_stats', {}))

print("\n3. 销售趋势:")
sales_trends = ecom_data.ecom.sales_trends(freq='D')
print("日销售趋势:")
print(sales_trends.head())

print("\n4. 价格分析:")
price_analysis = ecom_data['price'].ecom.price_analysis()
print("价格统计:")
for key, value in price_analysis['stats'].items():
    print(f"  {key}: {value:.2f}")

print("\n价格分段:")
print(price_analysis['segments'])

print("\n5. 用户群组分析:")
try:
    cohort_results = ecom_data.ecom.cohort_analysis()
    print("群组留存率矩阵:")
    print(cohort_results['retention_rate'].head())
except Exception as e:
    print(f"群组分析错误: {e}")

五、性能优化扩展:高效数据处理工具

python 复制代码
python
# 性能优化扩展
class PerformanceExtension:
    """性能优化扩展"""
    
    @staticmethod
    def register_all():
        """注册所有性能扩展"""
        PerformanceExtension.register_dataframe_accessor()
        PerformanceExtension.register_series_accessor()
    
    @staticmethod
    def register_dataframe_accessor():
        """注册DataFrame性能访问器"""
        
        @pd.api.extensions.register_dataframe_accessor("perf")
        class PerformanceDataFrameAccessor:
            """DataFrame性能优化访问器"""
            
            def __init__(self, pandas_obj):
                self._obj = pandas_obj
            
            def optimize_memory(self, inplace=False):
                """内存优化"""
                df = self._obj if inplace else self._obj.copy()
                
                for col in df.columns:
                    col_dtype = df[col].dtype
                    
                    # 整数优化
                    if col_dtype == 'int64':
                        c_min = df[col].min()
                        c_max = df[col].max()
                        
                        if c_min > np.iinfo(np.int32).min and c_max < np.iinfo(np.int32).max:
                            df[col] = df[col].astype('int32')
                        elif c_min > np.iinfo(np.int16).min and c_max < np.iinfo(np.int16).max:
                            df[col] = df[col].astype('int16')
                        elif c_min > np.iinfo(np.int8).min and c_max < np.iinfo(np.int8).max:
                            df[col] = df[col].astype('int8')
                    
                    # 浮点数优化
                    elif col_dtype == 'float64':
                        df[col] = df[col].astype('float32')
                    
                    # 字符串优化
                    elif col_dtype == 'object':
                        num_unique = df[col].nunique()
                        num_total = len(df[col])
                        
                        if num_unique / num_total < 0.5:
                            df[col] = df[col].astype('category')
                
                return df
            
            def chunked_apply(self, func, chunk_size=10000, **kwargs):
                """分块应用函数"""
                results = []
                
                for i in range(0, len(self._obj), chunk_size):
                    chunk = self._obj.iloc[i:i + chunk_size]
                    result = func(chunk, **kwargs)
                    results.append(result)
                
                # 根据返回类型合并结果
                if isinstance(results[0], pd.DataFrame):
                    return pd.concat(results, ignore_index=True)
                elif isinstance(results[0], pd.Series):
                    return pd.concat(results)
                elif isinstance(results[0], (int, float)):
                    return sum(results)
                else:
                    return results
            
            def parallel_apply(self, func, axis=0, n_jobs=-1, **kwargs):
                """并行应用函数"""
                try:
                    from joblib import Parallel, delayed
                    import multiprocessing
                    
                    if n_jobs == -1:
                        n_jobs = multiprocessing.cpu_count()
                    
                    if axis == 0:  # 按列
                        columns = self._obj.columns
                        results = Parallel(n_jobs=n_jobs)(
                            delayed(func)(self._obj[col], **kwargs)
                            for col in columns
                        )
                        return pd.Series(results, index=columns)
                    
                    elif axis == 1:  # 按行
                        # 分块处理
                        n_rows = len(self._obj)
                        chunk_size = max(1, n_rows // (n_jobs * 2))
                        
                        def process_chunk(chunk):
                            return chunk.apply(func, axis=1, **kwargs)
                        
                        results = Parallel(n_jobs=n_jobs)(
                            delayed(process_chunk)(self._obj.iloc[i:i + chunk_size])
                            for i in range(0, n_rows, chunk_size)
                        )
                        
                        return pd.concat(results, ignore_index=True)
                    
                except ImportError:
                    warnings.warn("joblib未安装,使用串行版本")
                    return self._obj.apply(func, axis=axis, **kwargs)
            
            def cache_operation(self, operation_name, func, *args, **kwargs):
                """缓存操作结果"""
                import hashlib
                import pickle
                
                # 生成缓存键
                cache_key_data = {
                    'operation': operation_name,
                    'args': args,
                    'kwargs': kwargs,
                    'data_hash': hashlib.md5(
                        pickle.dumps(self._obj.values)
                    ).hexdigest()
                }
                
                cache_key = hashlib.md5(
                    pickle.dumps(cache_key_data)
                ).hexdigest()
                
                # 检查缓存(这里用字典模拟,实际可用redis等)
                if hasattr(self, '_cache'):
                    if cache_key in self._cache:
                        print(f"使用缓存: {operation_name}")
                        return self._cache[cache_key]
                
                # 执行操作
                result = func(self._obj, *args, **kwargs)
                
                # 存储到缓存
                if not hasattr(self, '_cache'):
                    self._cache = {}
                self._cache[cache_key] = result
                
                return result
            
            def profile_operation(self, func, *args, **kwargs):
                """性能分析操作"""
                import cProfile
                import pstats
                import io
                import time
                
                print(f"\n性能分析: {func.__name__}")
                print("=" * 50)
                
                # 时间分析
                start_time = time.time()
                result = func(self._obj, *args, **kwargs)
                elapsed_time = time.time() - start_time
                
                print(f"执行时间: {elapsed_time:.4f} 秒")
                
                # CPU分析
                pr = cProfile.Profile()
                pr.enable()
                _ = func(self._obj.copy(), *args, **kwargs)
                pr.disable()
                
                s = io.StringIO()
                ps = pstats.Stats(pr, stream=s).sort_stats('cumulative')
                ps.print_stats(10)  # 显示前10个最耗时的函数
                
                print("CPU分析:")
                print(s.getvalue())
                
                return result

# 注册性能扩展
PerformanceExtension.register_all()

# 测试性能扩展
print("\n⚡ 测试性能优化扩展:")

# 创建大型测试数据集
large_df = pd.DataFrame({
    'A': np.random.randint(0, 100, 1000000),
    'B': np.random.randn(1000000),
    'C': ['category_' + str(i % 100) for i in range(1000000)],
    'D': np.random.choice(['X', 'Y', 'Z'], 1000000)
})

print(f"原始数据形状: {large_df.shape}")
print(f"原始内存使用: {large_df.memory_usage(deep=True).sum() / 1024**2:.2f} MB")

# 内存优化
print("\n1. 内存优化:")
optimized_df = large_df.perf.optimize_memory()
print(f"优化后内存: {optimized_df.memory_usage(deep=True).sum() / 1024**2:.2f} MB")
print(f"内存节省: {(1 - optimized_df.memory_usage(deep=True).sum()/large_df.memory_usage(deep=True).sum()) * 100:.1f}%")

# 分块处理示例
print("\n2. 分块处理演示:")

def complex_processing(chunk):
    """复杂处理函数"""
    chunk['processed'] = chunk['A'] * 2 + chunk['B'] * 3
    return chunk[['processed']]

chunked_result = large_df.perf.chunked_apply(complex_processing, chunk_size=50000)
print(f"分块处理结果形状: {chunked_result.shape}")

# 缓存演示
print("\n3. 缓存操作演示:")

def expensive_operation(df):
    """模拟耗时操作"""
    import time
    time.sleep(1)  # 模拟耗时操作
    return df['A'].mean()

# 第一次执行(会较慢)
print("第一次执行(无缓存)...")
start = time.time()
result1 = large_df.perf.cache_operation('mean_calc', expensive_operation)
time1 = time.time() - start
print(f"结果: {result1}, 时间: {time1:.2f}秒")

# 第二次执行(使用缓存)
print("\n第二次执行(使用缓存)...")
start = time.time()
result2 = large_df.perf.cache_operation('mean_calc', expensive_operation)
time2 = time.time() - start
print(f"结果: {result2}, 时间: {time2:.2f}秒")
print(f"缓存加速: {time1/time2:.1f}倍")

六、打包与分发扩展

6.1 创建可安装的扩展包

python 复制代码
python
# 扩展包的目录结构
"""
my_pandas_extensions/
├── setup.py
├── README.md
├── LICENSE
├── my_pandas_extensions/
│   ├── __init__.py
│   ├── finance.py
│   ├── ecommerce.py
│   ├── text.py
│   └── performance.py
└── tests/
    └── test_extensions.py
"""

# setup.py 示例
setup_py_content = '''
from setuptools import setup, find_packages

with open("README.md", "r", encoding="utf-8") as fh:
    long_description = fh.read()

setup(
    name="my-pandas-extensions",
    version="0.1.0",
    author="Your Name",
    author_email="your.email@example.com",
    description="A collection of useful pandas extensions for business analysis",
    long_description=long_description,
    long_description_content_type="text/markdown",
    url="https://github.com/yourusername/my-pandas-extensions",
    packages=find_packages(),
    classifiers=[
        "Programming Language :: Python :: 3",
        "Programming Language :: Python :: 3.7",
        "Programming Language :: Python :: 3.8",
        "Programming Language :: Python :: 3.9",
        "License :: OSI Approved :: MIT License",
        "Operating System :: OS Independent",
    ],
    python_requires=">=3.7",
    install_requires=[
        "pandas>=1.0.0",
        "numpy>=1.18.0",
    ],
    extras_require={
        "dev": [
            "pytest>=6.0",
            "pytest-cov>=2.0",
            "black>=20.0",
            "flake8>=3.8",
        ],
        "ecommerce": [
            "scikit-learn>=0.24",
        ],
    },
)
'''

print("📦 扩展包结构示例:")
print(setup_py_content)

# __init__.py 示例
init_py_content = '''
"""
My Pandas Extensions - A collection of useful pandas extensions
"""

from .finance import FinanceAccessor
from .ecommerce import EcommerceExtension
from .text import ChineseTextAccessor
from .performance import PerformanceExtension

# 自动注册所有扩展
def register_all_extensions():
    """Register all pandas extensions"""
    import pandas as pd
    
    # 注册访问器
    pd.api.extensions.register_dataframe_accessor("finance")(FinanceAccessor)
    pd.api.extensions.register_series_accessor("chinese")(ChineseTextAccessor)
    
    # 注册电商扩展
    EcommerceExtension.register_all()
    
    # 注册性能扩展
    PerformanceExtension.register_all()
    
    print("All pandas extensions registered successfully!")

# 版本信息
__version__ = "0.1.0"
__all__ = [
    "FinanceAccessor",
    "EcommerceExtension", 
    "ChineseTextAccessor",
    "PerformanceExtension",
    "register_all_extensions",
]
'''

print("\n__init__.py 内容示例:")
print(init_py_content)

# 使用示例文件
usage_example = '''
# 使用扩展包示例
import pandas as pd
import my_pandas_extensions as mpe

# 自动注册所有扩展
mpe.register_all_extensions()

# 使用金融分析扩展
df = pd.DataFrame({
    'revenue': [1000, 1500, 1200],
    'net_profit': [100, 180, 130],
    'assets': [5000, 6000, 5500],
    'equity': [3000, 3500, 3200]
})

# 计算财务比率
ratios = df.finance.profitability_ratios()
print(ratios)

# 使用中文文本处理
text_series = pd.Series(['这是一个测试', '你好世界'])
cleaned = text_series.chinese.clean()
print(cleaned)
'''

print("\n📖 使用示例:")
print(usage_example)

# 测试文件示例
test_file_content = '''
import pytest
import pandas as pd
import numpy as np
import my_pandas_extensions as mpe

class TestFinanceExtensions:
    """测试金融扩展"""
    
    def setup_method(self):
        """测试准备"""
        mpe.register_all_extensions()
        
        self.finance_df = pd.DataFrame({
            'revenue': [1000, 1500, 1200],
            'net_profit': [100, 180, 130],
            'assets': [5000, 6000, 5500],
            'equity': [3000, 3500, 3200]
        })
    
    def test_profitability_ratios(self):
        """测试盈利能力比率计算"""
        ratios = self.finance_df.finance.profitability_ratios()
        
        assert 'net_margin' in ratios.columns
        assert 'roe' in ratios.columns
        
        expected_net_margin = self.finance_df['net_profit'] / self.finance_df['revenue']
        pd.testing.assert_series_equal(
            ratios['net_margin'], 
            expected_net_margin
        )
    
    def test_missing_columns(self):
        """测试缺失列的情况"""
        df = pd.DataFrame({'revenue': [1000, 1500]})
        
        # 应该发出警告但不会报错
        with pytest.warns(UserWarning):
            ratios = df.finance.profitability_ratios()
            assert ratios.empty

class TestTextExtensions:
    """测试文本扩展"""
    
    def setup_method(self):
        mpe.register_all_extensions()
        
        self.text_series = pd.Series([
            '  你好,世界!  ',
            '测试-数据  ',
            '重要:信息'
        ])
    
    def test_text_cleaning(self):
        """测试文本清洗"""
        cleaned = self.text_series.chinese.clean()
        
        expected = pd.Series(['你好世界', '测试数据', '重要信息'])
        pd.testing.assert_series_equal(cleaned, expected)
    
    def test_word_count(self):
        """测试字数统计"""
        counts = self.text_series.chinese.word_count()
        
        expected = pd.Series([4, 4, 4])  # 每个字符串有4个中文字符
        pd.testing.assert_series_equal(counts, expected)

if __name__ == '__main__':
    pytest.main([__file__, '-v'])
'''

print("\n🧪 测试文件示例:")
print(test_file_content)

6.2 最佳实践与部署指南

python

python 复制代码
# 扩展开发最佳实践
class ExtensionBestPractices:
    """扩展开发最佳实践指南"""
    
    @staticmethod
    def get_guidelines():
        """获取扩展开发指南"""
        
        guidelines = {
            '设计原则': [
                '1. 单一职责:每个扩展只做一件事',
                '2. 向后兼容:确保现有代码不受影响',
                '3. 明确错误:提供清晰的错误信息',
                '4. 性能考虑:避免不必要的内存复制',
                '5. 类型安全:正确处理数据类型',
            ],
            
            '命名规范': [
                '1. 访问器名称:使用有意义的单数名词(如:.finance, .text)',
                '2. 方法命名:使用动词开头(如:calculate_ratios, clean_text)',
                '3. 类命名:使用有意义的名称并以Accessor/Extension结尾',
                '4. 避免冲突:避免使用pandas内置的名称',
            ],
            
            '错误处理': [
                '1. 验证输入:在访问器中验证DataFrame/Series结构',
                '2. 友好错误:提供清晰的错误信息和修复建议',
                '3. 优雅降级:当缺少必要数据时返回有意义的默认值或警告',
                '4. 类型检查:验证输入参数的类型和范围',
            ],
            
            '性能优化': [
                '1. 向量化:尽可能使用向量化操作而不是循环',
                '2. 内存效率:避免不必要的数据复制',
                '3. 延迟计算:只在需要时计算结果',
                '4. 缓存:为昂贵操作实现缓存机制',
            ],
            
            '文档规范': [
                '1. 完整的docstring:包含参数、返回值、示例',
                '2. 类型提示:使用Python类型提示',
                '3. 示例代码:提供清晰的用法示例',
                '4. 测试覆盖:编写全面的单元测试',
            ],
            
            '部署发布': [
                '1. 版本控制:使用语义化版本控制',
                '2. 依赖管理:明确声明依赖关系',
                '3. 测试发布:先在测试PyPI发布',
                '4. 持续集成:设置CI/CD流水线',
            ]
        }
        
        return guidelines

# 显示最佳实践
print("\n📚 扩展开发最佳实践:")
guidelines = ExtensionBestPractices.get_guidelines()

for category, items in guidelines.items():
    print(f"\n{category}:")
    for item in items:
        print(f"  {item}")

# 完整示例:业务就绪的扩展
class ProductionReadyExtension:
    """生产就绪的扩展示例"""
    
    @staticmethod
    def create_business_extension():
        """创建生产就绪的业务扩展"""
        
        business_extension_template = '''
"""
业务数据验证扩展 - 生产就绪版本
"""

import pandas as pd
import numpy as np
from typing import Dict, List, Optional, Union, Any
import warnings
from datetime import datetime

class BusinessDataValidator:
    """业务数据验证器"""
    
    @staticmethod
    def validate_financial_data(df: pd.DataFrame) -> Dict[str, Any]:
        """
        验证金融数据质量
        
        Parameters
        ----------
        df : pd.DataFrame
            包含金融数据的DataFrame
            
        Returns
        -------
        Dict[str, Any]
            验证结果,包含:
            - is_valid: 是否通过验证
            - issues: 发现的问题列表
            - summary: 验证摘要
        """
        
        issues = []
        
        # 1. 检查必要列
        required_columns = ['revenue', 'net_profit', 'assets']
        missing_columns = [col for col in required_columns if col not in df.columns]
        
        if missing_columns:
            issues.append({
                'type': 'missing_columns',
                'severity': 'high',
                'message': f'缺少必要列: {missing_columns}',
                'suggestion': '请确保数据包含这些列'
            })
        
        # 2. 检查数据类型
        numeric_columns = ['revenue', 'net_profit', 'assets']
        for col in numeric_columns:
            if col in df.columns:
                if not pd.api.types.is_numeric_dtype(df[col]):
                    issues.append({
                        'type': 'invalid_dtype',
                        'severity': 'medium',
                        'message': f'列 "{col}" 应为数值类型',
                        'suggestion': f'使用 pd.to_numeric() 转换数据类型'
                    })
        
        # 3. 检查数据范围
        if 'revenue' in df.columns:
            negative_revenue = df['revenue'] < 0
            if negative_revenue.any():
                issues.append({
                    'type': 'invalid_range',
                    'severity': 'high',
                    'message': f'发现 {negative_revenue.sum()} 条负收入记录',
                    'suggestion': '检查数据输入错误'
                })
        
        # 4. 检查一致性
        if all(col in df.columns for col in ['net_profit', 'revenue']):
            invalid_margin = df['net_profit'] > df['revenue']
            if invalid_margin.any():
                issues.append({
                    'type': 'inconsistency',
                    'severity': 'high',
                    'message': f'发现 {invalid_margin.sum()} 条净利大于收入的记录',
                    'suggestion': '验证数据计算逻辑'
                })
        
        # 5. 检查缺失值
        for col in required_columns:
            if col in df.columns:
                missing_count = df[col].isna().sum()
                if missing_count > 0:
                    issues.append({
                        'type': 'missing_values',
                        'severity': 'medium',
                        'message': f'列 "{col}" 有 {missing_count} 个缺失值',
                        'suggestion': '使用填充或删除处理缺失值'
                    })
        
        # 生成验证结果
        is_valid = len([i for i in issues if i['severity'] == 'high']) == 0
        
        result = {
            'is_valid': is_valid,
            'issues': issues,
            'summary': {
                'total_issues': len(issues),
                'high_severity': len([i for i in issues if i['severity'] == 'high']),
                'medium_severity': len([i for i in issues if i['severity'] == 'medium']),
                'low_severity': len([i for i in issues if i['severity'] == 'low'])
            },
            'timestamp': datetime.now().isoformat(),
            'data_shape': df.shape
        }
        
        return result
    
    @staticmethod
    def fix_common_issues(df: pd.DataFrame, validation_result: Dict[str, Any]) -> pd.DataFrame:
        """
        自动修复常见问题
        
        Parameters
        ----------
        df : pd.DataFrame
            原始数据
        validation_result : Dict[str, Any]
            验证结果
            
        Returns
        -------
        pd.DataFrame
            修复后的数据
        """
        
        df_fixed = df.copy()
        
        for issue in validation_result['issues']:
            if issue['type'] == 'missing_values':
                # 自动填充缺失值
                col = issue['message'].split('"')[1]
                if col in df_fixed.columns:
                    if pd.api.types.is_numeric_dtype(df_fixed[col]):
                        df_fixed[col] = df_fixed[col].fillna(df_fixed[col].median())
        
        return df_fixed


@pd.api.extensions.register_dataframe_accessor("validate")
class DataValidationAccessor:
    """数据验证访问器"""
    
    def __init__(self, pandas_obj):
        self._obj = pandas_obj
        self.validator = BusinessDataValidator()
    
    def financial_data(self) -> Dict[str, Any]:
        """验证金融数据"""
        return self.validator.validate_financial_data(self._obj)
    
    def fix_issues(self, validation_result: Optional[Dict[str, Any]] = None) -> pd.DataFrame:
        """修复数据问题"""
        if validation_result is None:
            validation_result = self.financial_data()
        
        return self.validator.fix_common_issues(self._obj, validation_result)
    
    def generate_report(self, validation_result: Dict[str, Any]) -> str:
        """生成验证报告"""
        
        report = []
        report.append("数据验证报告")
        report.append("=" * 50)
        report.append(f"验证时间: {validation_result['timestamp']}")
        report.append(f"数据形状: {validation_result['data_shape']}")
        report.append(f"是否有效: {'是' if validation_result['is_valid'] else '否'}")
        report.append(f"问题总数: {validation_result['summary']['total_issues']}")
        report.append("")
        
        if validation_result['issues']:
            report.append("发现的问题:")
            for i, issue in enumerate(validation_result['issues'], 1):
                report.append(f"{i}. [{issue['severity'].upper()}] {issue['message']}")
                report.append(f"   建议: {issue['suggestion']}")
        else:
            report.append("✓ 没有发现问题")
        
        return '\\n'.join(report)


def register_extensions():
    """注册所有扩展"""
    # 访问器已在装饰器中注册
    print("数据验证扩展已注册")
        '''
        
        return business_extension_template

# 显示生产就绪扩展
print("\n🏭 生产就绪扩展示例:")
production_extension = ProductionReadyExtension.create_business_extension()
print(production_extension)

相关推荐
AI_56782 小时前
测试用例“标准化”:TestRail实战技巧,从“用例编写”到“测试报告生成”
java·python·测试用例·testrail
喵手2 小时前
Python爬虫零基础入门【第二章:网页基础·第1节】网页是怎么工作的:URL、请求、响应、状态码?
爬虫·python·python爬虫实战·python爬虫工程化实战·python爬虫零基础入门·网页基础
忧郁的橙子.2 小时前
26期_01_Pyhton判断语句
python
快乐小胡!3 小时前
【自动化测试】Selenium选择/定位元素的基本方法
python·selenium·测试工具
2501_944934733 小时前
数据分析:汽车销售转型的职场跳板
数据挖掘·数据分析·汽车
高洁013 小时前
数字孪生与数字样机的技术基础:建模与仿真
python·算法·机器学习·transformer·知识图谱
喵手3 小时前
Python爬虫零基础入门【第二章:网页基础·第4节】新手最常栽的坑:编码、时区、空值、脏数据!
爬虫·python·python爬虫实战·python爬虫工程化实战·python爬虫零基础入门·python爬虫编码时区·爬虫编码时区
淡忘旧梦3 小时前
词错误率/WER算法讲解
人工智能·笔记·python·深度学习·算法
癫狂的兔子3 小时前
【Python】【爬虫】爬取虎扑网NBA排行数据
数据库·爬虫·python