MMMPlotSuite v2 Migration Guide#

PyMC-Marketing is moving from the monolithic MMMPlotSuite to a namespace-based plotting API backed by arviz-plots. The new API uses PlotCollection for subplot layout, dimension-aware faceting, and multi-backend rendering. The legacy suite will be removed in pymc-marketing 2.0.0.

This guide covers how to opt in to the new API and what has changed.

Opting In#

Set plot_suite = "new" on an existing instance to switch to the new API. Until set to "new", accessing mmm.plot emits a FutureWarning.

from pymc_marketing.mmm import MMM, TimeSliceCrossValidator

mmm = MMM(...)
mmm.plot_suite = "new"

cv = TimeSliceCrossValidator(
    n_init=100,
    forecast_horizon=12,
    date_column="date",
)
cv.plot_suite = "new"

Common Arguments#

All methods share a consistent parameter set. These are explained once here; subsequent sections refer to them by name.

Argument

Type

Default

Description

idata

az.InferenceData | None

None

Override the model’s fitted data for a single call. When None, uses the data stored on the model instance.

hdi_prob

float

0.94

Credible interval width. Replaces hdi_probs: list[float] — only one level per call.

dims

dict[str, Any] | None

None

Subset coordinates, e.g. {"channel": ["tv", "radio"]}. Size-1 dims are preserved as facets rather than squeezed out.

figsize

tuple[float, float] | None

None

Shorthand for figure_kwargs={"figsize": ...} passed to PlotCollection.

backend

str | None

None (matplotlib)

Rendering backend: "matplotlib", "plotly", or "bokeh".

return_as_pc

bool

False

Return a PlotCollection instead of (Figure, NDArray[Axes]). Required when backend is not "matplotlib".

**pc_kwargs

Forwarded to PlotCollection.wrap(). Controls col_wrap, layout, aesthetic mappings.

*_kwargs

dict | None

None

Per-visual-element kwargs e.g. line_kwargs, hdi_kwargs, scatter_kwargs — forwarded to the underlying arviz-plots visual.

fig, axes = mmm.plot.diagnostics.posterior_predictive(
    hdi_prob=0.89,
    dims={"channel": ["tv"]},
    figsize=(10, 4),
    line_kwargs={"color": "blue"},
    hdi_kwargs={"alpha": 0.3},
)

Non-Matplotlib Backends (not fully supported)#

Any method that accepts backend can render with Plotly or Bokeh instead of Matplotlib. When using a non-matplotlib backend you must pass return_as_pc=True, which returns a PlotCollection instead of a (Figure, axes) tuple.

Note: Non-matplotlib backend support has not been fully tested and is likely to contain issues. Use at your own risk.

waterfall() in the decomposition namespace does not accept backend or return_as_pc — it always returns a matplotlib (Figure, axes) tuple.

# Plotly backend — requires return_as_pc=True
pc = mmm.plot.diagnostics.posterior_predictive(
    backend="plotly",
    return_as_pc=True,
)
pc.show()

mmm.plot.diagnostics#

Method names#

Legacy

New

mmm.plot.posterior_predictive(...)

mmm.plot.diagnostics.posterior_predictive(...)

mmm.plot.prior_predictive(...)

mmm.plot.diagnostics.prior_predictive(...)

mmm.plot.residuals_over_time(...)

mmm.plot.diagnostics.residuals_over_time(...)

mmm.plot.residuals_posterior_distribution(...)

mmm.plot.diagnostics.residuals_distribution(...)

mmm.plot.posterior_distribution(...)

mmm.plot.diagnostics.posterior(...)

mmm.plot.prior_vs_posterior(...)

mmm.plot.diagnostics.prior_vs_posterior(...)

Argument changes#

Old

New

Notes

hdi_probs: list[float]

hdi_prob: float

Single level per call

var: list[str] (on posterior_distribution)

var_names: list[str] | str | None

Also accepts a single string

group: str = "posterior"

New on posterior() — pass "prior" to plot prior instead

kind: str = "kde"

New on posterior() and prior_vs_posterior() — controls plot type

aggregation

New on residuals_distribution() — dimension to aggregate over

quantiles

New on residuals_distribution() — quantile lines to overlay

fig, axes = mmm.plot.diagnostics.posterior_predictive()
fig, axes = mmm.plot.diagnostics.posterior(var_names=["alpha", "beta"])
fig, axes = mmm.plot.diagnostics.residuals_distribution(quantiles=[0.025, 0.5, 0.975])

mmm.plot.decomposition#

Method names#

Legacy

New

mmm.plot.contributions_over_time(...)

mmm.plot.decomposition.contributions_over_time(...)

