ModelResults class for RL models
There is one class to inspect model fits of RL models (fitted on choices alone): RLModelResults_2A.
The main functions of this class are:
Assess the model’s convergence and mcmc diagnostics, to make sure that the sampling was successful. This step is crucial and should be preferably done first.
Provide a measure of the model’s quantitative fit to the data (i.e., the Watanabe-Akaike information criterion). This is important when comparing the quantitative fit to the data of several, competing models.
Visualize and make interval-based (either Bayesian Credible Intervals or Higher Density Intervals) inferences on the posterior distributions of the model’s parameters. This is important when specific hypotheses were made about the parameters’ values.
Calculate and visualize posterior predictive distributions of the observed data. This step is important to assess the qualitative fit of the model to the data. Qualitative fit should be assessed not only when comparing different competing models, but also when a single candidate model is fitted. Different ways of calculating posterior predictive distributions are provided, together with different plotting options. In general, emphasis is given to calculating posterior predictive distributions across conditions. This allows us to assess whether a certain behavioral pattern observed in the data (e.g., due to experimental manipulations) can also be reproduced by the model.
All models
- class rlssm.fits_RL.ModelResults(model_label, data_info, parameters_info, priors, rhat, waic, last_values, samples, trial_samples)
- plot_posteriors(gridsize=100, clip=None, show_intervals='HDI', alpha_intervals=0.05, intervals_kws=None, **kwargs)
Plots posterior predictives of the model’s parameters.
If the model is hierarchical, then only the group parameters are plotted. In particular, group means are plotted in the first row and group standard deviations are plotted in the second row. By default, 95 percent HDI are shown. The kernel density estimation is calculated using scipy.stats.gaussian_kde.
- Parameters
gridsize (int, default to 100) – Resolution of the kernel density estimation function, default to 100.
clip (tuple of (float, float), optional) – Range for the kernel density estimation function. Default is min and max values of the distribution.
show_intervals (str, default to "HDI") – Either “HDI”, “BCI”, or None. HDI is better when the distribution is not simmetrical. If None, then no intervals are shown.
alpha_intervals (float, default to .05) – Alpha level for the intervals. Default is 5 percent which gives 95 percent BCIs and HDIs.
intervals_kws (dict) – Additional arguments for matplotlib.axes.Axes.fill_between that shows shaded intervals. By default, they are 50 percent transparent.
**kwargs – Additional parameters for seaborn.FacetGrid.
- Returns
g
- Return type
seaborn.FacetGrid
- to_pickle(filename=None)
Pickle the fitted model’s results object to file.
This can be used to store the model’s result and read them and inspect them at a later stage, without having to refit the model.
- Parameters
filename (str, optional) – File path where the pickled object will be stored. If not specified, is set to
Reinforcement learning models
- class rlssm.fits_RL.RLModelResults_2A(model_label, data_info, parameters_info, priors, rhat, waic, last_values, samples, trial_samples)
RLModelResults allows to perform various model checks on fitted RL models.
In particular, this can be used to to visualize the estimated posterior distributions and to calculate and visualize the estimated posterior predictives distributions.
- show-inheritance
- inherited-members
- get_grouped_posterior_predictives_summary(grouping_vars, n_posterior_predictives=500)
Calculates summary of posterior predictives of choices, separately for a list of grouping variables.
The mean proportion of choices (in this case coded as accuracy) is calculated for each posterior sample across all trials in all conditions combination.
For example, if grouping_vars=[‘reward’, ‘difficulty’], posterior predictives will be collapsed for all combinations of levels of the reward and difficulty variables.
- Parameters
grouping_vars (list of strings) – They should be existing grouping variables in the data.
n_posterior_predictives (int) – Number of posterior samples to use for posterior predictives calculation. If n_posterior_predictives is bigger than the posterior samples, then calculation will continue with the total number of posterior samples.
- Returns
out – Pandas DataFrame. The column contains the mean accuracy. The row indes is a pandas.MultIndex, with the grouping variables as higher level and number of samples as lower level.
- Return type
DataFrame
- get_posterior_predictives(n_posterior_predictives=500)
Calculates posterior predictives of choices.
- Parameters
n_posterior_predictives (int) – Number of posterior samples to use for posterior predictives calculation. If n_posterior_predictives is bigger than the posterior samples, then calculation will continue with the total number of posterior samples.
- Returns
pp_acc – Array of shape (n_samples, n_trials).
- Return type
numpy.ndarray
- get_posterior_predictives_df(n_posterior_predictives=500)
Calculates posterior predictives of choices.
- Parameters
n_posterior_predictives (int) – Number of posterior samples to use for posterior predictives calculation. If n_posterior_predictives is bigger than the posterior samples, then calculation will continue with the total number of posterior samples.
- Returns
out – Data frame of shape (n_samples, n_trials).
- Return type
DataFrame
- get_posterior_predictives_summary(n_posterior_predictives=500)
Calculates summary of posterior predictives of choices.
The mean proportion of choices (in this case coded as accuracy) is calculated for each posterior sample across all trials.
- Parameters
n_posterior_predictives (int) – Number of posterior samples to use for posterior predictives calculation. If n_posterior_predictives is bigger than the posterior samples, then calculation will continue with the total number of posterior samples.
- Returns
out – Data frame, where every row corresponds to a posterior sample. The column contains the mean accuracy for each posterior sample over the whole dataset.
- Return type
DataFrame
- plot_mean_grouped_posterior_predictives(grouping_vars, n_posterior_predictives=500, **kwargs)
Plots the mean posterior predictives of choices, separately for either 1 or 2 grouping variables.
The first grouping variable will be plotted on the x-axis. The second grouping variable, if provided, will be showed with a different color per variable level.
- Parameters
grouping_vars (list of strings) – They should be existing grouping variables in the data. The list should be of lenght 1 or 2.
n_posterior_predictives (int) – Number of posterior samples to use for posterior predictives calculation. If n_posterior_predictives is bigger than the posterior samples, then calculation will continue with the total number of posterior samples.
x_order (list of strings) – Order to plot the levels of the first grouping variable in, otherwise the levels are inferred from the data objects.
hue_order (lists of strings) – Order to plot the levels of the second grouping variable (when provided) in, otherwise the levels are inferred from the data objects.
hue_labels (list of strings) – Labels corresponding to hue_order in the legend. Advised to specify hue_order when using this to avoid confusion. Only makes sense when the second grouping variable is provided.
show_data (bool) – Whether to show a vertical line for the mean data. Set to False to not show it.
show_intervals (either "HDI", "BCI", or None) – HDI is better when the distribution is not simmetrical. If None, then no intervals are shown.
alpha_intervals (float) – Alpha level for the intervals. Default is 5 percent which gives 95 percent BCIs and HDIs.
palette (palette name, list, or dict) – Colors to use for the different levels of the second grouping variable (when provided). Should be something that can be interpreted by color_palette(), or a dictionary mapping hue levels to matplotlib colors.
color (matplotlib color) – Color for both the mean data and intervals. Only used when there is 1 grouping variable.
ax (matplotlib axis, optional) – If provided, plot on this axis. Default is set to current Axes.
intervals_kws (dictionary) – Additional arguments for the matplotlib fill_between function that shows shaded intervals. By default, they are 50 percent transparent.
- Returns
ax – Returns the matplotlib.axes.Axes object with the plot for further tweaking.
- Return type
matplotlib.axes.Axes
- plot_mean_posterior_predictives(n_posterior_predictives, **kwargs)
Plots the mean posterior predictives of choices.
The mean proportion of choices (in this case coded as accuracy) is calculated for each posterior sample across all trials, and then it’s plotted as a distribution. The mean accuracy in the data is plotted as vertical line. This allows to compare the real mean with the BCI or HDI of the predictions.
- Parameters
n_posterior_predictives (int) – Number of posterior samples to use for posterior predictives calculation. If n_posterior_predictives is bigger than the posterior samples, then calculation will continue with the total number of posterior samples.
show_data (bool) – Whether to show a vertical line for the mean data. Set to False to not show it.
color (matplotlib color) – Color for both the mean data and intervals.
ax (matplotlib axis, optional) – If provided, plot on this axis. Default is set to current Axes.
gridsize (int) – Resolution of the kernel density estimation function, default to 100.
clip (tuple) – Range for the kernel density estimation function. Default is min and max values of the distribution.
show_intervals (either "HDI", "BCI", or None) – HDI is better when the distribution is not simmetrical. If None, then no intervals are shown.
alpha_intervals (float) – Alpha level for the intervals. Default is 5 percent which gives 95 percent BCIs and HDIs.
intervals_kws (dictionary) – Additional arguments for the matplotlib fill_between function that shows shaded intervals. By default, they are 50 percent transparent.
- Returns
ax – Returns the matplotlib.axes.Axes object with the plot for further tweaking.
- Return type
matplotlib.axes.Axes