# -*- coding: utf-8 -*-
# 导入 DataFrame 和 Series 类
from pandas import DataFrame, Series
# 导入自定义函数
from ._core import get_offset, verify_series
from ._math import zero
# 定义函数 _above_below,用于比较两个 Series 对象的大小关系
def _above_below(series_a: Series, series_b: Series, above: bool = True, asint: bool = True, offset: int = None, **kwargs):
# 确保 series_a 和 series_b 是 Series 对象
series_a = verify_series(series_a)
series_b = verify_series(series_b)
offset = get_offset(offset)
# 将 series_a 和 series_b 中的零值替换为 NaN
series_a.apply(zero)
series_b.apply(zero)
# 计算结果
if above:
current = series_a >= series_b
else:
current = series_a <= series_b
if asint:
current = current.astype(int)
# 偏移
if offset != 0:
current = current.shift(offset)
# 设置名称和类别
current.name = f"{series_a.name}_{'A' if above else 'B'}_{series_b.name}"
current.category = "utility"
return current
# 定义函数 above,用于比较两个 Series 对象的大小关系,series_a 大于等于 series_b
def above(series_a: Series, series_b: Series, asint: bool = True, offset: int = None, **kwargs):
return _above_below(series_a, series_b, above=True, asint=asint, offset=offset, **kwargs)
# 定义函数 above_value,用于比较 Series 对象和给定值的大小关系,series_a 大于等于 value
def above_value(series_a: Series, value: float, asint: bool = True, offset: int = None, **kwargs):
if not isinstance(value, (int, float, complex)):
print("[X] value is not a number")
return
series_b = Series(value, index=series_a.index, name=f"{value}".replace(".", "_"))
return _above_below(series_a, series_b, above=True, asint=asint, offset=offset, **kwargs)
# 定义函数 below,用于比较两个 Series 对象的大小关系,series_a 小于等于 series_b
def below(series_a: Series, series_b: Series, asint: bool = True, offset: int = None, **kwargs):
return _above_below(series_a, series_b, above=False, asint=asint, offset=offset, **kwargs)
# 定义函数 below_value,用于比较 Series 对象和给定值的大小关系,series_a 小于等于 value
def below_value(series_a: Series, value: float, asint: bool = True, offset: int = None, **kwargs):
if not isinstance(value, (int, float, complex)):
print("[X] value is not a number")
return
series_b = Series(value, index=series_a.index, name=f"{value}".replace(".", "_"))
return _above_below(series_a, series_b, above=False, asint=asint, offset=offset, **kwargs)
# 定义函数 cross_value,用于判断 Series 对象和给定值是否交叉,above 为 True 表示交叉在上方
def cross_value(series_a: Series, value: float, above: bool = True, asint: bool = True, offset: int = None, **kwargs):
series_b = Series(value, index=series_a.index, name=f"{value}".replace(".", "_"))
return cross(series_a, series_b, above, asint, offset, **kwargs)
# 定义函数 cross,用于判断两个 Series 对象是否交叉,above 为 True 表示交叉在上方
def cross(series_a: Series, series_b: Series, above: bool = True, asint: bool = True, offset: int = None, **kwargs):
series_a = verify_series(series_a)
series_b = verify_series(series_b)
offset = get_offset(offset)
series_a.apply(zero)
series_b.apply(zero)
# 计算结果
current = series_a > series_b # current is above
previous = series_a.shift(1) < series_b.shift(1) # previous is below
# above if both are true, below if both are false
cross = current & previous if above else ~current & ~previous
if asint:
cross = cross.astype(int)
# 偏移
if offset != 0:
cross = cross.shift(offset)
# 设置名称和类别
# 设置交叉系列的名称,根据条件选择不同的后缀
cross.name = f"{series_a.name}_{'XA' if above else 'XB'}_{series_b.name}"
# 设置交叉系列的类别为"utility"
cross.category = "utility"
# 返回交叉系列对象
return cross
# 根据给定的指标、阈值和参数,生成包含交叉信号的数据框
def signals(indicator, xa, xb, cross_values, xserie, xserie_a, xserie_b, cross_series, offset) -> DataFrame:
# 创建一个空的数据框
df = DataFrame()
# 如果 xa 不为空且为整数或浮点数类型
if xa is not None and isinstance(xa, (int, float)):
# 如果需要计算交叉值
if cross_values:
# 计算指标在阈值 xa 以上交叉的起始点
crossed_above_start = cross_value(indicator, xa, above=True, offset=offset)
# 计算指标在阈值 xa 以上交叉的结束点
crossed_above_end = cross_value(indicator, xa, above=False, offset=offset)
# 将交叉信号起始点和结束点添加到数据框中
df[crossed_above_start.name] = crossed_above_start
df[crossed_above_end.name] = crossed_above_end
else:
# 计算指标在阈值 xa 以上的信号
crossed_above = above_value(indicator, xa, offset=offset)
# 将信号添加到数据框中
df[crossed_above.name] = crossed_above
# 如果 xb 不为空且为整数或浮点数类型
if xb is not None and isinstance(xb, (int, float)):
# 如果需要计算交叉值
if cross_values:
# 计算指标在阈值 xb 以下交叉的起始点
crossed_below_start = cross_value(indicator, xb, above=True, offset=offset)
# 计算指标在阈值 xb 以下交叉的结束点
crossed_below_end = cross_value(indicator, xb, above=False, offset=offset)
# 将交叉信号起始点和结束点添加到数据框中
df[crossed_below_start.name] = crossed_below_start
df[crossed_below_end.name] = crossed_below_end
else:
# 计算指标在阈值 xb 以下的信号
crossed_below = below_value(indicator, xb, offset=offset)
# 将信号添加到数据框中
df[crossed_below.name] = crossed_below
# 如果 xserie_a 为空,则使用默认值 xserie
if xserie_a is None:
xserie_a = xserie
# 如果 xserie_b 为空,则使用默认值 xserie
if xserie_b is None:
xserie_b = xserie
# 如果 xserie_a 不为空且为有效的数据序列
if xserie_a is not None and verify_series(xserie_a):
# 如果需要计算交叉序列
if cross_series:
# 计算指标与 xserie_a 交叉的起始点
cross_serie_above = cross(indicator, xserie_a, above=True, offset=offset)
else:
# 计算指标在 xserie_a 以上的信号
cross_serie_above = above(indicator, xserie_a, offset=offset)
# 将信号添加到数据框中
df[cross_serie_above.name] = cross_serie_above
# 如果 xserie_b 不为空且为有效的数据序列
if xserie_b is not None and verify_series(xserie_b):
# 如果需要计算交叉序列
if cross_series:
# 计算指标与 xserie_b 交叉的起始点
cross_serie_below = cross(indicator, xserie_b, above=False, offset=offset)
else:
# 计算指标在 xserie_b 以下的信号
cross_serie_below = below(indicator, xserie_b, offset=offset)
# 将信号添加到数据框中
df[cross_serie_below.name] = cross_serie_below
# 返回生成的数据框
return df
.\pandas-ta\pandas_ta\utils\_time.py
py复制代码
# -*- coding: utf-8 -*-
# 从 datetime 模块中导入 datetime 类
from datetime import datetime
# 从 time 模块中导入 localtime 和 perf_counter 函数
from time import localtime, perf_counter
# 从 typing 模块中导入 Tuple 类型
from typing import Tuple
# 从 pandas 模块中导入 DataFrame 和 Timestamp 类
from pandas import DataFrame, Timestamp
# 从 pandas_ta 模块中导入 EXCHANGE_TZ 和 RATE 变量
from pandas_ta import EXCHANGE_TZ, RATE
# 定义函数 df_dates,接受一个 DataFrame 和日期元组作为参数,返回过滤后的 DataFrame
def df_dates(df: DataFrame, dates: Tuple[str, list] = None) -> DataFrame:
"""Yields the DataFrame with the given dates"""
# 若日期元组为空,则返回 None
if dates is None: return None
# 如果日期元组不是列表类型,则将其转换为列表
if not isinstance(dates, list):
dates = [dates]
# 返回过滤后的 DataFrame,只包含日期元组中指定的日期
return df[df.index.isin(dates)]
# 定义函数 df_month_to_date,接受一个 DataFrame 作为参数,返回当月的 DataFrame
def df_month_to_date(df: DataFrame) -> DataFrame:
"""Yields the Month-to-Date (MTD) DataFrame"""
# 获取当前日期的月初日期,并判断 DataFrame 的索引是否大于等于该日期
in_mtd = df.index >= Timestamp.now().strftime("%Y-%m-01")
# 如果有数据在当月,则返回当月的 DataFrame
if any(in_mtd): return df[in_mtd]
# 否则返回原始 DataFrame
return df
# 定义函数 df_quarter_to_date,接受一个 DataFrame 作为参数,返回当季的 DataFrame
def df_quarter_to_date(df: DataFrame) -> DataFrame:
"""Yields the Quarter-to-Date (QTD) DataFrame"""
# 获取当前日期,并遍历季度开始的月份
now = Timestamp.now()
for m in [1, 4, 7, 10]:
# 如果当前月份小于等于遍历到的月份
if now.month <= m:
# 获取季度开始日期,并判断 DataFrame 的索引是否大于等于该日期
in_qtr = df.index >= datetime(now.year, m, 1).strftime("%Y-%m-01")
# 如果有数据在当季,则返回当季的 DataFrame
if any(in_qtr): return df[in_qtr]
# 否则返回从当前月份开始的 DataFrame
return df[df.index >= now.strftime("%Y-%m-01")]
# 定义函数 df_year_to_date,接受一个 DataFrame 作为参数,返回当年的 DataFrame
def df_year_to_date(df: DataFrame) -> DataFrame:
"""Yields the Year-to-Date (YTD) DataFrame"""
# 获取当前日期的年初日期,并判断 DataFrame 的索引是否大于等于该日期
in_ytd = df.index >= Timestamp.now().strftime("%Y-01-01")
# 如果有数据在当年,则返回当年的 DataFrame
if any(in_ytd): return df[in_ytd]
# 否则返回原始 DataFrame
return df
# 定义函数 final_time,接受一个起始时间参数,返回从起始时间到当前时间的耗时
def final_time(stime: float) -> str:
"""Human readable elapsed time. Calculates the final time elasped since
stime and returns a string with microseconds and seconds."""
# 计算当前时间与起始时间的差值
time_diff = perf_counter() - stime
# 返回耗时的字符串,包含毫秒和秒
return f"{time_diff * 1000:2.4f} ms ({time_diff:2.4f} s)"
# 定义函数 get_time,接受交易所名称、是否显示全信息和是否返回字符串作为参数,返回当前时间及交易所时间信息
def get_time(exchange: str = "NYSE", full:bool = True, to_string:bool = False) -> Tuple[None, str]:
"""Returns Current Time, Day of the Year and Percentage, and the current
time of the selected Exchange."""
# 默认交易所时区为东部时间(NYSE)
tz = EXCHANGE_TZ["NYSE"]
# 如果传入的交易所名称为字符串类型
if isinstance(exchange, str):
# 将交易所名称转换为大写
exchange = exchange.upper()
# 获取对应交易所的时区信息
tz = EXCHANGE_TZ[exchange]
# 获取当前时间
today = Timestamp.now()
# 格式化日期字符串
date = f"{today.day_name()} {today.month_name()} {today.day}, {today.year}"
# 获取当前时间在交易所时区的时间
_today = today.timetuple()
exchange_time = f"{(_today.tm_hour + tz) % 24}:{_today.tm_min:02d}:{_today.tm_sec:02d}"
# 如果需要显示全信息
if full:
# 获取本地时间信息
lt = localtime()
local_ = f"Local: {lt.tm_hour}:{lt.tm_min:02d}:{lt.tm_sec:02d} {lt.tm_zone}"
# 计算当天在一年中的日期和百分比
doy = f"Day {today.dayofyear}/365 ({100 * round(today.dayofyear/365, 2):.2f}%)"
exchange_ = f"{exchange}: {exchange_time}"
# 构建包含完整信息的字符串
s = f"{date}, {exchange_}, {local_}, {doy}"
else:
# 构建简略信息的字符串
s = f"{date}, {exchange}: {exchange_time}"
# 如果需要返回字符串,则返回构建的字符串,否则打印字符串并返回 None
return s if to_string else print(s)
# 定义函数 total_time,接受一个 DataFrame 和时间间隔类型参数作为输入,返回 DataFrame 的总时间间隔
def total_time(df: DataFrame, tf: str = "years") -> float:
"""Calculates the total time of a DataFrame. Difference of the Last and
First index. Options: 'months', 'weeks', 'days', 'hours', 'minutes'
and 'seconds'. Default: 'years'.
Useful for annualization."""
# 计算 DataFrame 的总时间间
TimeFrame = {
"years": time_diff.days / RATE["TRADING_DAYS_PER_YEAR"], # 计算时间差对应的年数
"months": time_diff.days / 30.417, # 计算时间差对应的月数
"weeks": time_diff.days / 7, # 计算时间差对应的周数
"days": time_diff.days, # 计算时间差对应的天数
"hours": time_diff.days * 24, # 计算时间差对应的小时数
"minutes": time_diff.total_seconds() / 60, # 计算时间差对应的分钟数
"seconds": time_diff.total_seconds() # 计算时间差对应的秒数
}
if isinstance(tf, str) and tf in TimeFrame.keys(): # 检查 tf 是否为字符串且在 TimeFrame 字典的键中
return TimeFrame[tf] # 返回对应 tf 的时间差
return TimeFrame["years"] # 如果 tf 不在 TimeFrame 字典的键中,则返回默认的年数时间差
# 将 DataFrame 的索引转换为 UTC 时区,或者使用 tz_convert 将索引设置为 UTC 时区
def to_utc(df: DataFrame) -> DataFrame:
# 检查 DataFrame 是否为空
if not df.empty:
try:
# 尝试将索引本地化为 UTC 时区
df.index = df.index.tz_localize("UTC")
except TypeError:
# 如果出现 TypeError,则使用 tz_convert 将索引转换为 UTC 时区
df.index = df.index.tz_convert("UTC")
# 返回处理后的 DataFrame
return df
# 别名
mtd = df_month_to_date
qtd = df_quarter_to_date
ytd = df_year_to_date
.\pandas-ta\pandas_ta\utils\__init__.py
py复制代码
# 设置文件编码为UTF-8,确保可以正确处理中文等特殊字符
# 导入模块中的所有内容,这些模块包括 _candles、_core、_math、_signals、_time、_metrics 和 data
from ._candles import *
from ._core import *
from ._math import *
from ._signals import *
from ._time import *
from ._metrics import *
from .data import *
.\pandas-ta\pandas_ta\volatility\aberration.py
py复制代码
# -*- coding: utf-8 -*-
# from numpy import sqrt as npsqrt # 导入 numpy 中的 sqrt 函数并重命名为 npsqrt(已注释掉)
from pandas import DataFrame # 从 pandas 库中导入 DataFrame 类
from .atr import atr # 从当前包中的 atr 模块中导入 atr 函数
from pandas_ta.overlap import hlc3, sma # 从 pandas_ta.overlap 模块中导入 hlc3 和 sma 函数
from pandas_ta.utils import get_offset, verify_series # 从 pandas_ta.utils 模块中导入 get_offset 和 verify_series 函数
def aberration(high, low, close, length=None, atr_length=None, offset=None, **kwargs):
"""Indicator: Aberration (ABER)"""
# Validate arguments
# 确认参数合法性,若参数未指定或非正整数,则使用默认值
length = int(length) if length and length > 0 else 5
atr_length = int(atr_length) if atr_length and atr_length > 0 else 15
_length = max(atr_length, length) # 选择最大长度作为计算时使用的长度
high = verify_series(high, _length) # 确认 high Series 的合法性和长度
low = verify_series(low, _length) # 确认 low Series 的合法性和长度
close = verify_series(close, _length) # 确认 close Series 的合法性和长度
offset = get_offset(offset) # 获取偏移量
if high is None or low is None or close is None: return # 如果输入数据有缺失,则返回空值
# Calculate Result
# 计算结果
atr_ = atr(high=high, low=low, close=close, length=atr_length) # 计算 ATR 指标
jg = hlc3(high=high, low=low, close=close) # 计算 JG(typical price,即三价均价)
zg = sma(jg, length) # 计算 ZG(SMA of JG)
sg = zg + atr_ # 计算 SG(ZG + ATR)
xg = zg - atr_ # 计算 XG(ZG - ATR)
# Offset
# 偏移结果
if offset != 0:
zg = zg.shift(offset) # 对 ZG 进行偏移
sg = sg.shift(offset) # 对 SG 进行偏移
xg = xg.shift(offset) # 对 XG 进行偏移
atr_ = atr_.shift(offset) # 对 ATR 进行偏移
# Handle fills
# 处理填充缺失值
if "fillna" in kwargs:
zg.fillna(kwargs["fillna"], inplace=True) # 使用指定值填充 ZG 中的缺失值
sg.fillna(kwargs["fillna"], inplace=True) # 使用指定值填充 SG 中的缺失值
xg.fillna(kwargs["fillna"], inplace=True) # 使用指定值填充 XG 中的缺失值
atr_.fillna(kwargs["fillna"], inplace=True) # 使用指定值填充 ATR 中的缺失值
if "fill_method" in kwargs:
zg.fillna(method=kwargs["fill_method"], inplace=True) # 使用指定的填充方法填充 ZG 中的缺失值
sg.fillna(method=kwargs["fill_method"], inplace=True) # 使用指定的填充方法填充 SG 中的缺失值
xg.fillna(method=kwargs["fill_method"], inplace=True) # 使用指定的填充方法填充 XG 中的缺失值
atr_.fillna(method=kwargs["fill_method"], inplace=True) # 使用指定的填充方法填充 ATR 中的缺失值
# Name and Categorize it
# 命名和分类
_props = f"_{length}_{atr_length}" # 用于生成属性名称的后缀
zg.name = f"ABER_ZG{_props}" # 设置 ZG Series 的名称
sg.name = f"ABER_SG{_props}" # 设置 SG Series 的名称
xg.name = f"ABER_XG{_props}" # 设置 XG Series 的名称
atr_.name = f"ABER_ATR{_props}" # 设置 ATR Series 的名称
zg.category = sg.category = "volatility" # 设置 ZG 和 SG Series 的分类为波动性
xg.category = atr_.category = zg.category # 设置 XG 和 ATR Series 的分类与 ZG 相同
# Prepare DataFrame to return
# 准备要返回的 DataFrame
data = {zg.name: zg, sg.name: sg, xg.name: xg, atr_.name: atr_} # 构建数据字典
aberdf = DataFrame(data) # 使用数据字典创建 DataFrame
aberdf.name = f"ABER{_props}" # 设置 DataFrame 的名称
aberdf.category = zg.category # 设置 DataFrame 的分类与 ZG 相同
return aberdf # 返回计算结果的 DataFrame
aberration.__doc__ = \
"""Aberration
A volatility indicator similar to Keltner Channels.
Sources:
Few internet resources on definitive definition.
Request by Github user homily, issue #46
Calculation:
Default Inputs:
length=5, atr_length=15
ATR = Average True Range
SMA = Simple Moving Average
ATR = ATR(length=atr_length)
JG = TP = HLC3(high, low, close)
ZG = SMA(JG, length)
SG = ZG + ATR
XG = ZG - ATR
Args:
high (pd.Series): Series of 'high's
low (pd.Series): Series of 'low's
close (pd.Series): Series of 'close's
length (int): The short period. Default: 5
atr_length (int): The short period. Default: 15
offset (int): How many periods to offset the result. Default: 0
Kwargs:
fillna (value, optional): pd.DataFrame.fillna(value)
""" # 设置 aberration 函数的文档字符串
fill_method (value, optional): 填充方法的类型
# 返回一个 pandas DataFrame,包含 zg、sg、xg、atr 列
.\pandas-ta\pandas_ta\volatility\accbands.py
py复制代码
# -*- coding: utf-8 -*-
# 从 pandas 库中导入 DataFrame 类
from pandas import DataFrame
# 从 pandas_ta 库中的 overlap 模块中导入 ma 函数
from pandas_ta.overlap import ma
# 从 pandas_ta 库中的 utils 模块中导入 get_drift, get_offset, non_zero_range, verify_series 函数
from pandas_ta.utils import get_drift, get_offset, non_zero_range, verify_series
# 定义函数 accbands,用于计算加速带指标
def accbands(high, low, close, length=None, c=None, drift=None, mamode=None, offset=None, **kwargs):
"""Indicator: Acceleration Bands (ACCBANDS)"""
# 验证参数
# 若 length 存在且大于 0,则将其转换为整数类型,否则设为 20
length = int(length) if length and length > 0 else 20
# 若 c 存在且大于 0,则将其转换为浮点数类型,否则设为 4
c = float(c) if c and c > 0 else 4
# 若 mamode 不为字符串类型,则设为 "sma"
mamode = mamode if isinstance(mamode, str) else "sma"
# 验证 high、low、close 系列,使其长度为 length
high = verify_series(high, length)
low = verify_series(low, length)
close = verify_series(close, length)
# 获取 drift 和 offset
drift = get_drift(drift)
offset = get_offset(offset)
# 若 high、low、close 存在空值,则返回空值
if high is None or low is None or close is None: return
# 计算结果
# 计算 high 和 low 的非零范围
high_low_range = non_zero_range(high, low)
# 计算 high_low_range 与 (high + low) 的比值
hl_ratio = high_low_range / (high + low)
# 将 hl_ratio 乘以 c
hl_ratio *= c
# 计算下轨线 _lower
_lower = low * (1 - hl_ratio)
# 计算上轨线 _upper
_upper = high * (1 + hl_ratio)
# 计算移动平均值
lower = ma(mamode, _lower, length=length)
mid = ma(mamode, close, length=length)
upper = ma(mamode, _upper, length=length)
# 对结果进行位移
if offset != 0:
lower = lower.shift(offset)
mid = mid.shift(offset)
upper = upper.shift(offset)
# 处理填充
if "fillna" in kwargs:
lower.fillna(kwargs["fillna"], inplace=True)
mid.fillna(kwargs["fillna"], inplace=True)
upper.fillna(kwargs["fillna"], inplace=True)
if "fill_method" in kwargs:
lower.fillna(method=kwargs["fill_method"], inplace=True)
mid.fillna(method=kwargs["fill_method"], inplace=True)
upper.fillna(method=kwargs["fill_method"], inplace=True)
# 命名和分类
lower.name = f"ACCBL_{length}"
mid.name = f"ACCBM_{length}"
upper.name = f"ACCBU_{length}"
mid.category = upper.category = lower.category = "volatility"
# 准备返回的 DataFrame
data = {lower.name: lower, mid.name: mid, upper.name: upper}
accbandsdf = DataFrame(data)
accbandsdf.name = f"ACCBANDS_{length}"
accbandsdf.category = mid.category
return accbandsdf
# 设置函数文档字符串
accbands.__doc__ = \
"""Acceleration Bands (ACCBANDS)
Acceleration Bands created by Price Headley plots upper and lower envelope
bands around a simple moving average.
Sources:
https://www.tradingtechnologies.com/help/x-study/technical-indicator-definitions/acceleration-bands-abands/
Calculation:
Default Inputs:
length=10, c=4
EMA = Exponential Moving Average
SMA = Simple Moving Average
HL_RATIO = c * (high - low) / (high + low)
LOW = low * (1 - HL_RATIO)
HIGH = high * (1 + HL_RATIO)
if 'ema':
LOWER = EMA(LOW, length)
MID = EMA(close, length)
UPPER = EMA(HIGH, length)
else:
LOWER = SMA(LOW, length)
MID = SMA(close, length)
UPPER = SMA(HIGH, length)
Args:
high (pd.Series): Series of 'high's
low (pd.Series): Series of 'low's
close (pd.Series): Series of 'close's
"""
# 表示参数 `length` 是一个整数,代表周期。默认值为 10
length (int): It's period. Default: 10
# 表示参数 `c` 是一个整数,代表乘数。默认值为 4
c (int): Multiplier. Default: 4
# 表示参数 `mamode` 是一个字符串,参见 `ta.ma` 的帮助文档。默认值为 'sma'
mamode (str): See ```help(ta.ma)```py. Default: 'sma'
# 表示参数 `drift` 是一个整数,代表差异周期。默认值为 1
drift (int): The difference period. Default: 1
# 表示参数 `offset` 是一个整数,代表结果的偏移周期数。默认值为 0
offset (int): How many periods to offset the result. Default: 0
# 函数参数,用于指定填充缺失值的值,可选参数
fillna (value, optional): pd.DataFrame.fillna(value)
# 函数参数,指定填充方法的类型,可选参数
fill_method (value, optional): Type of fill method
# 返回值,返回一个 DataFrame,包含 lower、mid、upper 列
pd.DataFrame: lower, mid, upper columns.
.\pandas-ta\pandas_ta\volatility\atr.py
py复制代码
# -*- coding: utf-8 -*-
# 导入 true_range 模块
from .true_range import true_range
# 导入 Imports 模块
from pandas_ta import Imports
# 导入 ma 模块
from pandas_ta.overlap import ma
# 导入 get_drift, get_offset, verify_series 模块
from pandas_ta.utils import get_drift, get_offset, verify_series
# 定义 ATR 函数,计算平均真实范围
def atr(high, low, close, length=None, mamode=None, talib=None, drift=None, offset=None, **kwargs):
"""Indicator: Average True Range (ATR)"""
# 验证参数
length = int(length) if length and length > 0 else 14
mamode = mamode.lower() if mamode and isinstance(mamode, str) else "rma"
high = verify_series(high, length)
low = verify_series(low, length)
close = verify_series(close, length)
drift = get_drift(drift)
offset = get_offset(offset)
mode_tal = bool(talib) if isinstance(talib, bool) else True
if high is None or low is None or close is None: return
# 计算结果
if Imports["talib"] and mode_tal:
from talib import ATR
atr = ATR(high, low, close, length)
else:
tr = true_range(high=high, low=low, close=close, drift=drift)
atr = ma(mamode, tr, length=length)
percentage = kwargs.pop("percent", False)
if percentage:
atr *= 100 / close
# 偏移
if offset != 0:
atr = atr.shift(offset)
# 处理填充
if "fillna" in kwargs:
atr.fillna(kwargs["fillna"], inplace=True)
if "fill_method" in kwargs:
atr.fillna(method=kwargs["fill_method"], inplace=True)
# 命名和分类
atr.name = f"ATR{mamode[0]}_{length}{'p' if percentage else ''}"
atr.category = "volatility"
return atr
# 设置 ATR 函数的文档字符串
atr.__doc__ = \
"""Average True Range (ATR)
Averge True Range is used to measure volatility, especially volatility caused by
gaps or limit moves.
Sources:
https://www.tradingview.com/wiki/Average_True_Range_(ATR)
Calculation:
Default Inputs:
length=14, drift=1, percent=False
EMA = Exponential Moving Average
SMA = Simple Moving Average
WMA = Weighted Moving Average
RMA = WildeR's Moving Average
TR = True Range
tr = TR(high, low, close, drift)
if 'ema':
ATR = EMA(tr, length)
elif 'sma':
ATR = SMA(tr, length)
elif 'wma':
ATR = WMA(tr, length)
else:
ATR = RMA(tr, length)
if percent:
ATR *= 100 / close
Args:
high (pd.Series): Series of 'high's
low (pd.Series): Series of 'low's
close (pd.Series): Series of 'close's
length (int): It's period. Default: 14
mamode (str): See ```help(ta.ma)```py. Default: 'rma'
talib (bool): If TA Lib is installed and talib is True, Returns the TA Lib
version. Default: True
drift (int): The difference period. Default: 1
offset (int): How many periods to offset the result. Default: 0
Kwargs:
percent (bool, optional): Return as percentage. Default: False
fillna (value, optional): pd.DataFrame.fillna(value)
fill_method (value, optional): Type of fill method
Returns:
pd.Series: New feature generated.
"""