BassModel#
- class pymc_marketing.bass.model.BassModel(model_config=None, sampler_config=None)[source]#
Bass diffusion model for product adoption forecasting.
Wraps the functional
create_bass_model()inside theModelBuilderinterface, providing standardised.fit(),.save(),.load()and related methods. The underlyingpm.Modelis accessible viamodel.modelfor direct use with PyMC functions.- Parameters:
Notes
Data format#
When using
xr.Dataset, theTcoordinate is required and represents the time index. Anobserveddata variable can hold adoption counts (omit for prior-predictive only).Single-product — 1-D
observedwithTas the only dimension:xr.Dataset( {"observed": ("T", counts)}, coords={"T": np.arange(N)}, )
Multi-product —
observedwithTandproductdimensions:xr.Dataset( {"observed": (("T", "product"), counts)}, coords={"T": np.arange(N), "product": ["A", "B", "C"]}, )
Other input types (
np.ndarray,pd.Series,pd.DataFrame) are auto-converted viato_bass_dataset().Examples
Fit a single-product model
import numpy as np import arviz as az from pymc_marketing.bass import BassModel y = np.random.poisson(lam=100, size=50) model = BassModel() idata = model.fit(data=y) print(az.summary(idata, var_names=["m", "p", "q"]))
Multi-product with custom priors
import xarray as xr from pymc_extras.prior import Prior data = xr.Dataset( {"observed": (("T", "product"), np.random.poisson(100, size=(50, 3)))}, coords={"T": np.arange(50), "product": ["A", "B", "C"]}, ) model = BassModel( model_config={ "m": Prior("Normal", mu=5_000, sigma=1_000), "p": Prior("Beta", alpha=1.5, beta=20), "q": Prior("Beta", alpha=2, beta=5), "likelihood": Prior("Poisson"), }, ) idata = model.fit(data=data) print(az.summary(idata, var_names=["m", "p", "q"]))
Generate synthetic data and fit
Build the model without an
observedvariable (only aTcoordinate), draw a prior predictive sample, then fit to it:import xarray as xr import pymc as pm ds = xr.Dataset({"T": np.arange(50)}) model = BassModel() model.build_model(data=ds) with model.model: prior = pm.sample_prior_predictive(draws=50, random_seed=42) y_sim = prior.prior["y"].sel(draw=0, chain=0) # Now fit the model to the synthetic data idata = model.fit(data=y_sim.values)
Posterior predictive checks
Generate posterior predictive samples after fitting:
pp_data = model.sample_posterior_predictive(X=new_time_points)
The posterior contains deterministics such as
adopters,innovators,imitators, andpeakthat can be analysed directly viaidata.posterior, e.g.:az.plot_forest(idata.posterior["peak"], combined=True)
Methods
BassModel.__init__([model_config, ...])Initialize model configuration and sampler configuration for the model.
Convert the model configuration and sampler configuration from the attributes to keyword arguments.
BassModel.build_from_idata(idata)Rebuild the model from an
InferenceDataobject.BassModel.build_model([data])Build the Bass diffusion model from the given data.
Create attributes for the inference data.
BassModel.fit(data[, progressbar, random_seed])Fit the Bass diffusion model via MCMC.
BassModel.graphviz(**kwargs)Get the graphviz representation of the model.
Create the model configuration and sampler configuration from the InferenceData to keyword arguments.
BassModel.load(fname[, check])Create a ModelBuilder instance from a file.
BassModel.load_from_idata(idata[, check])Create a ModelBuilder instance from an InferenceData object.
BassModel.sample_posterior_predictive(X[, ...])Sample from the model's posterior predictive distribution.
BassModel.save(fname, **kwargs)Save the model's inference data to a file.
BassModel.set_idata_attrs([idata])Set attributes on an InferenceData object.
BassModel.table(**model_table_kwargs)Get the summary table of the model.
Attributes
default_model_configDefault model configuration with weakly informative priors.
default_sampler_configDefault sampler configuration.
fit_resultGet the posterior fit_result.
idGenerate a unique hash value for the model.
output_varReturn the name of the output variable.
posteriorAccess the 'posterior' attribute of the InferenceData object.
posterior_predictiveAccess the 'posterior_predictive' attribute of the InferenceData object.
predictionsAccess the 'predictions' attribute of the InferenceData object.
priorAccess the 'prior' attribute of the InferenceData object.
prior_predictiveAccess the 'prior_predictive' attribute of the InferenceData object.
versionidatasampler_configmodel_config