mmm.plot.waterfall_components_decomposition(...)

mmm.plot.decomposition.waterfall(...)

mmm.plot.channel_parameter(...)

mmm.plot.decomposition.channel_share_hdi(...)

Argument changes#

Old

New

Notes

hdi_probs: list[float]

hdi_prob: float

Single level per call

original_scale: bool = False

original_scale: bool = True

Default flipped to True

include: list[Literal["channels", "baseline", "controls", "seasonality"]] | None

New on contributions_over_time() — filter which components appear

waterfall() does not accept backend or return_as_pc — it always returns (Figure, NDArray[Axes]).

fig, axes = mmm.plot.decomposition.contributions_over_time(
    include=["channels", "baseline"]
)
fig, axes = mmm.plot.decomposition.waterfall()
fig, axes = mmm.plot.decomposition.channel_share_hdi(hdi_prob=0.89)

mmm.plot.sensitivity#

Method names#

Legacy

New

mmm.plot.sensitivity_analysis(...)

mmm.plot.sensitivity.analysis(...)

mmm.plot.uplift_curve(...)

mmm.plot.sensitivity.uplift(...)

mmm.plot.marginal_curve(...)

mmm.plot.sensitivity.marginal(...)

Argument changes#

Old

New

Notes

hdi_probs: list[float]

hdi_prob: float

Single level per call

x_sweep_axis: Literal["relative", "absolute"] = "relative"

New — controls x-axis scale

apply_cost_per_unit: bool = True

New — scale spend axis by cost-per-unit

aggregation: dict[str, str | list[str]] | None

New — aggregate over dimensions before plotting

fig, axes = mmm.plot.sensitivity.analysis()
fig, axes = mmm.plot.sensitivity.analysis(x_sweep_axis="absolute", hdi_prob=0.89)

mmm.plot.transformation#

Method names#

Legacy

New

mmm.plot.saturation_scatterplot(...)

mmm.plot.transformation.saturation_scatterplot(...)

mmm.plot.saturation_curves(...)

mmm.plot.transformation.saturation_curves(...)

Argument changes#

Old

New

Notes

original_scale: bool = False

original_scale: bool = True

Default flipped to True

hdi_probs: list[float]

hdi_prob: float | None = 0.94

Single level; pass None to suppress HDI band

apply_cost_per_unit: bool = True

New on both methods

n_samples: int = 10

New on saturation_curves() — number of posterior draws to overlay

random_seed

New on saturation_curves() — for reproducible sample selection

mean_curve_kwargs

New on saturation_curves() — style the mean curve separately

sample_curves_kwargs

New on saturation_curves() — style individual sample curves

curves: xr.DataArray

Now required positional arg on saturation_curves()

fig, axes = mmm.plot.transformation.saturation_scatterplot()
fig, axes = mmm.plot.transformation.saturation_curves(
    curves=saturation_curve_data,
    n_samples=20,
)

Budget Plots#

Budget plots have moved from mmm.plot to optimizer.plot. BudgetPlots is stateless — all data is passed per-call via samples.

Argument changes#

Old

New

Notes

hdi_probs: list[float]

hdi_prob: float

Single level per call

samples: xr.Dataset

Required arg — output of allocate_budget()

from pymc_marketing.mmm import MMM, BudgetOptimizerWrapper

mmm = MMM(...)
mmm.plot_suite = "new"
optimizer = BudgetOptimizerWrapper(
    model=mmm, start_date="2024-01-01", end_date="2024-12-31"
)
samples = optimizer.allocate_budget(...)

fig, axes = optimizer.plot.allocation_roas(samples=samples)
fig, axes = optimizer.plot.contribution_over_time(samples=samples)

Cross-Validation Plots#

CV plots use MMMCVPlotSuite when cv.plot_suite = "new".

Argument changes#

Old

New

Notes

hdi_probs: list[float]

hdi_prob: float

Single level per call

var_names: list[str] | None

New on param_stability() — filter which parameters appear

cv = TimeSliceCrossValidator(n_init=100, forecast_horizon=12, date_column="date")
cv.plot_suite = "new"
cv_idata = cv.run(X, y, mmm=mmm)

fig, axes = cv.plot.predictions(cv_idata)
fig, axes = cv.plot.param_stability(cv_idata, var_names=["alpha"])
fig, axes = cv.plot.crps(cv_idata)

Removal Timeline#

The legacy MMMPlotSuite (the default when mmm.plot_suite is not set) will be removed in pymc-marketing 2.0.0. Until then, accessing mmm.plot without opting in emits a FutureWarning.

To suppress it: set mmm.plot_suite = "new".