Skip to content

Plotting

This module contains functions for plotting results from Bayesian models.

plotting

Functions:

  • plot_posterior_forest

    Creates a forest plot of the Bayesian results with point estimates and

  • plot_posteriors

    Creates a summary of the Bayesian results and plots the posterior

plot_posterior_forest

plot_posterior_forest(results: InferenceData, var_names: List[str], titles: List[str], credible_interval: float = 0.95, file_name: Optional[str] = None, **kwargs) -> None

Creates a forest plot of the Bayesian results with point estimates and credible intervals, with variables displayed along the x-axis.

Parameters:

  • results

    (InferenceData) –

    The results object from the Bayesian model fit, typically an InferenceData object.

  • var_names

    (List[str]) –

    A list of variable names to summarize and plot.

  • titles

    (List[str]) –

    A list of titles corresponding to the variables. Each title corresponds to one variable name.

  • credible_interval

    (float, default: 0.95 ) –

    The credible interval to use for the error bar. Defaults to 0.95.

  • file_name

    (Optional[str], default: None ) –

    Optional; The name of the file to save the plot. If None, the plot is not saved.

  • **kwargs

    (Any, default: {} ) –

    Optional keyword arguments passed to the plotting functions.

Returns:

  • None ( None ) –

    This function does not return anything.

Example
plot_forest_bayesian_results(
        results_2,
        ["mu", "tau", "eta"],
        ["Mean", "Standard Deviation", "Noise"],
        file_name="my_model_forest_plot.svg",
    )
Source code in stats_utils/bayesian/plotting.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def plot_posterior_forest(
    results: az.InferenceData,
    var_names: List[str],
    titles: List[str],
    credible_interval: float = 0.95,
    file_name: Optional[str] = None,
    **kwargs
) -> None:
    """
    Creates a forest plot of the Bayesian results with point estimates and
    credible intervals, with variables displayed along the x-axis.

    Args:
        results (az.InferenceData): The results object from the Bayesian model
            fit, typically an `InferenceData` object.
        var_names (List[str]): A list of variable names to summarize and plot.
        titles (List[str]): A list of titles corresponding to the variables.
            Each title corresponds to one variable name.
        credible_interval (float): The credible interval to use for the
            error bar. Defaults to `0.95`.
        file_name (Optional[str]): Optional; The name of the file to
            save the plot. If `None`, the plot is not saved.
        **kwargs (Any): Optional keyword arguments passed to the plotting
            functions.

    Returns:
        None: This function does not return anything.

    Example:
        ```
        plot_forest_bayesian_results(
                results_2,
                ["mu", "tau", "eta"],
                ["Mean", "Standard Deviation", "Noise"],
                file_name="my_model_forest_plot.svg",
            )
        ```
    """
    # Creating a summary for the specified variables
    summary = az.summary(
        results, var_names=var_names, hdi_prob=credible_interval, kind="stats"
    )

    # Get default colour palette from matplotlib
    palette = plt.rcParams["axes.prop_cycle"].by_key()["color"]

    # Creating a figure
    fig, ax = plt.subplots(figsize=(3, 2.5))

    # Initializing a colour list to store colors for each distribution plot
    colours = [palette[0]] * len(var_names)

    # Plotting points with error bars for each variable
    for i, v in enumerate(var_names):
        # Extracting the point estimate and the HDI
        mean = summary.loc[v, "mean"]
        hdi_low = summary.loc[v, "hdi_2.5%"]
        hdi_high = summary.loc[v, "hdi_97.5%"]

        # Get colour
        if np.sign(hdi_low) == np.sign(hdi_high):
            colours[i] = palette[1]

        # Plotting the point estimate
        ax.scatter(
            i,
            mean,
            color=colours[i],
            edgecolors="black",
            linewidths=1,
        )

        # Plotting the error bars (credible interval)
        ax.errorbar(
            x=[i, i],
            y=[hdi_low, hdi_high],
            color=colours[i],
            capsize=0,
            capthick=2,
            linewidth=2,
            ecolor=colours[i],
            zorder=-1,
            alpha=0.8,
        )

    # Setting the x-axis to show the variable titles
    ax.set_xticks(np.arange(len(var_names)))

    # Replace spaces with newlines in titles
    titles = [title.replace(" ", "\n") for title in titles]
    ax.set_xticklabels(titles, rotation=45, ha="right")

    ax.set_ylabel("Parameter estimate")
    ax.axhline(
        0, color="grey", linestyle="--", alpha=0.7
    )  # Add a horizontal line at zero for reference

    sns.despine()
    plt.tight_layout()

    # Save the figure if file_name is provided
    if file_name:
        if not os.path.exists("figures"):
            os.makedirs("figures")
        plt.savefig(os.path.join("figures", file_name), bbox_inches="tight")

    plt.show()

