Skip to content

plot

add_cbar(ax, cmap, norm, **kwargs)

Truncate or expand cmap such that it covers axes limit and and colorbar to axes

Source code in pyhdx/plot.py
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
def add_cbar(ax, cmap, norm, **kwargs):
    """Truncate or expand cmap such that it covers axes limit and and colorbar to axes"""

    N = cmap.N
    ymin, ymax = np.min(ax.get_ylim()), np.max(ax.get_ylim())
    values = np.linspace(ymin, ymax, num=N)

    norm_clip = copy(norm)
    norm_clip.clip = True
    colors = cmap(norm_clip(values))

    if isinstance(cmap, pplt.DiscreteColormap):
        listmode = "discrete"
    elif isinstance(cmap, pplt.ContinuousColormap):
        listmode = "continuous"
    else:
        listmode = "perceptual"

    cb_cmap = pplt.Colormap(colors, listmode=listmode)

    cb_norm = pplt.Norm("linear", vmin=ymin, vmax=ymax)  # todo allow log norms?
    cbar_kwargs = {**CBAR_KWARGS, **kwargs}
    reverse = np.diff(ax.get_ylim()) < 0

    cbar = ax.colorbar(cb_cmap, norm=cb_norm, reverse=reverse, **cbar_kwargs)

    return cbar

add_mse_panels(axes, fit_result, cmap=None, norm=None, panel_kwargs=None, fig=None, cbar=False, cbar_kwargs=None)

adds linear bar panels to axes showing some property (usually MSE)

Source code in pyhdx/plot.py
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
def add_mse_panels(
    axes,
    fit_result,
    cmap=None,
    norm=None,
    panel_kwargs=None,
    fig=None,
    cbar=False,
    cbar_kwargs=None,
):
    """adds linear bar panels to axes showing some property (usually MSE)"""

    residue_mse = fit_result.get_residue_mse()
    vmax = residue_mse.to_numpy().max()

    cmap = cmap or CMAP_NORM_DEFAULTS.cmaps["mse"]
    norm = norm or pplt.Norm("linear", vmin=0, vmax=vmax)

    collections = []
    for hdxm, ax in zip(fit_result.hdxm_set, axes):
        panel_kwargs = panel_kwargs or {}
        loc = panel_kwargs.pop("loc", "b")
        panel_kwargs = {"width": "5mm", "space": 0, **panel_kwargs}

        pn = ax.panel_axes(loc, **panel_kwargs)
        residue_values = residue_mse[hdxm.name]

        collection = single_linear_bar(
            pn,
            hdxm.coverage.r_number,
            residue_values.to_numpy().squeeze(),
            cmap=cmap,
            norm=norm,
        )
        collections.append(collection)

    if cbar:
        if fig is None:
            raise ValueError("Must pass 'fig' keyword argument to add a global colorbar")
        cbar_kwargs = cbar_kwargs or {}
        cbar_kwargs = {
            "width": CBAR_KWARGS["width"],
            "loc": "b",
            "length": 0.3,
            **cbar_kwargs,
        }

        cbar = fig.colorbar(cmap, norm=norm, **cbar_kwargs)
    else:
        cbar = None

    return collections, cbar

linear_bars(data, norm, cmap, cbar_sclf=1.0, **figure_kwargs)

Generate a linear bar plot with multiple subplots.

Parameters:

Name Type Description Default
data dict[str, dict[str, Series]]

A dictionary containing the data for each subplot. The keys are the top-level labels, and the values are dictionaries containing the data for each subplot. The data is represented as a pandas Series object.

required
norm

A normalization object to be applied to the color scale.

required
cmap

A colormap object to be used for coloring the bars.

required
cbar_sclf

A scaling factor to be applied to the color scale. Default is 1.0.

1.0
**figure_kwargs

Additional keyword arguments to be passed to the figure.

{}

Returns:

Name Type Description
fig

The figure object containing the plot.

axes

The axes object containing the subplots.

