MaxCompute Python UDF开发指南:从入门到精通

MaxCompute UDF基础概念

UDF(User-Defined Function)即用户自定义函数,当MaxCompute内建函数无法满足业务需求时,可以自行编写代码创建自定义函数。MaxCompute支持三种类型的UDF:

  • UDF(用户自定义标量函数) :一进一出,读入一行数据,输出一个值
  • UDTF(用户自定义表值函数) :一进多出,读入一行数据,输出多个值(可视为一张表)
  • UDAF(用户自定义聚合函数) :多进一出,将多条输入记录聚合成一个输出值

Python UDF开发基础

函数签名与数据类型

在开发Python UDF时,必须导入annotate模块并使用@annotate()注解定义函数签名:

python 复制代码
python
from odps.udf import annotate

@annotate("输入参数类型->返回值类型")

MaxCompute支持的数据类型包括:

  • 基本类型:bigint, string, double, boolean, datetime
  • 扩展类型:decimal, float, binary, date, char, varchar
  • 复杂类型:array, map, struct及其嵌套类型

UDF代码基本结构

一个完整的Python UDF至少包含以下部分:

python 复制代码
python
from odps.udf import annotate  # 导入注解模块

@annotate("string,bigint->string")  # 函数签名
class MyUDF(object):  # 自定义Python类
    def evaluate(self, text, number):  # 必须实现evaluate方法
        if text is None or number is None:
            return None
        return text + str(number)

标量函数UDF详解与案例

基本案例:字符串处理

python 复制代码
python
from odps.udf import annotate

@annotate("string->string")
class ToUpperCase(object):
    def evaluate(self, s):
        if s is None:
            return None
        return s.upper()

实用案例:日期格式转换

python 复制代码
python
from odps.udf import annotate
from datetime import datetime

@annotate("string,string->string")
class DateFormat(object):
    def evaluate(self, date_str, format_str):
        if date_str is None or format_str is None:
            return None
        try:
            # 假设输入格式为'yyyy-MM-dd'
            dt = datetime.strptime(date_str, '%Y-%m-%d')
            return dt.strftime(format_str)
        except:
            return None

使用示例:

sql 复制代码
sql
SELECT date_format('2025-03-17', '%Y年%m月%d日') FROM dual;
-- 结果: 2025年03月17日

聚合函数UDAF详解与案例

UDAF需要实现四个关键方法:

  1. new_buffer:创建中间结果缓冲区
  2. iterate:处理每条输入记录
  3. merge:合并中间结果
  4. terminate:生成最终结果

计算平均值案例

python 复制代码
python
from odps.udf import annotate
from odps.udf import BaseUDAF

@annotate('double->double')
class MyAverage(BaseUDAF):
    def new_buffer(self):
        # 返回[sum, count]初始值
        return [0.0, 0]
        
    def iterate(self, buffer, value):
        # 处理每个输入值
        if value is not None:
            buffer[0] += value  # 累加和
            buffer[1] += 1      # 计数
            
    def merge(self, buffer, partial_buffer):
        # 合并两个buffer
        buffer[0] += partial_buffer[0]  # 合并和
        buffer[1] += partial_buffer[1]  # 合并计数
        
    def terminate(self, buffer):
        # 计算最终结果
        if buffer[1] == 0:
            return None
        return buffer[0] / buffer[1]  # 返回平均值

使用示例:

sql 复制代码
sql
SELECT my_average(salary) FROM employees GROUP BY department_id;

带权重的中位数计算

python 复制代码
python
from odps.udf import annotate
from odps.udf import BaseUDAF

@annotate('double,double->double')
class WeightedMedian(BaseUDAF):
    def new_buffer(self):
        return {'values': [], 'weights': []}
        
    def iterate(self, buffer, value, weight):
        if value is not None and weight is not None:
            buffer['values'].append(value)
            buffer['weights'].append(weight)
            
    def merge(self, buffer, partial_buffer):
        buffer['values'].extend(partial_buffer['values'])
        buffer['weights'].extend(partial_buffer['weights'])
        
    def terminate(self, buffer):
        if not buffer['values']:
            return None
            
        # 按值排序
        pairs = sorted(zip(buffer['values'], buffer['weights']))
        values = [p[0] for p in pairs]
        weights = [p[1] for p in pairs]
        
        # 计算总权重
        total_weight = sum(weights)
        if total_weight == 0:
            return None
            
        # 找到中位数位置
        cumulative_weight = 0
        for i, w in enumerate(weights):
            cumulative_weight += w
            if cumulative_weight >= total_weight / 2:
                return values[i]

表值函数UDTF详解与案例

UDTF是唯一能够返回多行数据的UDF类型,需要实现process方法并通过forward输出结果。

字符串分割案例

python 复制代码
python
from odps.udf import annotate
from odps.udf import BaseUDTF

@annotate('string,string->string,bigint')
class SplitAndCount(BaseUDTF):
    def process(self, text, delimiter):
        if text is None or delimiter is None:
            return
            
        parts = text.split(delimiter)
        for i, part in enumerate(parts):
            self.forward(part, i)  # 输出多行结果

