ModelResults class for RL models

There is one class to inspect model fits of RL models (fitted on choices alone): RLModelResults_2A.

The main functions of this class are:

  • Assess the model’s convergence and mcmc diagnostics, to make sure that the sampling was successful. This step is crucial and should be preferably done first.

  • Provide a measure of the model’s quantitative fit to the data (i.e., the Watanabe-Akaike information criterion). This is important when comparing the quantitative fit to the data of several, competing models.

  • Visualize and make interval-based (either Bayesian Credible Intervals or Higher Density Intervals) inferences on the posterior distributions of the model’s parameters. This is important when specific hypotheses were made about the parameters’ values.

  • Calculate and visualize posterior predictive distributions of the observed data. This step is important to assess the qualitative fit of the model to the data. Qualitative fit should be assessed not only when comparing different competing models, but also when a single candidate model is fitted. Different ways of calculating posterior predictive distributions are provided, together with different plotting options. In general, emphasis is given to calculating posterior predictive distributions across conditions. This allows us to assess whether a certain behavioral pattern observed in the data (e.g., due to experimental manipulations) can also be reproduced by the model.

All models

class rlssm.fits_RL.ModelResults(model_label, data_info, parameters_info, priors, rhat, waic, last_values, samples, trial_samples)
plot_posteriors(gridsize=100, clip=None, show_intervals='HDI', alpha_intervals=0.05, intervals_kws=None, **kwargs)

Plots posterior predictives of the model’s parameters.

If the model is hierarchical, then only the group parameters are plotted. In particular, group means are plotted in the first row and group standard deviations are plotted in the second row. By default, 95 percent HDI are shown. The kernel density estimation is calculated using scipy.stats.gaussian_kde.

Parameters
  • gridsize (int, default to 100) – Resolution of the kernel density estimation function, default to 100.

  • clip (tuple of (float, float), optional) – Range for the kernel density estimation function. Default is min and max values of the distribution.

  • show_intervals (str, default to "HDI") – Either “HDI”, “BCI”, or None. HDI is better when the distribution is not simmetrical. If None, then no intervals are shown.

  • alpha_intervals (float, default to .05) – Alpha level for the intervals. Default is 5 percent which gives 95 percent BCIs and HDIs.

  • intervals_kws (dict) – Additional arguments for matplotlib.axes.Axes.fill_between that shows shaded intervals. By default, they are 50 percent transparent.

  • **kwargs – Additional parameters for seaborn.FacetGrid.

Returns

g

Return type

seaborn.FacetGrid

to_pickle(filename=None)

Pickle the fitted model’s results object to file.

This can be used to store the model’s result and read them and inspect them at a later stage, without having to refit the model.

Parameters

filename (str, optional) – File path where the pickled object will be stored. If not specified, is set to

Reinforcement learning models

class rlssm.fits_RL.RLModelResults_2A(model_label, data_info, parameters_info, priors, rhat, waic, last_values, samples, trial_samples)

RLModelResults allows to perform various model checks on fitted RL models.

In particular, this can be used to to visualize the estimated posterior distributions and to calculate and visualize the estimated posterior predictives distributions.

show-inheritance

inherited-members

get_grouped_posterior_predictives_summary(grouping_vars, n_posterior_predictives=500)

Calculates summary of posterior predictives of choices, separately for a list of grouping variables.

The mean proportion of choices (in this case coded as accuracy) is calculated for each posterior sample across all trials in all conditions combination.

For example, if grouping_vars=[‘reward’, ‘difficulty’], posterior predictives will be collapsed for all combinations of levels of the reward and difficulty variables.

Parameters
  • grouping_vars (list of strings) – They should be existing grouping variables in the data.

  • n_posterior_predictives (int) – Number of posterior samples to use for posterior predictives calculation. If n_posterior_predictives is bigger than the posterior samples, then calculation will continue with the total number of posterior samples.

Returns

out – Pandas DataFrame. The column contains the mean accuracy. The row indes is a pandas.MultIndex, with the grouping variables as higher level and number of samples as lower level.

Return type

DataFrame

get_posterior_predictives(n_posterior_predictives=500)

Calculates posterior predictives of choices.

Parameters

n_posterior_predictives (int) – Number of posterior samples to use for posterior predictives calculation. If n_posterior_predictives is bigger than the posterior samples, then calculation will continue with the total number of posterior samples.

Returns

pp_acc – Array of shape (n_samples, n_trials).

Return type

numpy.ndarray

get_posterior_predictives_df(n_posterior_predictives=500)