cbar

The colorbar object.

Source code in pyhdx/plot.py
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
def linear_bars(
    data: dict[str, dict[str, pd.Series]],
    norm,
    cmap,
    cbar_sclf=1.0,
    **figure_kwargs,
):
    """
    Generate a linear bar plot with multiple subplots.

    Args:
        data: A dictionary containing the data for each subplot. The keys are the top-level labels, and the values are dictionaries containing the data for each subplot. The data is represented as a pandas Series object.
        norm: A normalization object to be applied to the color scale.
        cmap: A colormap object to be used for coloring the bars.
        cbar_sclf: A scaling factor to be applied to the color scale. Default is 1.0.
        **figure_kwargs: Additional keyword arguments to be passed to the figure.

    Returns:
        fig: The figure object containing the plot.
        axes: The axes object containing the subplots.
        cbar: The colorbar object.
    """

    hspace = [elem for v in data.values() for elem in [0] * (len(v) - 1) + [None]][:-1]
    ncols = 1
    nrows = len(hspace) + 1

    figure_width = figure_kwargs.pop("width", cfg.plotting.page_width) / 25.4
    refaspect = figure_kwargs.pop("refaspect", cfg.plotting.linear_bars_aspect)
    cbar_width = figure_kwargs.pop("cbar_width", cfg.plotting.cbar_width) / 25.4

    fig, axes = pplt.subplots(
        nrows=nrows, ncols=ncols, refaspect=refaspect, width=figure_width, hspace=hspace
    )
    axes_iter = iter(axes)
    y_edges = [0, 1]
    for top_level, subdict in data.items():
        for i, (label, values) in enumerate(subdict.items()):
            ax = next(axes_iter)
            rmin, rmax = values.index.min(), values.index.max()
            r_edges = pplt.arange(rmin - 0.5, rmax + 0.5, 1)
            ax.pcolormesh(
                r_edges,
                y_edges,
                values.to_numpy().reshape(1, -1),
                cmap=cmap,
                vmin=norm.vmin,
                vmax=norm.vmax,
                levels=256,
            )
            ax.format(yticks=[])

            ax.text(
                1.02,
                0.5,
                label,
                horizontalalignment="left",
                verticalalignment="center",
                transform=ax.transAxes,
            )

            if i == 0:
                ax.format(title=top_level)

    axes.format(xlabel=r_xlabel)

    cmap_norm = copy(norm)
    cmap_norm.vmin *= cbar_sclf
    cmap_norm.vmax *= cbar_sclf

    cbar = fig.colorbar(cmap, norm=cmap_norm, loc="b", width=cbar_width)

    return fig, axes, cbar

linear_bars_figure(data, reference=None, groupby=None, field='dG', norm=None, cmap=None, **figure_kwargs)

Generate a linear bars figure based on the provided data.

Parameters:

Name Type Description Default
data DataFrame

A pandas DataFrame containing the data to be plotted.

required
reference Optional[str]

An optional string representing the reference value for subtraction.

None
groupby Optional[str]

An optional string representing the column to group the data by.

None
field str

A string representing the field to be plotted. Default is "dG".

'dG'
norm

An optional normalization function.

None
cmap

An optional colormap.

None
**figure_kwargs

Additional keyword arguments to be passed to the figure.

{}

Returns:

Name Type Description
fig

The generated figure.

axes

The axes of the figure.

cbar

The colorbar of the figure.

