在 plotly dash Store 中,比 json 更快的序列化pickle、parquet、feather...?

问题描述

上下文

在使用 plotly Dash 的仪表板中,仅当组件(具有要考虑的时间段的 DataPicker,因此要从 DB 下载)更新时,我才需要从 DB 执行昂贵的下载,然后将生成的 DataFrame 与其他组件(例如下拉过滤数据帧)避免了昂贵的下载过程。

docs 建议使用 <path d="M4520.84,0.486c-23.469,6.567 -33.771,58.55 -33.771,58.55c-107.025,573.459 -2460.35,714.709 -2798.9,773.33c-537.992,93.166 -447.612,909.229 65.871,875.141c185.679,-12.325 979.017,-79.631 979.017,-79.631c-0,-0 -14.344,-104.849 50.82,-111.235c65.164,-6.387 71.763,100.733 71.763,100.733c220.525,-19.258 1289.13,-144.555 1289.13,-144.555c-7.646,-46.501 15.458,-77.667 48.895,-83.1c33.438,-5.438 67.598,17.3 74.109,72.012c494.963,-43.573 2721.5,-252.633 2966.35,-282.537c363.216,-44.363 601.066,328.762 109.12,377.417c-419.854,41.529 -4888.32,477.23 -5003.25,487.634l6.459,101.831c0,-0 4235.6,-392.165 5061.22,-462.528c673.929,-132.462 268.75,-706.75 -105.588,-634.912c-291.55,31.654 -5264.31,515.842 -5430.99,506.325c-573.14,-32.721 -365.372,-423.426 -90.792,-511.82c86.976,-28 669.863,-96.155 694.175,-98.83c685.163,-75.316 1788.48,-168.387 2127.44,237.771c24.904,29.846 73.358,16.659 78.658,-21.85c46.142,-335.116 -65.85,-927.987 -136.879,-1143.18c-3.667,-11.113 -13.617,-19.147 -22.854,-16.563Z" style="fill:url(#_Radial14);fill-rule:nonzero;"/> 作为回调的输出,该回调返回在 json 中序列化的 DataFrame,而不是使用 Store 作为需要从 json 反序列化到 DataFrame 的其他回调的输入。>

从/到 JSON 的序列化速度很慢,每次我更新一个组件需要 30 秒来更新绘图。

我尝试使用更快的序列化,如pickle、parquet和feather,但在反序列化部分我收到一个错误,指出对象为空(使用JSON时不会出现此类错误)。

问题

是否可以使用比 JSON 更快的方法(例如 pickle、feather 或 parquet(它们花费我的数据集大约一半的时间)在 Dash Store 中执行序列化?怎么样?

代码

dash_core_components.Store

错误文本

import io import traceback import pandas as pd from datetime import datetime,date,timedelta import dash import dash_core_components as dcc import dash_html_components as html import dash_bootstrap_components as dbc from dash.dependencies import Input,Output from plotly.subplots import make_subplots app = dash.Dash(__name__,external_stylesheets=[dbc.themes.BOOTSTRAP]) today = date.today() app.layout = html.Div([ dbc.Row(dbc.Col(html.H1('PMC'))),dbc.Row(dbc.Col(html.H5('analysis'))),html.Hr(),html.Br(),dbc.Container([ dbc.Row([ dbc.Col( dcc.DatePickerRange( id='date_ranges',start_date=today - timedelta(days=20),end_date=today,max_date_allowed=today,display_format='MMM Do,YY',),width=4 ),]),dbc.Row( dbc.Col( dcc.Dropdown( id='dd_ycolnames',options=options,value=default_options,multi=True,dbc.Row([ dbc.Col( dcc.Graph( id='graph_subplots',figure={},width=12 ),dcc.Store(id='store') ]) @app.callback( Output('store','data'),[ Input(component_id='date_ranges',component_property='start_date'),Input(component_id='date_ranges',component_property='end_date') ] ) def load_dataset(date_ranges_start,date_ranges_end): # some expensive clean data step logger.info('loading dataset...') date_ranges1_start = datetime.strptime(date_ranges_start,'%Y-%m-%d') date_ranges1_end = datetime.strptime(date_ranges_end,'%Y-%m-%d') df = expensive_load_from_db(date_ranges1_start,date_ranges1_end) logger.info('dataset to json...') #return df.to_json(date_format='iso',orient='split') return df.to_parquet() # <---------------------- @app.callback( Output(component_id='graph_subplots',component_property='figure'),[ Input(component_id='store',component_property='data'),Input(component_id='dd_ycolnames',component_property='value'),],) def update_plot(df_bin,y_colnames): logger.info('dataset from json') #df = pd.read_json(df_bin,orient='split') df = pd.read_parquet(io.BytesIO(df_bin)) # <---------------------- logger.info('building plot...') traces = [] for y_colname in y_colnames: if df[y_colname].dtype == 'bool': df[y_colname] = df[y_colname].astype('int') traces.append( {'x': df['date'],'y': df[y_colname].values,'name': y_colname},) fig = make_subplots( rows=len(y_colnames),cols=1,shared_xaxes=True,vertical_spacing=0.1 ) fig.layout.height = 1000 for i,trace in enumerate(traces): fig.append_trace(trace,i+1,1) logger.info('plotted') return fig if __name__ == '__main__': app.run_server(host='localhost',debug=True)

解决方法

由于客户端和服务器之间的数据交换,您目前仅限于 JSON 序列化。规避此限制的一种方法是通过 ServersideOutput component from dash-extensions,它将数据存储在服务器上。它默认使用文件存储和 pickle 序列化,但您也可以使用其他存储(例如 Redis)和/或序列化协议(例如箭头)。这是一个小例子,

import time
import dash_core_components as dcc
import dash_html_components as html
import plotly.express as px
from dash_extensions.enrich import Dash,Output,Input,State,ServersideOutput

app = Dash(prevent_initial_callbacks=True)
app.layout = html.Div([
    html.Button("Query data",id="btn"),dcc.Dropdown(id="dd"),dcc.Graph(id="graph"),dcc.Loading(dcc.Store(id='store'),fullscreen=True,type="dot")
])


@app.callback(ServersideOutput("store","data"),Input("btn","n_clicks"))
def query_data(n_clicks):
    time.sleep(1)
    return px.data.gapminder()  # no JSON serialization here


@app.callback(Input("store",Output("dd","options"))
def update_dd(df):
    return [{"label": column,"value": column} for column in df["year"]]  # no JSON de-serialization here


@app.callback(Output("graph","figure"),[Input("dd","value"),State("store","data")])
def update_graph(value,df):
    df = df.query("year == {}".format(value))  # no JSON de-serialization here
    return px.sunburst(df,path=['continent','country'],values='pop',color='lifeExp',hover_data=['iso_alpha'])


if __name__ == '__main__':
    app.run_server()