使用示例:

sql 复制代码
sql
SELECT t.part, t.position 
FROM my_table LATERAL VIEW split_and_count(text, ',') t AS part, position;

行转列(展开JSON数组)

python 复制代码
python
from odps.udf import annotate
from odps.udf import BaseUDTF
import json

@annotate('string->string,string')
class ExpandJsonArray(BaseUDTF):
    def process(self, json_str):
        if json_str is None:
            return
            
        try:
            data = json.loads(json_str)
            if isinstance(data, list):
                for item in data:
                    if isinstance(item, dict):
                        for key, value in item.items():
                            self.forward(key, str(value))
        except:
            pass

资源引用

在UDF中可以引用文件资源和表资源,大大增强了UDF的功能。

引用文件资源

python 复制代码
python
from odps.udf import annotate
from odps.distcache import get_cache_file

@annotate('string->string')
class TranslateUDF(object):
    def __init__(self):
        # 加载字典文件
        dict_file = get_cache_file('translation_dict.csv')
        self.trans_dict = {}
        
        for line in dict_file:
            line = line.strip()
            if line:
                src, tgt = line.split(',')
                self.trans_dict[src] = tgt
        dict_file.close()
        
    def evaluate(self, word):
        if word is None:
            return None
        return self.trans_dict.get(word, word)

注册和使用方法:

sql 复制代码
sql
-- 上传资源文件
ADD FILE translation_dict.csv;

-- 创建函数
CREATE FUNCTION translate AS 'TranslateUDF' USING 'translate.py', 'translation_dict.csv';

-- 使用函数
SELECT translate(word) FROM words;

引用表资源

python 复制代码
python
from odps.udf import annotate
from odps.distcache import get_cache_table

@annotate('string->double')
class ProductPriceUDF(object):
    def __init__(self):
        # 加载价格表
        self.price_dict = {}
        for product_id, price in get_cache_table('product_prices'):
            self.price_dict[product_id] = float(price)
        
    def evaluate(self, product_id):
        if product_id is None:
            return None
        return self.price_dict.get(product_id, 0.0)

性能优化建议

  1. 避免初始化开销 :将耗时操作放在__init__中进行,而不是evaluate方法中
  2. 批量处理:对于大数据量处理,可以考虑使用UDTF实现批量处理
  3. 内存控制 :UDF处理大数据时可能超出默认内存,可设置set odps.sql.udf.joiner.jvm.memory=xxxx;
  4. 数据类型选择:尽量使用原生数据类型,避免不必要的类型转换
  5. 异常处理:确保UDF能够妥善处理NULL值和异常情况

实际应用案例:地理位置计算

基于GPS数据计算两点间距离:

python 复制代码
from odps.udf import annotate
import math

@annotate('double,double,double,double->double')
class GeoDistance(object):
    EARTH_RADIUS = 6371000  # 地球半径(米)
    
    def evaluate(self, lat1, lng1, lat2, lng2):
        if None in (lat1, lng1, lat2, lng2):
            return None
            
        rad_lat1 = math.radians(lat1)
        rad_lat2 = math.radians(lat2)
        rad_lng1 = math.radians(lng1)
        rad_lng2 = math.radians(lng2)
        
        a = rad_lat1 - rad_lat2
        b = rad_lng1 - rad_lng2
        
        s = 2 * math.asin(math.sqrt(math.pow(math.sin(a/2), 2) + 
                          math.cos(rad_lat1) * math.cos(rad_lat2) * 
                          math.pow(math.sin(b/2), 2)))
        
        s = s * self.EARTH_RADIUS
        return round(s, 2)  # 返回距离(米)

使用方法:

sql 复制代码
SELECT geo_distance(39.9087202, 116.3974799, 39.9846100, 116.3176590) AS distance;
-- 结果: 约10145.79米(北京天安门到北京大学的直线距离)
相关推荐
明月与玄武44 分钟前
Spring Boot中的拦截器!
java·spring boot·后端
菲兹园长1 小时前
SpringBoot统一功能处理
java·spring boot·后端
muxue1781 小时前
go语言封装、继承与多态:
开发语言·后端·golang
开心码农1号2 小时前
Go语言中 源文件开头的 // +build 注释的用法
开发语言·后端·golang
北极象2 小时前
Go主要里程碑版本及其新增特性
开发语言·后端·golang
lyrhhhhhhhh2 小时前
Spring框架(1)
java·后端·spring
喝养乐多长不高3 小时前
Spring Web MVC基础理论和使用
java·前端·后端·spring·mvc·springmvc
莫轻言舞4 小时前
SpringBoot整合PDF导出功能
spring boot·后端·pdf
一切顺势而行4 小时前
kafka 面试总结
分布式·面试·kafka
玄武后端技术栈4 小时前
什么是死信队列?死信队列是如何导致的?
后端·rabbitmq·死信队列