Source code in pyhdx/plot.py
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
def linear_bars_figure(
    data: pd.DataFrame,
    reference: Optional[str] = None,
    groupby: Optional[str] = None,
    field: str = "dG",
    norm=None,
    cmap=None,
    **figure_kwargs,
):
    """
    Generate a linear bars figure based on the provided data.

    Args:
        data: A pandas DataFrame containing the data to be plotted.
        reference: An optional string representing the reference value for subtraction.
        groupby: An optional string representing the column to group the data by.
        field: A string representing the field to be plotted. Default is "dG".
        norm: An optional normalization function.
        cmap: An optional colormap.
        **figure_kwargs: Additional keyword arguments to be passed to the figure.

    Returns:
        fig: The generated figure.
        axes: The axes of the figure.
        cbar: The colorbar of the figure.
    """

    if reference is None and field == "dG":
        cmap_default, norm_default = CMAP_NORM_DEFAULTS["dG"]
        ylabel = dG_ylabel
        sclf = 1e-3
    elif reference is not None and field == "dG":
        cmap_default, norm_default = CMAP_NORM_DEFAULTS["ddG"]
        ylabel = ddG_ylabel
        sclf = 1e-3
    elif reference is None and field == "rfu":
        cmap_default, norm_default = CMAP_NORM_DEFAULTS["rfu"]
        ylabel = "RFU"
        sclf = 1.0
    elif reference is not None and field == "rfu":
        cmap_default, norm_default = CMAP_NORM_DEFAULTS["drfu"]
        ylabel = "ΔRFU"
        sclf = 1.0
    else:
        cmap_default, norm_default = None, None
        ylabel = ""
        sclf = 1.0

    cmap = cmap or cmap_default
    norm = norm or norm_default

    if cmap is None:
        raise ValueError("No valid Colormap found")
    if norm is None:
        raise ValueError("No valid Norm found")

    reduced = data.xs(level=-1, key=field, axis=1)

    if groupby:
        grp_level = reduced.columns.names.index(groupby)
        bar_level = 1 - grp_level
    else:
        grp_level, bar_level = 0, 1

    flat = reduced.columns.to_flat_index().tolist()
    series_list = [reduced[col] for col in reduced.columns]

    # nest the individual pandas series in a dict according to grp / bar level
    result = defaultdict(dict)
    for tup, series in zip(flat, series_list):
        result[tup[grp_level]][tup[bar_level]] = series

    # subract reference values if given
    if reference is not None:
        for subdict in result.values():
            ref_values = subdict.pop(reference)
            for name, values in subdict.items():
                subdict[name] = values - ref_values

    fig, axes, cbar = linear_bars(
        result,
        norm=norm,
        cmap=cmap,
        cbar_sclf=sclf,
        **figure_kwargs,
    )

    cbar.set_label(ylabel)

    return fig, axes, cbar

plot_fitresults(fitresult_path, reference=None, plots='all', renew=False, cmap_and_norm=None, output_path=None, output_type='.png', **save_kwargs)

Parameters

fitresult_path plots renew cmap_and_norm: :obj:dict, optional Dictionary with cmap and norms to use. If None, reverts to defaults. Dict format: {'dG': (cmap, norm), 'ddG': (cmap, norm)}

output_type: list or str

Returns