plot_posteriors

plot_posteriors(results: InferenceData, var_names: List[str], titles: List[str], file_name: Optional[str] = None, **kwargs) -> None

Creates a summary of the Bayesian results and plots the posterior distributions with optional keyword arguments for plotting functions. Optionally saves the figure to a file.

Parameters:

  • results

    (InferenceData) –

    The results object from the Bayesian model fit, typically an Arviz InferenceData object.

  • var_names

    (List[str]) –

    A list of variable names to summarize and plot.

  • titles

    (List[str]) –

    A list of titles corresponding to the variables. Each title corresponds to one variable name.

  • file_name

    (Optional[str], default: None ) –

    Optional; The name of the file to save the plot. If None, the plot is not saved.

  • **kwargs

    (Any, default: {} ) –

    Optional keyword arguments passed to Seaborn KDE plot function.

Returns:

  • None ( None ) –

    This function does not return anything.

Example
plot_bayesian_results(
    results_2,
    ["mu", "tau", "eta"],
    ["Mean", "Standard Deviation", "Noise"],
    file_name="my_model_posteriors.svg",
    linewidth=2.5,
    linestyle="--",
)
Source code in stats_utils/bayesian/plotting.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def plot_posteriors(
    results: az.InferenceData,
    var_names: List[str],
    titles: List[str],
    file_name: Optional[str] = None,
    **kwargs
) -> None:
    """
    Creates a summary of the Bayesian results and plots the posterior
    distributions with optional keyword arguments for plotting functions.
    Optionally saves the figure to a file.

    Args:
        results (az.InferenceData): The results object from the Bayesian
            model fit, typically an Arviz `InferenceData` object.
        var_names (List[str]): A list of variable names to summarize and plot.
        titles (List[str]): A list of titles corresponding to the variables.
            Each title corresponds to one variable name.
        file_name (Optional[str]): Optional; The name of the file to save
            the plot. If `None`, the plot is not saved.
        **kwargs (Any): Optional keyword arguments passed to Seaborn KDE
            plot function.

    Returns:
        None: This function does not return anything.

    Example:
        ```
        plot_bayesian_results(
            results_2,
            ["mu", "tau", "eta"],
            ["Mean", "Standard Deviation", "Noise"],
            file_name="my_model_posteriors.svg",
            linewidth=2.5,
            linestyle="--",
        )
        ```
    """

    # Creating a summary for the specified variables
    summary = az.summary(
        results,
        var_names=var_names,
        hdi_prob=0.95,
    )

    # Creating a figure and axes objects
    f, ax = plt.subplots(1, len(var_names), figsize=(len(var_names) * 2, 2.2))

    # Retrieving the default colour palette from matplotlib
    palette = plt.rcParams["axes.prop_cycle"].by_key()["color"]

    # Initializing a colour list to store colors for each distribution plot
    colours = [palette[0]] * len(var_names)

    # Determine the colour of the distribution based on 95% HDI
    for i, v in enumerate(var_names):
        var_summary = summary.loc[v]
        if np.sign(var_summary["hdi_2.5%"]) == np.sign(
            var_summary["hdi_97.5%"]
        ):
            colours[i] = palette[1]

    # Plotting distributions for each variable
    for i, v in enumerate(var_names):
        ax[i].set_xlabel(r"$\beta$ value")  # Label x-axis
        ax[i].axvline(
            0, color="black", linestyle="--", alpha=0.5
        )  # Add a vertical line at x=0
        sns.kdeplot(
            results.posterior[v].values.flatten(),
            fill=True,
            ax=ax[i],
            color=colours[i],
            **kwargs  # Pass additional kwargs to seaborn kdeplot
        )
        ax[i].set_ylabel("" if i > 0 else "Posterior\ndensity")
        ax[i].set_title(titles[i])  # Title

    sns.despine()

    plt.tight_layout()

    # Save the figure if file_name is provided
    if file_name:
        figure_folder = "figures"
        if not os.path.exists(figure_folder):
            os.makedirs(figure_folder)
        plt.savefig(os.path.join(figure_folder, file_name))