目录
描述:
1 本例结合实际应用场景描述散点图的使用。在财报分析中,需要将数值放在同行业中进行比较,从而判断是否异常。
2 散点图显示部分可以当通用工具使用。
3 数据计算使用compile动态执行代码,返回固定格式的数据进行显示,尽最大可能实现工具的灵活性。
效果:

代码:
返回结果对象
@dataclass
class MultiScatterObj:
title:str=''
df:pd.DataFrame=pd.DataFrame()
col_dict: dict = field(default_factory=dict)
type_dict: dict = field(default_factory=dict)
focus_target: str = ''
error_msg:str = ''
status:str='ok'
pass

1 title 标题
2 type_dict 行业分类标准
4 col_dict 计算的指标
5 focus_target 要关注的股票
字符型横坐标
class StrAxisItem(pg.AxisItem):
def __init__(self,ticks,*args,**kwargs):
pg.AxisItem.__init__(self,*args,**kwargs)
self.x_values = [x[0] for x in ticks]
self.x_strings = [x[1] for x in ticks]
pass
def tickStrings(self, values, scale, spacing):
strings = []
for v in values:
vs = v*scale
if vs in self.x_values:
vstr = self.x_strings[self.x_values.index(vs)]
else:
vstr = ''
strings.append(vstr)
return strings
通用散点图工具
class ScatterGraphWidget(pg.PlotWidget):
def __init__(self):
super().__init__()
self.init_data()
pass
def init_data(self):
self.whole_df = None
self.cur_targetItem = None
self.focus_targetItem = None
self.color_point = (255, 255, 0)
self.color_star = (220, 20, 60)
self.color_focus = (255,140,0)
self.color_mean = (255, 0, 255)
self.color_median = (34, 139, 34)
with open(DATA_DIR+'ticker_name.json','r',encoding='utf-8') as f:
self.ticker_name_dict = json.load(f)
pass
def set_data(self, df: pd.DataFrame,focus_ticker:str=None):
self.clear()
if df.empty:
return
self.whole_df = df
self.xTicks = self.whole_df.loc[:, ['x', 'ticker']].values
self.x = self.whole_df['x'].to_list()
self.y = self.whole_df['target'].to_list()
horAxis = StrAxisItem(ticks=self.xTicks, orientation='bottom')
self.setAxisItems({'bottom': horAxis})
scatters = pg.ScatterPlotItem(
hoverable=True,
hoverPen=pg.mkPen('g'),
tip=None
)
spots = []
for x0, y0 in zip(self.x, self.y):
spots.append({
'pos': (x0, y0),
'size': 10,
'pen': {'color': self.color_point, 'width': 2},
'brush': pg.mkBrush(color=self.color_point)
})
scatters.addPoints(spots)
self.addItem(scatters)
self.label = pg.TextItem()
self.addItem(self.label, ignoreBounds=True)
if focus_ticker:
if self.focus_targetItem:
self.removeItem(self.focus_targetItem)
df00 = df.loc[df['ticker']==focus_ticker]
if not df00.empty:
index_x = df00.iloc[0]['x']
self.focus_targetItem = pg.TargetItem(
pos=(index_x, self.y[index_x]),
movable=False,
size=20,
symbol='t1',
pen=self.color_focus,
brush=self.color_focus
)
self.addItem(self.focus_targetItem)
# 添加中位数线 和 平均数线
mean_line = pg.InfiniteLine(pos=(0,self.whole_df['target'].mean()),movable=False,angle=0,pen=pg.mkPen({'color':self.color_mean,'width':2}),label=f'{self.whole_df["target"].mean():,} 平均数',labelOpts={'position':0.05,'color': (255, 255, 255), 'movable': True, 'fill': (self.color_mean[0], self.color_mean[1], self.color_mean[2], 100)})
median_line = pg.InfiniteLine(pos=(0,self.whole_df['target'].median()),movable=False,angle=0,pen=pg.mkPen({'color':self.color_median,'width':2}),label=f'{self.whole_df["target"].median():,} 中位数',labelOpts={'position':0.05,'color': (255, 255, 255), 'movable': True, 'fill': (self.color_median[0], self.color_median[1], self.color_median[2], 100)})
self.addItem(mean_line)
self.addItem(median_line)
scatters.sigClicked.connect(self.scatters_sigClicked)
scatters.sigHovered.connect(self.scatters_sigHovered)
self.enableAutoRange()
pass
def set_content_empty(self):
self.clear()
pass
def set_targetItem(self,ticker:str):
df = self.whole_df.copy()
df = df.loc[df['ticker'] == ticker]
if df.empty:
return
index_x = df.iloc[0]['x']
if self.cur_targetItem:
self.removeItem(self.cur_targetItem)
self.cur_targetItem = pg.TargetItem(
pos=(index_x, self.y[index_x]),
movable=False,
size=20,
symbol='star',
pen=self.color_star,
brush=self.color_star
)
self.addItem(self.cur_targetItem)
pass
def scatters_sigClicked(self, plot, points):
# 将单击的股票代码发送给左侧图
if len(points) <= 0:
return
index_x = points[0].index()
if self.cur_targetItem:
self.removeItem(self.cur_targetItem)
self.cur_targetItem = pg.TargetItem(
pos=(index_x, self.y[index_x]),
movable=False,
size=20,
symbol='star',
pen=self.color_star,
brush=self.color_star
)
self.addItem(self.cur_targetItem)
pass
def scatters_sigHovered(self, plot, points):
if len(points) <= 0:
return
index_x = points[0].index()
x_str = self.xTicks[index_x][1]
y_val = self.y[index_x]
x_str00 = self.ticker_name_dict.get(x_str,x_str)
html_str = '<p style="color:white;font-size:18px;font-weight:bold;">' + x_str00 + ' ' + f'{y_val:,}' + '</p>'
self.label.setHtml(html_str)
self.label.setPos(points[0].pos())
pass
def wheelEvent(self, ev):
if len(self.whole_df) <= 0:
super().wheelEvent(ev)
else:
delta = ev.angleDelta().x()
if delta == 0:
delta = ev.angleDelta().y()
s = 1.001 ** delta
before_xmin, before_xmax = self.viewRange()[0]
val_x = self.getViewBox().mapSceneToView(ev.position()).x()
after_xmin = int(val_x - (val_x - before_xmin) // s)
after_xmax = int(val_x + (before_xmax - val_x) // s)
if after_xmin < 1:
after_xmin = 0
if after_xmin >= len(self.whole_df):
after_xmin = max(len(self.whole_df) - 3, len(self.whole_df) - 1)
if after_xmax < 1:
after_xmax = min(len(self.whole_df) - 1, 1)
if after_xmax >= len(self.whole_df):
after_xmax = len(self.whole_df) - 1
# print(after_xmin,after_xmax)
df00 = self.whole_df.loc[
(self.whole_df['x'] >= after_xmin) & (self.whole_df['x'] <= after_xmax)].copy()
after_ymin = df00['target'].min()
after_ymax = df00['target'].max()
self.setXRange(after_xmin, after_xmax)
self.setYRange(after_ymin, after_ymax)
pass
pass
1)本例中散点图增加了一些与实际业务相关的数据
2)set_data方法需要带入df、focus_ticker(可为空)
2.1)df必须要有x、ticker两个字段,x为递增整数,ticker为横坐标要显示的字符
2.2)df中target字段为y轴数值
工具主界面
class PyExcuteGraphShowWidgetII(QWidget):
def __init__(self):
super().__init__()
self.setWindowTitle('py文件执行并显示结果(散点图)')
self.setMinimumSize(QSize(1000,800))
label00 = QLabel('选择py文件:')
self.lineedit_file = QLineEdit()
btn_choice = QPushButton('选择文件',clicked=self.btn_choice_clicked)
self.btn_excute = QPushButton('执行',clicked=self.btn_excute_clicked)
btn_download = QPushButton('下载数据',clicked=self.btn_download_clicked)
self.label_title = QLabel('指标', alignment=Qt.AlignmentFlag.AlignHCenter)
self.label_title.setStyleSheet("font-size:28px;color:#CC2EFA;")
label01 = QLabel('行业标准')
self.combo_type = QComboBox()
self.combo_type.currentIndexChanged.connect(self.combo_type_currentIndexChanged)
self.label_industry = QLabel('所属行业')
self.label_industry.setStyleSheet("font-size:20px;color:#CC2EFA;")
label02 = QLabel('日期')
self.combo_date = QComboBox()
self.combo_date.currentIndexChanged.connect(self.combo_date_currentIndexChanged)
self.label_date = QLabel('日期')
self.label_date.setStyleSheet("font-size:20px;color:#CC2EFA;")
self.radioButtonGroup = QButtonGroup()
self.radioButtonGroup.buttonClicked.connect(self.radioButtonGroup_buttonClicked)
self.layout_radio = QHBoxLayout()
groupBox = QGroupBox('指标')
groupBox.setLayout(self.layout_radio)
self.pw = ScatterGraphWidget()
self.label_num = QLabel('几个')
self.table = QTableWidget()
self.table.setColumnCount(2)
self.table.setHorizontalHeaderLabels(['代码','名'])
self.table.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers)
self.table.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection)
self.table.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows)
self.table.itemClicked.connect(self.table_itemClicked)
layout00 = QHBoxLayout()
layout00.addWidget(label00)
layout00.addWidget(self.lineedit_file)
layout00.addWidget(btn_choice)
layout00.addWidget(self.btn_excute)
layout00.addWidget(btn_download)
layout01 = QVBoxLayout()
layout01.addWidget(label01)
layout01.addWidget(self.combo_type)
layout01.addWidget(self.label_industry)
layout02 = QVBoxLayout()
layout02.addWidget(label02)
layout02.addWidget(self.combo_date)
layout02.addWidget(self.label_date)
layout03 = QHBoxLayout()
layout03.addLayout(layout01)
layout03.addLayout(layout02)
layout04 = QVBoxLayout()
layout04.addLayout(layout03)
layout04.addWidget(groupBox)
layout04.addWidget(self.pw)
layout05 = QVBoxLayout()
layout05.addWidget(self.label_num)
layout05.addWidget(self.table)
layout06 = QHBoxLayout()
layout06.addLayout(layout04,5)
layout06.addLayout(layout05,1)
layout = QVBoxLayout()
layout.addLayout(layout00)
layout.addWidget(self.label_title)
layout.addLayout(layout06)
self.setLayout(layout)
self.open_init()
pass
def open_init(self):
self.whole_resObj:MultiScatterObj = None
self.whole_current_df:pd.DataFrame = pd.DataFrame()
self.whole_current_show_df:pd.DataFrame = pd.DataFrame()
with open(DATA_DIR+'industry_type.json','r',encoding='utf-8') as f:
self.industry_type = json.load(f)
with open(DATA_DIR+'ticker_name.json','r',encoding='utf-8') as f:
self.ticker_name_dict = json.load(f)
self.radio_list = []
self.whole_current_ticker_list = []
pass
def btn_choice_clicked(self):
file_path,_ = QFileDialog.getOpenFileName(self,'选择文件')
if file_path:
self.lineedit_file.setText(file_path)
pass
def btn_excute_clicked(self):
file_path = self.lineedit_file.text()
if len(file_path) <= 0:
QMessageBox.information(self,'提示','请选择要执行的py文件',QMessageBox.StandardButton.Ok)
return
with open(file_path,'r',encoding='utf-8') as fr:
py_code = fr.read()
namespace = {}
fun_code = compile(py_code, '<string>', 'exec')
exec(fun_code, namespace)
res = namespace['execute_caculate']()
if res.status == 'error':
QMessageBox.information(self,'执行过程报错',res.error_msg,QMessageBox.StandardButton.Ok)
return
self.label_title.setText(res.title)
self.whole_resObj = res
self.reset_content()
QMessageBox.information(self,'提示','执行完毕',QMessageBox.StandardButton.Ok)
pass
def btn_download_clicked(self):
if self.whole_resObj is None or self.whole_resObj.status == 'error':
QMessageBox.information(self,'提示','数据为空',QMessageBox.StandardButton.Ok)
return
dir_name = QFileDialog.getExistingDirectory(self,'选择保存位置')
if dir_name:
df = self.whole_resObj.df.copy()
df.rename(columns=self.whole_resObj.col_dict,inplace=True)
df.to_csv(dir_name+os.path.sep + self.whole_resObj.title +'.csv',encoding='utf-8',index=False)
QMessageBox.information(self,'提示','下载完毕',QMessageBox.StandardButton.Ok)
pass
def combo_type_currentIndexChanged(self,cur_i:int):
cur_type = self.combo_type.currentText()
self.label_industry.setText(';'.join(self.whole_resObj.type_dict[cur_type].keys()))
df = self.whole_resObj.df.copy()
key0 = list(self.whole_resObj.type_dict[cur_type].keys())[0]
ticker_list = self.whole_resObj.type_dict[cur_type][key0]
self.whole_current_df = df.loc[df['ticker'].isin(ticker_list)].copy()
self.label_num.setText(f'共{len(ticker_list)}个')
self.table.setRowCount(len(ticker_list))
for i, item in enumerate(ticker_list):
item_name = self.ticker_name_dict.get(item, item)
self.table.setItem(i, 0, QTableWidgetItem(str(item)))
self.table.setItem(i, 1, QTableWidgetItem(item_name))
pass
self.table.resizeColumnsToContents()
self.whole_current_ticker_list = ticker_list
date_list = list(set(self.whole_current_df['reportDate'].to_list()))
date_list.sort()
date_list.reverse()
self.combo_date.addItems(date_list)
pass
def combo_date_currentIndexChanged(self,cur_i:int):
cur_date = self.combo_date.currentText()
self.label_date.setText(cur_date)
a0 = self.radioButtonGroup.button(2)
a0.setChecked(True)
self.radioButtonGroup_buttonClicked(a0)
pass
def radioButtonGroup_buttonClicked(self,a0):
indicator_name = a0.text()
target_name = ''
for k,v in self.whole_resObj.col_dict.items():
if indicator_name == v:
target_name = k
break
df = self.whole_resObj.df.copy()
df = df.loc[(df['ticker'].isin(self.whole_current_ticker_list)) & (df['reportDate']==self.combo_date.currentText())]
df.rename(columns={target_name:'target'},inplace=True)
df['x'] = range(len(df))
self.whole_current_show_df = df.copy()
self.pw.set_data(df.copy(),self.whole_resObj.focus_target)
pass
def table_itemClicked(self,cur_item):
cur_row = cur_item.row()
ticker = self.table.item(cur_row,0).text()
df = self.whole_current_show_df.loc[self.whole_current_show_df['ticker']==ticker]
if df.empty:
QMessageBox.information(self,'提示',f'{ticker},在该日期没有数据',QMessageBox.StandardButton.Ok)
return
self.pw.set_targetItem(ticker)
pass
def reset_content(self):
for item in self.radio_list:
self.layout_radio.removeWidget(item)
self.radioButtonGroup.removeButton(item)
self.radio_list.clear()
i = 2
for item in self.whole_resObj.col_dict.values():
radio = QRadioButton(item)
self.layout_radio.addWidget(radio)
self.radioButtonGroup.addButton(radio,i)
self.radio_list.append(radio)
i += 1
self.pw.set_content_empty()
self.combo_type.clear()
self.combo_date.clear()
self.table.clearContents()
self.label_num.setText('--')
self.combo_type.addItems(list(self.whole_resObj.type_dict.keys()))
self.label_title.setText(self.whole_resObj.title)
pass
pass
使用举例
需要导入的包和运行代码
import os,sys,json
import pandas as pd
import numpy as np
from PyQt6.QtCore import (
QSize,
Qt
)
from PyQt6.QtWidgets import (
QApplication, QButtonGroup, QRadioButton,
QMainWindow, QAbstractItemView,
QLabel,
QPushButton,
QComboBox,
QTableWidget,
QTableWidgetItem,
QTextEdit,
QWidget,
QVBoxLayout,
QHBoxLayout, QGridLayout,
QFileDialog,
QInputDialog,
QMessageBox,
QLineEdit,
QGroupBox, QScrollArea, QCompleter
)
import pyqtgraph as pg
from objects import MultiScatterObj
from settings import DATA_DIR
if __name__ == '__main__':
app = QApplication(sys.argv)
mw = PyExcuteGraphShowWidgetII()
mw.show()
app.exec()
pass
1)一个py文件例子,内容如下,方法名固定为 excute_caculate
def execute_caculate():
import traceback,json,os
import pandas as pd
from utils import postgresql_utils
from objects import MultiScatterObj
from settings import DATA_DIR,INDUSTRY_DIR
'''
灵活py文件执行
营业利润,营业外支出,营业外收入
'''
conn = postgresql_utils.connect_db()
cur = conn.cursor()
try:
ticker = '000638'
file_list = os.listdir(INDUSTRY_DIR)
type_dict = {}
ticker_list = []
for file_one in file_list:
type_code = file_one[0:6]
file_path = os.path.join(INDUSTRY_DIR, file_one)
with open(file_path, 'r',encoding='utf-8') as f:
j_obj = json.load(f)
# type_dict[type_code] = {}
for k,v in j_obj.items():
if ticker in v:
if type_dict.get(type_code,None) is None:
type_dict[type_code] = {}
type_dict[type_code][k] = v
ticker_list.extend(v)
pass
pass
ticker_list = list(set(ticker_list))
ticker_list_str = '\',\''.join(ticker_list)
ticker_list_str = '\''+ticker_list_str+'\''
sql_str = f'''
select ticker,reportDate,iii_operateProfit,add_nonoperateIncome,less_nonoperateExpenses from t_profit where ticker in ({ticker_list_str}) and reportDate like \'%-12-31\';
'''
cur.execute(sql_str)
res = cur.fetchall()
col_list = ['ticker','reportDate','a0','a1','a2']
col_dict = {
'a0':'营业利润',
'a1':'营业外收入',
'a2':'营业外支出'
}
df = pd.DataFrame(columns=col_list, data=res)
res_obj = MultiScatterObj(
title=f'{ticker},营业利润、营业外收入、营业外支出',
df=df,
col_dict=col_dict,
type_dict=type_dict,
focus_target=ticker,
status='ok'
)
return res_obj
except:
res_obj = MultiScatterObj(
status='error',
error_msg=traceback.format_exc()
)
return res_obj
finally:
cur.close()
conn.close()
pass
pass
注意:例子中涉及到的postgreSQL和财报数据在往期博文中可以找到
2)点击"选择文件",选择 test002.py文件
3)点击"执行",选择行业、日期、指标,就能显示散点图
4)右侧股票列表,单击某个股票,就会在散点图中用红星标注