Source code in pyhdx/plot.py
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
def plot_fitresults(
    fitresult_path,
    reference=None,
    plots="all",
    renew=False,
    cmap_and_norm=None,
    output_path=None,
    output_type=".png",
    **save_kwargs,
):
    """

    Parameters
    ----------
    fitresult_path
    plots
    renew
    cmap_and_norm: :obj:`dict`, optional
        Dictionary with cmap and norms to use. If `None`, reverts to defaults.
        Dict format: {'dG': (cmap, norm), 'ddG': (cmap, norm)}

    output_type: list or str

    Returns
    -------

    """

    raise DeprecationWarning("This function is deprecated, use FitResultPlot.plot_all instead")
    # batch results only
    history_path = fitresult_path / "model_history.csv"
    output_path = output_path or fitresult_path
    output_type = list([output_type]) if isinstance(output_type, str) else output_type
    fitresult = load_fitresult(fitresult_path)

    protein_states = fitresult.output.df.columns.get_level_values(0).unique()

    if isinstance(reference, int):
        reference_state = protein_states[reference]
    elif reference in protein_states:
        reference_state = reference
    elif reference is None:
        reference_state = None
    else:
        raise ValueError(f"Invalid value {reference!r} for 'reference'")

    # todo needs tidying up
    cmap_and_norm = cmap_and_norm or {}
    dG_cmap, dG_norm = cmap_and_norm.get("dG", (None, None))
    dG_cmap_default, dG_norm_default = default_cmap_norm("dG")
    ddG_cmap, ddG_norm = cmap_and_norm.get("ddG", (None, None))
    ddG_cmap_default, ddG_norm_default = default_cmap_norm("ddG")
    dG_cmap = ddG_cmap or dG_cmap_default
    dG_norm = dG_norm or dG_norm_default
    ddG_cmap = ddG_cmap or ddG_cmap_default
    ddG_norm = ddG_norm or ddG_norm_default

    # check_exists = lambda x: False if renew else x.exists()
    # todo add logic for checking renew or not

    if plots == "all":
        plots = [
            "loss",
            "rfu_coverage",
            "rfu_scatter",
            "dG_scatter",
            "ddG_scatter",
            "linear_bars",
            "rainbowclouds",
            "peptide_mse",
        ]

    # def check_update(pth, fname, extensions, renew):
    #     # Returns True if the target graph should be renewed or not
    #     if renew:
    #         return True
    #     else:
    #         pths = [pth / (fname + ext) for ext in extensions]
    #         return any([not pth.exists() for pth in pths])

    # plots = [p for p in plots if check_update(output_path, p, output_type, renew)]

    if "loss" in plots:
        loss_df = fitresult.losses
        loss_df.plot()

        mse_loss = loss_df["mse_loss"]
        reg_loss = loss_df.iloc[:, 1:].sum(axis=1)
        reg_percentage = 100 * reg_loss / (mse_loss + reg_loss)
        fig = plt.gcf()
        ax = plt.gca()
        ax1 = ax.twinx()
        reg_percentage.plot(ax=ax1, color="k")
        ax1.set_xlim(0, None)
        for ext in output_type:
            f_out = output_path / ("loss" + ext)
            plt.savefig(f_out)
        plt.close(fig)

    if "rfu_coverage" in plots:
        for hdxm in fitresult.hdxm_set:
            fig, axes, cbar_ax = peptide_coverage_figure(hdxm.data)
            for ext in output_type:
                f_out = output_path / (f"rfu_coverage_{hdxm.name}" + ext)
                plt.savefig(f_out)
            plt.close(fig)

    # todo rfu_scatter_timepoint

    if "rfu_scatter" in plots:
        fig, axes, cbar = residue_scatter_figure(fitresult.hdxm_set)
        for ext in output_type:
            f_out = output_path / ("rfu_scatter" + ext)
            plt.savefig(f_out)
        plt.close(fig)

    if "dG_scatter" in plots:
        fig, axes, cbars = dG_scatter_figure(fitresult.output.df, cmap=dG_cmap, norm=dG_norm)
        for ext in output_type:
            f_out = output_path / ("dG_scatter" + ext)
            plt.savefig(f_out)
        plt.close(fig)

    if "ddG_scatter" in plots:
        fig, axes, cbars = ddG_scatter_figure(
            fitresult.output.df, reference=reference, cmap=ddG_cmap, norm=ddG_norm
        )
        for ext in output_type:
            f_out = output_path / ("ddG_scatter" + ext)
            plt.savefig(f_out)
        plt.close(fig)

    if "linear_bars" in plots:
        fig, axes = linear_bars_figure(fitresult.output.df)
        for ext in output_type:
            f_out = output_path / ("dG_linear_bars" + ext)
            plt.savefig(f_out)
        plt.close(fig)

        if reference_state:
            fig, axes = linear_bars_figure(fitresult.output.df, reference=reference)
            for ext in output_type:
                f_out = output_path / ("ddG_linear_bars" + ext)
                plt.savefig(f_out)
            plt.close(fig)

    if "rainbowclouds" in plots:
        fig, ax = rainbowclouds_figure(fitresult.output.df)
        for ext in output_type:
            f_out = output_path / ("dG_rainbowclouds" + ext)
            plt.savefig(f_out)
        plt.close(fig)

        if reference_state:
            fig, axes = rainbowclouds_figure(fitresult.output.df, reference=reference)
            for ext in output_type:
                f_out = output_path / ("ddG_rainbowclouds" + ext)
                plt.savefig(f_out)
            plt.close(fig)

    if "peptide_mse" in plots:
        fig, axes, cbars = peptide_mse_figure(fitresult.get_peptide_mse())
        for ext in output_type:
            f_out = output_path / ("peptide_mse" + ext)
            plt.savefig(f_out)
        plt.close(fig)

