BassModel#

class pymc_marketing.bass.model.BassModel(model_config=None, sampler_config=None)[source]#

Bass diffusion model for product adoption forecasting.

Wraps the functional create_bass_model() inside the ModelBuilder interface, providing standardised .fit(), .save(), .load() and related methods. The underlying pm.Model is accessible via model.model for direct use with PyMC functions.

Parameters:
model_configdict, optional

Dictionary with keys "m", "p", "q", "likelihood" mapping to Prior (or equivalent dict). See default_model_config() for defaults.

sampler_configdict, optional

Dictionary of sampler settings (draws, tune, chains, …). See default_sampler_config() for defaults.

Notes

Data format#

When using xr.Dataset, the T coordinate is required and represents the time index. An observed data variable can hold adoption counts (omit for prior-predictive only).

Single-product — 1-D observed with T as the only dimension:

xr.Dataset(
    {"observed": ("T", counts)},
    coords={"T": np.arange(N)},
)

Multi-productobserved with T and product dimensions:

xr.Dataset(
    {"observed": (("T", "product"), counts)},
    coords={"T": np.arange(N), "product": ["A", "B", "C"]},
)

Other input types (np.ndarray, pd.Series, pd.DataFrame) are auto-converted via to_bass_dataset().

Examples

Fit a single-product model

import numpy as np
import arviz as az
from pymc_marketing.bass import BassModel

y = np.random.poisson(lam=100, size=50)
model = BassModel()
idata = model.fit(data=y)
print(az.summary(idata, var_names=["m", "p", "q"]))

Multi-product with custom priors

import xarray as xr
from pymc_extras.prior import Prior

data = xr.Dataset(
    {"observed": (("T", "product"), np.random.poisson(100, size=(50, 3)))},
    coords={"T": np.arange(50), "product": ["A", "B", "C"]},
)
model = BassModel(
    model_config={
        "m": Prior("Normal", mu=5_000, sigma=1_000),
        "p": Prior("Beta", alpha=1.5, beta=20),
        "q": Prior("Beta", alpha=2, beta=5),
        "likelihood": Prior("Poisson"),
    },
)
idata = model.fit(data=data)
print(az.summary(idata, var_names=["m", "p", "q"]))

Generate synthetic data and fit

Build the model without an observed variable (only a T coordinate), draw a prior predictive sample, then fit to it:

import xarray as xr
import pymc as pm

ds = xr.Dataset({"T": np.arange(50)})
model = BassModel()
model.build_model(data=ds)

with model.model:
    prior = pm.sample_prior_predictive(draws=50, random_seed=42)
    y_sim = prior.prior["y"].sel(draw=0, chain=0)

# Now fit the model to the synthetic data
idata = model.fit(data=y_sim.values)

Posterior predictive checks

Generate posterior predictive samples after fitting:

pp_data = model.sample_posterior_predictive(X=new_time_points)

The posterior contains deterministics such as adopters, innovators, imitators, and peak that can be analysed directly via idata.posterior, e.g.:

az.plot_forest(idata.posterior["peak"], combined=True)

Methods

BassModel.__init__([model_config, ...])

Initialize model configuration and sampler configuration for the model.

BassModel.attrs_to_init_kwargs(attrs)

Convert the model configuration and sampler configuration from the attributes to keyword arguments.

BassModel.build_from_idata(idata)

Rebuild the model from an InferenceData object.

BassModel.build_model([data])

Build the Bass diffusion model from the given data.

BassModel.create_idata_attrs()

Create attributes for the inference data.

BassModel.fit(data[, progressbar, random_seed])

Fit the Bass diffusion model via MCMC.

BassModel.graphviz(**kwargs)

Get the graphviz representation of the model.

BassModel.idata_to_init_kwargs(idata)

Create the model configuration and sampler configuration from the InferenceData to keyword arguments.

BassModel.load(fname[, check])

Create a ModelBuilder instance from a file.

BassModel.load_from_idata(idata[, check])

Create a ModelBuilder instance from an InferenceData object.

BassModel.sample_posterior_predictive(X[, ...])

Sample from the model's posterior predictive distribution.

BassModel.save(fname, **kwargs)

Save the model's inference data to a file.

BassModel.set_idata_attrs([idata])

Set attributes on an InferenceData object.

BassModel.table(**model_table_kwargs)

Get the summary table of the model.

Attributes

default_model_config

Default model configuration with weakly informative priors.

default_sampler_config

Default sampler configuration.

fit_result

Get the posterior fit_result.

id

Generate a unique hash value for the model.

output_var

Return the name of the output variable.

posterior

Access the 'posterior' attribute of the InferenceData object.

posterior_predictive

Access the 'posterior_predictive' attribute of the InferenceData object.

predictions

Access the 'predictions' attribute of the InferenceData object.

prior

Access the 'prior' attribute of the InferenceData object.

prior_predictive

Access the 'prior_predictive' attribute of the InferenceData object.

version

idata

sampler_config

model_config