Calculates posterior predictives of choices.

Parameters

n_posterior_predictives (int) – Number of posterior samples to use for posterior predictives calculation. If n_posterior_predictives is bigger than the posterior samples, then calculation will continue with the total number of posterior samples.

Returns

out – Data frame of shape (n_samples, n_trials).

Return type

DataFrame

get_posterior_predictives_summary(n_posterior_predictives=500)

Calculates summary of posterior predictives of choices.

The mean proportion of choices (in this case coded as accuracy) is calculated for each posterior sample across all trials.

Parameters

n_posterior_predictives (int) – Number of posterior samples to use for posterior predictives calculation. If n_posterior_predictives is bigger than the posterior samples, then calculation will continue with the total number of posterior samples.

Returns

out – Data frame, where every row corresponds to a posterior sample. The column contains the mean accuracy for each posterior sample over the whole dataset.

Return type

DataFrame

plot_mean_grouped_posterior_predictives(grouping_vars, n_posterior_predictives=500, **kwargs)

Plots the mean posterior predictives of choices, separately for either 1 or 2 grouping variables.

The first grouping variable will be plotted on the x-axis. The second grouping variable, if provided, will be showed with a different color per variable level.

Parameters
  • grouping_vars (list of strings) – They should be existing grouping variables in the data. The list should be of lenght 1 or 2.

  • n_posterior_predictives (int) – Number of posterior samples to use for posterior predictives calculation. If n_posterior_predictives is bigger than the posterior samples, then calculation will continue with the total number of posterior samples.

  • x_order (list of strings) – Order to plot the levels of the first grouping variable in, otherwise the levels are inferred from the data objects.

  • hue_order (lists of strings) – Order to plot the levels of the second grouping variable (when provided) in, otherwise the levels are inferred from the data objects.

  • hue_labels (list of strings) – Labels corresponding to hue_order in the legend. Advised to specify hue_order when using this to avoid confusion. Only makes sense when the second grouping variable is provided.

  • show_data (bool) – Whether to show a vertical line for the mean data. Set to False to not show it.

  • show_intervals (either "HDI", "BCI", or None) – HDI is better when the distribution is not simmetrical. If None, then no intervals are shown.

  • alpha_intervals (float) – Alpha level for the intervals. Default is 5 percent which gives 95 percent BCIs and HDIs.

  • palette (palette name, list, or dict) – Colors to use for the different levels of the second grouping variable (when provided). Should be something that can be interpreted by color_palette(), or a dictionary mapping hue levels to matplotlib colors.

  • color (matplotlib color) – Color for both the mean data and intervals. Only used when there is 1 grouping variable.

  • ax (matplotlib axis, optional) – If provided, plot on this axis. Default is set to current Axes.

  • intervals_kws (dictionary) – Additional arguments for the matplotlib fill_between function that shows shaded intervals. By default, they are 50 percent transparent.

Returns

ax – Returns the matplotlib.axes.Axes object with the plot for further tweaking.

Return type

matplotlib.axes.Axes

plot_mean_posterior_predictives(n_posterior_predictives, **kwargs)

Plots the mean posterior predictives of choices.

The mean proportion of choices (in this case coded as accuracy) is calculated for each posterior sample across all trials, and then it’s plotted as a distribution. The mean accuracy in the data is plotted as vertical line. This allows to compare the real mean with the BCI or HDI of the predictions.

Parameters
  • n_posterior_predictives (int) – Number of posterior samples to use for posterior predictives calculation. If n_posterior_predictives is bigger than the posterior samples, then calculation will continue with the total number of posterior samples.

  • show_data (bool) – Whether to show a vertical line for the mean data. Set to False to not show it.

  • color (matplotlib color) – Color for both the mean data and intervals.

  • ax (matplotlib axis, optional) – If provided, plot on this axis. Default is set to current Axes.

  • gridsize (int) – Resolution of the kernel density estimation function, default to 100.

  • clip (tuple) – Range for the kernel density estimation function. Default is min and max values of the distribution.

  • show_intervals (either "HDI", "BCI", or None) – HDI is better when the distribution is not simmetrical. If None, then no intervals are shown.

  • alpha_intervals (float) – Alpha level for the intervals. Default is 5 percent which gives 95 percent BCIs and HDIs.

  • intervals_kws (dictionary) – Additional arguments for the matplotlib fill_between function that shows shaded intervals. By default, they are 50 percent transparent.

Returns

ax – Returns the matplotlib.axes.Axes object with the plot for further tweaking.

Return type

matplotlib.axes.Axes