residue_time_scatter_figure(hdxm, field='rfu', cmap='turbo', norm=None, scatter_kwargs=None, cbar_kwargs=None, **figure_kwargs)

per-residue per-exposure values for field field by weighted averaging

Source code in pyhdx/plot.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
def residue_time_scatter_figure(
    hdxm,
    field="rfu",
    cmap="turbo",
    norm=None,
    scatter_kwargs=None,
    cbar_kwargs=None,
    **figure_kwargs,
):
    """per-residue per-exposure values for field  `field` by weighted averaging"""

    n_subplots = hdxm.Nt
    ncols = figure_kwargs.pop("ncols", min(cfg.plotting.ncols, n_subplots))
    nrows = figure_kwargs.pop("nrows", int(np.ceil(n_subplots / ncols)))
    figure_width = figure_kwargs.pop("width", cfg.plotting.page_width) / 25.4
    refaspect = figure_kwargs.pop("refaspect", cfg.plotting.residue_scatter_aspect)
    cbar_width = figure_kwargs.pop("cbar_width", cfg.plotting.cbar_width) / 25.4

    cmap = pplt.Colormap(cmap)  # todo allow None as cmap
    norm = norm or pplt.Norm("linear", vmin=0, vmax=1)

    fig, axes = pplt.subplots(
        ncols=ncols,
        nrows=nrows,
        width=figure_width,
        refaspect=refaspect,
        sharey=4,
        **figure_kwargs,
    )
    scatter_kwargs = scatter_kwargs or {}
    axes_iter = iter(axes)
    for hdx_tp in hdxm:
        ax = next(axes_iter)
        residue_time_scatter(
            ax, hdx_tp, field=field, cmap=cmap, norm=norm, cbar=False, **scatter_kwargs
        )  # todo cbar kwargs? (check with other methods)
        ax.format(title=f"exposure: {hdx_tp.exposure:.1f}")

    for ax in axes_iter:
        ax.axis("off")

    axes.format(xlabel=r_xlabel, ylabel=field)

    cbar_kwargs = cbar_kwargs or {}
    cbars = []
    for ax in axes:
        if not ax.axison:
            continue

        cbar = add_cbar(ax, cmap, norm, **cbar_kwargs)
        cbars.append(cbar)

    return fig, axes, cbars

single_linear_bar(ax, x, z, cmap, norm, **kwargs)

makes a linear bar plot on supplied axis with values z and corresponding x values x

Source code in pyhdx/plot.py
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
def single_linear_bar(ax, x, z, cmap, norm, **kwargs):
    """makes a linear bar plot on supplied axis with values z and corresponding x values x"""

    if isinstance(z, pd.Series):
        z = z.to_numpy()
    elif isinstance(z, pd.DataFrame):
        assert len(z.columns) == 1, "Can only plot dataframes with 1 column"
        z = z.to_numpy().squeeze()

    img = np.expand_dims(z, 0)

    collection = ax.pcolormesh(
        pplt.edges(x),
        np.array([0, 1]),
        img,
        cmap=cmap,
        vmin=norm.vmin,
        vmax=norm.vmax,
        levels=256,
        **kwargs,
    )
    ax.format(yticks=[])

    return collection