1149 lines
39 KiB
Python
1149 lines
39 KiB
Python
# This code is licensed under the Apache License, Version 2.0. You may
|
|
# obtain a copy of this license in the LICENSE.txt file in the root directory
|
|
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
|
|
#
|
|
# Any modifications or derivative works of this code must retain this
|
|
# copyright notice, and modified files need to carry a notice indicating
|
|
# that they have been altered from the originals.
|
|
#
|
|
# NetworkX is distributed with the 3-clause BSD license.
|
|
#
|
|
# Copyright (C) 2004-2020, NetworkX Developers
|
|
# Aric Hagberg <hagberg@lanl.gov>
|
|
# Dan Schult <dschult@colgate.edu>
|
|
# Pieter Swart <swart@lanl.gov>
|
|
# All rights reserved.
|
|
#
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are
|
|
# met:
|
|
#
|
|
# * Redistributions of source code must retain the above copyright
|
|
# notice, this list of conditions and the following disclaimer.
|
|
#
|
|
# * Redistributions in binary form must reproduce the above
|
|
# copyright notice, this list of conditions and the following
|
|
# disclaimer in the documentation and/or other materials provided
|
|
# with the distribution.
|
|
#
|
|
# * Neither the name of the NetworkX Developers nor the names of its
|
|
# contributors may be used to endorse or promote products derived
|
|
# from this software without specific prior written permission.
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
|
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
|
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
|
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
|
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
|
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
#
|
|
# This code is forked from networkx's networkx_pylab.py module and adapted to
|
|
# work with rustworkx instead. The original source can be found at:
|
|
#
|
|
# https://github.com/networkx/networkx/blob/80b1afa2ae50314a8312998c214a8c1a356adcf1/networkx/drawing/nx_pylab.py
|
|
|
|
"""Draw a rustworkx graph with matplotlib."""
|
|
|
|
from collections.abc import Iterable
|
|
from itertools import islice, cycle
|
|
from numbers import Number
|
|
|
|
import numpy as np
|
|
|
|
import rustworkx
|
|
|
|
|
|
__all__ = [
|
|
"mpl_draw",
|
|
]
|
|
|
|
|
|
def mpl_draw(graph, pos=None, ax=None, arrows=True, with_labels=False, **kwds):
|
|
r"""Draw a graph with Matplotlib.
|
|
|
|
.. note::
|
|
|
|
Matplotlib is an optional dependency and will not be installed with
|
|
rustworkx by default. If you intend to use this function make sure that
|
|
you install matplotlib with either ``pip install matplotlib`` or
|
|
``pip install 'rustworkx[mpl]'``
|
|
|
|
:param graph: A rustworkx graph, either a :class:`~rustworkx.PyGraph` or a
|
|
:class:`~rustworkx.PyDiGraph`.
|
|
:param dict pos: An optional dictionary (or
|
|
a :class:`~rustworkx.Pos2DMapping` object) with nodes as keys and
|
|
positions as values. If not specified a spring layout positioning will
|
|
be computed. See `layout_functions` for functions that compute
|
|
node positions.
|
|
:param matplotlib.Axes ax: An optional Matplotlib Axes object to draw the
|
|
graph in.
|
|
:param bool arrows: For :class:`~rustworkx.PyDiGraph` objects if ``True``
|
|
draw arrowheads. (defaults to ``True``) Note, that the Arrows will
|
|
be the same color as edges.
|
|
:param str arrowstyle: An optional string for directed graphs to choose
|
|
the style of the arrowsheads. See
|
|
:class:`matplotlib.patches.ArrowStyle` for more options. By default the
|
|
value is set to ``'-\|>'``.
|
|
:param int arrow_size: For directed graphs, choose the size of the arrow
|
|
head's length and width. See
|
|
:class:`matplotlib.patches.FancyArrowPatch` attribute and constructor
|
|
kwarg ``mutation_scale`` for more info. Defaults to 10.
|
|
:param bool with_labels: Set to ``True`` to draw labels on the nodes. Edge
|
|
labels will only be drawn if the ``edge_labels`` parameter is set to a
|
|
function. Defaults to ``False``.
|
|
:param list node_list: An optional list of node indices in the graph to
|
|
draw. If not specified all nodes will be drawn.
|
|
:param list edge_list: An option list of edges in the graph to draw. If not
|
|
specified all edges will be drawn
|
|
:param int|list node_size: Optional size of nodes. If an array is
|
|
specified it must be the same length as node_list. Defaults to 300
|
|
:param node_color: Optional node color. Can be a single color or
|
|
a sequence of colors with the same length as node_list. Color can be
|
|
string or rgb (or rgba) tuple of floats from 0-1. If numeric values
|
|
are specified they will be mapped to colors using the ``cmap`` and
|
|
``vmin``,``vmax`` parameters. See :func:`matplotlib.scatter` for more
|
|
details. Defaults to ``'#1f78b4'``)
|
|
:param str node_shape: The optional shape node. The specification is the
|
|
same as the :func:`matplotlib.pyplot.scatter` function's ``marker``
|
|
kwarg, valid options are one of
|
|
``['s', 'o', '^', '>', 'v', '<', 'd', 'p', 'h', '8']``. Defaults to
|
|
``'o'``
|
|
:param float alpha: Optional value for node and edge transparency
|
|
:param matplotlib.colors.Colormap cmap: An optional Matplotlib colormap
|
|
object for mapping intensities of nodes
|
|
:param float vmin: Optional minimum value for node colormap scaling
|
|
:param float vmax: Optional minimum value for node colormap scaling
|
|
:param float|sequence linewidths: An optional line width for symbol
|
|
borders. If a sequence is specified it must be the same length as
|
|
node_list. Defaults to 1.0
|
|
:param float|sequence width: An optional width to use for edges. Can
|
|
either be a float or sequence of floats. If a sequence is specified
|
|
it must be the same length as node_list. Defaults to 1.0
|
|
:param str|sequence edge_color: color or array of colors (default='k')
|
|
Edge color. Can be a single color or a sequence of colors with the same
|
|
length as edge_list. Color can be string or rgb (or rgba) tuple of
|
|
floats from 0-1. If numeric values are specified they will be
|
|
mapped to colors using the ``edge_cmap`` and ``edge_vmin``,
|
|
``edge_vmax`` parameters.
|
|
:param matplotlib.colors.Colormap edge_cmap: An optional Matplotlib
|
|
colormap for mapping intensities of edges.
|
|
:param float edge_vmin: Optional minimum value for edge colormap scaling
|
|
:param float edge_vmax: Optional maximum value for node colormap scaling
|
|
:param str style: An optional string to specify the edge line style.
|
|
For example, ``'-'``, ``'--'``, ``'-.'``, ``':'`` or words like
|
|
``'solid'`` or ``'dashed'``. See the
|
|
:class:`matplotlib.patches.FancyArrowPatch` attribute and kwarg
|
|
``linestyle`` for more details. Defaults to ``'solid'``.
|
|
:param func labels: An optional callback function that will be passed a
|
|
node payload and return a string label for the node. For example::
|
|
|
|
labels=str
|
|
|
|
could be used to just return a string cast of the node's data payload.
|
|
Or something like::
|
|
|
|
labels=lambda node: node['label']
|
|
|
|
could be used if the node payloads are dictionaries.
|
|
:param func edge_labels: An optional callback function that will be passed
|
|
an edge payload and return a string label for the edge. For example::
|
|
|
|
edge_labels=str
|
|
|
|
could be used to just return a string cast of the edge's data payload.
|
|
Or something like::
|
|
|
|
edge_labels=lambda edge: edge['label']
|
|
|
|
could be used if the edge payloads are dictionaries. If this is set
|
|
edge labels will be drawn in the visualization.
|
|
:param int font_size: An optional fontsize to use for text labels, By
|
|
default a value of 12 is used for nodes and 10 for edges.
|
|
:param str font_color: An optional font color for strings. By default
|
|
``'k'`` (ie black) is set.
|
|
:param str font_weight: An optional string used to specify the font weight.
|
|
By default a value of ``'normal'`` is used.
|
|
:param str font_family: An optional font family to use for strings. By
|
|
default ``'sans-serif'`` is used.
|
|
:param str label: An optional string label to use for the graph legend.
|
|
:param str connectionstyle: An optional value used to create a curved arc
|
|
of rounding radius rad. For example,
|
|
``connectionstyle='arc3,rad=0.2'``. See
|
|
:class:`matplotlib.patches.ConnectionStyle` and
|
|
:class:`matplotlib.patches.FancyArrowPatch` for more info. By default
|
|
this is set to ``"arc3"``.
|
|
|
|
:returns: A matplotlib figure for the visualization if not running with an
|
|
interactive backend (like in jupyter) or if ``ax`` is not set.
|
|
:rtype: matplotlib.figure.Figure
|
|
|
|
For Example:
|
|
|
|
.. jupyter-execute::
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import rustworkx as rx
|
|
from rustworkx.visualization import mpl_draw
|
|
|
|
G = rx.generators.directed_path_graph(25)
|
|
mpl_draw(G)
|
|
plt.draw()
|
|
"""
|
|
try:
|
|
import matplotlib.pyplot as plt # type: ignore
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"matplotlib needs to be installed prior to running "
|
|
"rustworkx.visualization.mpl_draw(). You can install "
|
|
"matplotlib with:\n'pip install matplotlib'"
|
|
) from e
|
|
if ax is None:
|
|
cf = plt.gcf()
|
|
else:
|
|
cf = ax.get_figure()
|
|
cf.set_facecolor("w")
|
|
if ax is None:
|
|
if cf.axes:
|
|
ax = cf.gca()
|
|
else:
|
|
ax = cf.add_axes((0, 0, 1, 1))
|
|
|
|
draw_graph(graph, pos=pos, ax=ax, arrows=arrows, with_labels=with_labels, **kwds)
|
|
ax.set_axis_off()
|
|
plt.draw_if_interactive()
|
|
if not plt.isinteractive() or ax is None:
|
|
return cf
|
|
|
|
|
|
def draw_graph(graph, pos=None, arrows=True, with_labels=False, **kwds):
|
|
r"""Draw the graph using Matplotlib.
|
|
|
|
Draw the graph with Matplotlib with options for node positions,
|
|
labeling, titles, and many other drawing features.
|
|
See draw() for simple drawing without labels or axes.
|
|
|
|
Parameters
|
|
----------
|
|
graph: A rustworkx :class:`~rustworkx.PyDiGraph` or
|
|
:class:`~rustworkx.PyGraph`
|
|
|
|
pos : dictionary, optional
|
|
A dictionary with nodes as keys and positions as values.
|
|
If not specified a spring layout positioning will be computed.
|
|
See :mod:`rustworkx.drawing.layout` for functions that
|
|
compute node positions.
|
|
|
|
|
|
Notes
|
|
-----
|
|
For directed graphs, arrows are drawn at the head end. Arrows can be
|
|
turned off with keyword arrows=False.
|
|
|
|
"""
|
|
try:
|
|
import matplotlib.pyplot as plt
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"matplotlib needs to be installed prior to running "
|
|
"rustworkx.visualization.mpl_draw(). You can install "
|
|
"matplotlib with:\n'pip install matplotlib'"
|
|
) from e
|
|
|
|
valid_node_kwds = {
|
|
"node_list",
|
|
"node_size",
|
|
"node_color",
|
|
"node_shape",
|
|
"alpha",
|
|
"cmap",
|
|
"vmin",
|
|
"vmax",
|
|
"ax",
|
|
"linewidths",
|
|
"edgecolors",
|
|
"label",
|
|
}
|
|
|
|
valid_edge_kwds = {
|
|
"edge_list",
|
|
"width",
|
|
"edge_color",
|
|
"style",
|
|
"alpha",
|
|
"arrowstyle",
|
|
"arrow_size",
|
|
"edge_cmap",
|
|
"edge_vmin",
|
|
"edge_vmax",
|
|
"ax",
|
|
"label",
|
|
"node_size",
|
|
"node_list",
|
|
"node_shape",
|
|
"connectionstyle",
|
|
"min_source_margin",
|
|
"min_target_margin",
|
|
}
|
|
|
|
valid_label_kwds = {
|
|
"labels",
|
|
"font_size",
|
|
"font_color",
|
|
"font_family",
|
|
"font_weight",
|
|
"alpha",
|
|
"bbox",
|
|
"ax",
|
|
"horizontalalignment",
|
|
"verticalalignment",
|
|
}
|
|
|
|
valid_edge_label_kwds = {
|
|
"edge_labels",
|
|
"font_size",
|
|
"font_color",
|
|
"font_family",
|
|
"font_weight",
|
|
"alpha",
|
|
"bbox",
|
|
"ax",
|
|
"rotate",
|
|
"horizontalalignment",
|
|
"verticalalignment",
|
|
}
|
|
|
|
valid_kwds = valid_node_kwds | valid_edge_kwds | valid_label_kwds | valid_edge_label_kwds
|
|
|
|
if any([k not in valid_kwds for k in kwds]):
|
|
invalid_args = ", ".join([k for k in kwds if k not in valid_kwds])
|
|
raise ValueError(f"Received invalid argument(s): {invalid_args}")
|
|
|
|
label_fn = kwds.pop("labels", None)
|
|
if label_fn:
|
|
kwds["labels"] = {x: label_fn(graph[x]) for x in graph.node_indices()}
|
|
edge_label_fn = kwds.pop("edge_labels", None)
|
|
if edge_label_fn:
|
|
kwds["edge_labels"] = {
|
|
(x[0], x[1]): edge_label_fn(x[2]) for x in graph.weighted_edge_list()
|
|
}
|
|
|
|
node_kwds = {k: v for k, v in kwds.items() if k in valid_node_kwds}
|
|
edge_kwds = {k: v for k, v in kwds.items() if k in valid_edge_kwds}
|
|
if isinstance(edge_kwds.get("alpha"), list):
|
|
del edge_kwds["alpha"]
|
|
label_kwds = {k: v for k, v in kwds.items() if k in valid_label_kwds}
|
|
edge_label_kwds = {k: v for k, v in kwds.items() if k in valid_edge_label_kwds}
|
|
|
|
if pos is None:
|
|
pos = rustworkx.spring_layout(graph) # default to spring layout
|
|
|
|
draw_nodes(graph, pos, **node_kwds)
|
|
draw_edges(graph, pos, arrows=arrows, **edge_kwds)
|
|
if with_labels:
|
|
draw_labels(graph, pos, **label_kwds)
|
|
if edge_label_fn:
|
|
draw_edge_labels(graph, pos, **edge_label_kwds)
|
|
plt.draw_if_interactive()
|
|
|
|
|
|
def draw_nodes(
|
|
graph,
|
|
pos,
|
|
node_list=None,
|
|
node_size=300,
|
|
node_color="#1f78b4",
|
|
node_shape="o",
|
|
alpha=None,
|
|
cmap=None,
|
|
vmin=None,
|
|
vmax=None,
|
|
ax=None,
|
|
linewidths=None,
|
|
edgecolors=None,
|
|
label=None,
|
|
):
|
|
"""Draw the nodes of the graph.
|
|
|
|
This draws only the nodes of the graph.
|
|
|
|
:param graph: A rustworkx graph, either a :class:`~rustworkx.PyGraph` or a
|
|
:class:`~rustworkx.PyDiGraph`.
|
|
|
|
:param dict pos: A dictionary with nodes as keys and positions as values.
|
|
Positions should be sequences of length 2.
|
|
|
|
:param Axes ax: An optional Matplotlib Axes object, if specified it will
|
|
draw the graph in the specified Matplotlib axes.
|
|
|
|
:param list node_list: If specified only draw the specified node indices.
|
|
If not specified all nodes in the graph will be drawn.
|
|
|
|
:param float|array node_size: Size of nodes. If an array it must be the
|
|
same length as node_list. Defaults to 300
|
|
|
|
node_color : color or array of colors (default='#1f78b4')
|
|
Node color. Can be a single color or a sequence of colors with the same
|
|
length as node_list. Color can be string or rgb (or rgba) tuple of
|
|
floats from 0-1. If numeric values are specified they will be
|
|
mapped to colors using the cmap and vmin,vmax parameters. See
|
|
matplotlib.scatter for more details.
|
|
|
|
node_shape : string (default='o')
|
|
The shape of the node. Specification is as matplotlib.scatter
|
|
marker, one of 'so^>v<dph8'.
|
|
|
|
alpha : float or array of floats (default=None)
|
|
The node transparency. This can be a single alpha value,
|
|
in which case it will be applied to all the nodes of color. Otherwise,
|
|
if it is an array, the elements of alpha will be applied to the colors
|
|
in order (cycling through alpha multiple times if necessary).
|
|
|
|
cmap : Matplotlib colormap (default=None)
|
|
Colormap for mapping intensities of nodes
|
|
|
|
vmin,vmax : floats or None (default=None)
|
|
Minimum and maximum for node colormap scaling
|
|
|
|
linewidths : [None | scalar | sequence] (default=1.0)
|
|
Line width of symbol border
|
|
|
|
edgecolors : [None | scalar | sequence] (default = node_color)
|
|
Colors of node borders
|
|
|
|
label : [None | string]
|
|
Label for legend
|
|
|
|
Returns
|
|
-------
|
|
matplotlib.collections.PathCollection
|
|
`PathCollection` of the nodes.
|
|
|
|
"""
|
|
try:
|
|
import matplotlib as mpl
|
|
import matplotlib.collections # type: ignore
|
|
import matplotlib.pyplot as plt
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"matplotlib needs to be installed prior to running "
|
|
"rustworkx.visualization.mpl_draw(). You can install "
|
|
"matplotlib with:\n'pip install matplotlib'"
|
|
) from e
|
|
|
|
if ax is None:
|
|
ax = plt.gca()
|
|
|
|
if node_list is None:
|
|
node_list = graph.node_indices()
|
|
|
|
# empty node_list, no drawing
|
|
if len(node_list) == 0:
|
|
return mpl.collections.PathCollection(None)
|
|
|
|
try:
|
|
xy = np.asarray([pos[v] for v in node_list])
|
|
except KeyError as e:
|
|
raise IndexError(f"Node {e} has no position.") from e
|
|
|
|
if isinstance(alpha, Iterable):
|
|
node_color = apply_alpha(node_color, alpha, node_list, cmap, vmin, vmax)
|
|
alpha = None
|
|
|
|
node_collection = ax.scatter(
|
|
xy[:, 0],
|
|
xy[:, 1],
|
|
s=node_size,
|
|
c=node_color,
|
|
marker=node_shape,
|
|
cmap=cmap,
|
|
vmin=vmin,
|
|
vmax=vmax,
|
|
alpha=alpha,
|
|
linewidths=linewidths,
|
|
edgecolors=edgecolors,
|
|
label=label,
|
|
)
|
|
ax.tick_params(
|
|
axis="both",
|
|
which="both",
|
|
bottom=False,
|
|
left=False,
|
|
labelbottom=False,
|
|
labelleft=False,
|
|
)
|
|
|
|
node_collection.set_zorder(2)
|
|
return node_collection
|
|
|
|
|
|
def draw_edges(
|
|
graph,
|
|
pos,
|
|
edge_list=None,
|
|
width=1.0,
|
|
edge_color="k",
|
|
style="solid",
|
|
alpha=None,
|
|
arrowstyle=None,
|
|
arrow_size=10,
|
|
edge_cmap=None,
|
|
edge_vmin=None,
|
|
edge_vmax=None,
|
|
ax=None,
|
|
arrows=True,
|
|
label=None,
|
|
node_size=300,
|
|
node_list=None,
|
|
node_shape="o",
|
|
connectionstyle="arc3",
|
|
min_source_margin=0,
|
|
min_target_margin=0,
|
|
):
|
|
r"""Draw the edges of the graph.
|
|
|
|
This draws only the edges of the graph.
|
|
|
|
Parameters
|
|
----------
|
|
graph: A rustworkx graph
|
|
|
|
pos : dictionary
|
|
A dictionary with nodes as keys and positions as values.
|
|
Positions should be sequences of length 2.
|
|
|
|
edge_list : collection of edge tuples (default=graph.edge_list())
|
|
Draw only specified edges
|
|
|
|
width : float or array of floats (default=1.0)
|
|
Line width of edges
|
|
|
|
edge_color : color or array of colors (default='k')
|
|
Edge color. Can be a single color or a sequence of colors with the same
|
|
length as edge_list. Color can be string or rgb (or rgba) tuple of
|
|
floats from 0-1. If numeric values are specified they will be
|
|
mapped to colors using the edge_cmap and edge_vmin,edge_vmax
|
|
parameters.
|
|
|
|
style : string (default=solid line)
|
|
Edge line style e.g.: '-', '--', '-.', ':'
|
|
or words like 'solid' or 'dashed'.
|
|
(See `matplotlib.patches.FancyArrowPatch`: `linestyle`)
|
|
|
|
alpha : float or None (default=None)
|
|
The edge transparency
|
|
|
|
edge_cmap : Matplotlib colormap, optional
|
|
Colormap for mapping intensities of edges
|
|
|
|
edge_vmin,edge_vmax : floats, optional
|
|
Minimum and maximum for edge colormap scaling
|
|
|
|
ax : Matplotlib Axes object, optional
|
|
Draw the graph in the specified Matplotlib axes.
|
|
|
|
arrows : bool, optional (default=True)
|
|
For directed graphs, if True set default to drawing arrowheads.
|
|
Otherwise set default to no arrowheads. Ignored if `arrowstyle` is set.
|
|
|
|
Note: Arrows will be the same color as edges.
|
|
|
|
arrowstyle : str (default='-\|>' if directed else '-')
|
|
For directed graphs and `arrows==True` defaults to '-\|>',
|
|
otherwise defaults to '-'.
|
|
|
|
See `matplotlib.patches.ArrowStyle` for more options.
|
|
|
|
arrow_size : int (default=10)
|
|
For directed graphs, choose the size of the arrow head's length and
|
|
width. See `matplotlib.patches.FancyArrowPatch` for attribute
|
|
``mutation_scale`` for more info.
|
|
|
|
node_size : scalar or array (default=300)
|
|
Size of nodes. Though the nodes are not drawn with this function, the
|
|
node size is used in determining edge positioning.
|
|
|
|
node_list : list, optional (default=graph.node_indices())
|
|
This provides the node order for the `node_size` array (if it is an
|
|
array).
|
|
|
|
node_shape : string (default='o')
|
|
The marker used for nodes, used in determining edge positioning.
|
|
Specification is as a `matplotlib.markers` marker, e.g. one of
|
|
'so^>v<dph8'.
|
|
|
|
label : None or string
|
|
Label for legend
|
|
|
|
min_source_margin : int (default=0)
|
|
The minimum margin (gap) at the beginning of the edge at the source.
|
|
|
|
min_target_margin : int (default=0)
|
|
The minimum margin (gap) at the end of the edge at the target.
|
|
|
|
Returns
|
|
-------
|
|
list of matplotlib.patches.FancyArrowPatch
|
|
`FancyArrowPatch` instances of the directed edges
|
|
|
|
Notes
|
|
-----
|
|
For directed graphs, arrows are drawn at the head end. Arrows can be
|
|
turned off with keyword arrows=False or by passing an arrowstyle without
|
|
an arrow on the end.
|
|
|
|
Be sure to include `node_size` as a keyword argument; arrows are
|
|
drawn considering the size of nodes.
|
|
"""
|
|
try:
|
|
import matplotlib as mpl
|
|
import matplotlib.colors # type: ignore
|
|
import matplotlib.patches # type: ignore
|
|
import matplotlib.path # type: ignore
|
|
import matplotlib.pyplot as plt
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"matplotlib needs to be installed prior to running "
|
|
"rustworkx.visualization.mpl_draw(). You can install "
|
|
"matplotlib with:\n'pip install matplotlib'"
|
|
) from e
|
|
|
|
if arrowstyle is None:
|
|
if isinstance(graph, rustworkx.PyDiGraph) and arrows:
|
|
arrowstyle = "-|>"
|
|
else:
|
|
arrowstyle = "-"
|
|
|
|
if ax is None:
|
|
ax = plt.gca()
|
|
|
|
if edge_list is None:
|
|
edge_list = graph.edge_list()
|
|
|
|
if len(edge_list) == 0: # no edges!
|
|
return []
|
|
|
|
if node_list is None:
|
|
node_list = list(graph.node_indices())
|
|
|
|
# FancyArrowPatch handles color=None different from LineCollection
|
|
if edge_color is None:
|
|
edge_color = "k"
|
|
|
|
# set edge positions
|
|
edge_pos_keys = dict()
|
|
for e in edge_list:
|
|
edge_pos_keys[(tuple(pos[e[0]]), tuple(pos[e[1]]))] = None
|
|
edge_pos = edge_pos_keys.keys()
|
|
|
|
# Check if edge_color is an array of floats and map to edge_cmap.
|
|
# This is the only case handled differently from matplotlib
|
|
if (
|
|
np.iterable(edge_color)
|
|
and (len(edge_color) == len(edge_pos))
|
|
and np.all([isinstance(c, Number) for c in edge_color])
|
|
):
|
|
if edge_cmap is not None:
|
|
assert isinstance(edge_cmap, mpl.colors.Colormap)
|
|
else:
|
|
edge_cmap = plt.get_cmap()
|
|
if edge_vmin is None:
|
|
edge_vmin = min(edge_color)
|
|
if edge_vmax is None:
|
|
edge_vmax = max(edge_color)
|
|
color_normal = mpl.colors.Normalize(vmin=edge_vmin, vmax=edge_vmax)
|
|
edge_color = [edge_cmap(color_normal(e)) for e in edge_color]
|
|
|
|
# Note: Waiting for someone to implement arrow to intersection with
|
|
# marker. Meanwhile, this works well for polygons with more than 4
|
|
# sides and circle.
|
|
|
|
def to_marker_edge(marker_size, marker):
|
|
if marker in "s^>v<d": # `large` markers need extra space
|
|
return np.sqrt(2 * marker_size) / 2
|
|
else:
|
|
return np.sqrt(marker_size) / 2
|
|
|
|
# Draw arrows with `matplotlib.patches.FancyarrowPatch`
|
|
arrow_collection = []
|
|
mutation_scale = arrow_size # scale factor of arrow head
|
|
|
|
base_connectionstyle = mpl.patches.ConnectionStyle(connectionstyle)
|
|
|
|
# Fallback for self-loop scale. Left outside of _connectionstyle so it is
|
|
# only computed once
|
|
max_nodesize = np.array(node_size).max()
|
|
|
|
# FancyArrowPatch doesn't handle color strings
|
|
arrow_colors = mpl.colors.colorConverter.to_rgba_array(edge_color, alpha)
|
|
for i, edge in enumerate(edge_pos):
|
|
x1, y1 = edge[0][0], edge[0][1]
|
|
x2, y2 = edge[1][0], edge[1][1]
|
|
shrink_source = 0 # space from source to tail
|
|
shrink_target = 0 # space from head to target
|
|
if np.iterable(node_size): # many node sizes
|
|
source, target = edge_list[i][:2]
|
|
source_node_size = node_size[node_list.index(source)]
|
|
target_node_size = node_size[node_list.index(target)]
|
|
shrink_source = to_marker_edge(source_node_size, node_shape)
|
|
shrink_target = to_marker_edge(target_node_size, node_shape)
|
|
else:
|
|
shrink_source = shrink_target = to_marker_edge(node_size, node_shape)
|
|
|
|
if shrink_source < min_source_margin:
|
|
shrink_source = min_source_margin
|
|
|
|
if shrink_target < min_target_margin:
|
|
shrink_target = min_target_margin
|
|
|
|
if len(arrow_colors) == len(edge_pos):
|
|
arrow_color = arrow_colors[i]
|
|
elif len(arrow_colors) == 1:
|
|
arrow_color = arrow_colors[0]
|
|
else: # Cycle through colors
|
|
arrow_color = arrow_colors[i % len(arrow_colors)]
|
|
|
|
if np.iterable(width):
|
|
if len(width) == len(edge_pos):
|
|
line_width = width[i]
|
|
else:
|
|
line_width = width[i % len(width)]
|
|
else:
|
|
line_width = width
|
|
|
|
# radius of edges
|
|
if tuple(reversed(edge)) in edge_pos:
|
|
rad = 0.25
|
|
else:
|
|
rad = 0.0
|
|
|
|
arrow = mpl.patches.FancyArrowPatch(
|
|
(x1, y1),
|
|
(x2, y2),
|
|
arrowstyle=arrowstyle,
|
|
shrinkA=shrink_source,
|
|
shrinkB=shrink_target,
|
|
mutation_scale=mutation_scale,
|
|
color=arrow_color,
|
|
linewidth=line_width,
|
|
connectionstyle=f"{connectionstyle}, rad = {rad}",
|
|
linestyle=style,
|
|
zorder=1,
|
|
) # arrows go behind nodes
|
|
|
|
arrow_collection.append(arrow)
|
|
ax.add_patch(arrow)
|
|
|
|
edge_pos = np.asarray(tuple(edge_pos))
|
|
|
|
# compute view
|
|
mirustworkx = np.amin(np.ravel(edge_pos[:, :, 0]))
|
|
maxx = np.amax(np.ravel(edge_pos[:, :, 0]))
|
|
miny = np.amin(np.ravel(edge_pos[:, :, 1]))
|
|
maxy = np.amax(np.ravel(edge_pos[:, :, 1]))
|
|
w = maxx - mirustworkx
|
|
h = maxy - miny
|
|
|
|
def _connectionstyle(posA, posB, *args, **kwargs):
|
|
# check if we need to do a self-loop
|
|
if np.all(posA == posB):
|
|
# Self-loops are scaled by view extent, except in cases the extent
|
|
# is 0, e.g. for a single node. In this case, fall back to scaling
|
|
# by the maximum node size
|
|
selfloop_ht = 0.005 * max_nodesize if h == 0 else h
|
|
# this is called with _screen space_ values so covert back
|
|
# to data space
|
|
data_loc = ax.transData.inverted().transform(posA)
|
|
v_shift = 0.1 * selfloop_ht
|
|
h_shift = v_shift * 0.5
|
|
# put the top of the loop first so arrow is not hidden by node
|
|
path = [
|
|
# 1
|
|
data_loc + np.asarray([0, v_shift]),
|
|
# 4 4 4
|
|
data_loc + np.asarray([h_shift, v_shift]),
|
|
data_loc + np.asarray([h_shift, 0]),
|
|
data_loc,
|
|
# 4 4 4
|
|
data_loc + np.asarray([-h_shift, 0]),
|
|
data_loc + np.asarray([-h_shift, v_shift]),
|
|
data_loc + np.asarray([0, v_shift]),
|
|
]
|
|
|
|
ret = mpl.path.Path(ax.transData.transform(path), [1, 4, 4, 4, 4, 4, 4])
|
|
# if not, fall back to the user specified behavior
|
|
else:
|
|
ret = base_connectionstyle(posA, posB, *args, **kwargs)
|
|
|
|
return ret
|
|
|
|
# update view
|
|
padx, pady = 0.05 * w, 0.05 * h
|
|
corners = (mirustworkx - padx, miny - pady), (maxx + padx, maxy + pady)
|
|
ax.update_datalim(corners)
|
|
ax.autoscale_view()
|
|
|
|
ax.tick_params(
|
|
axis="both",
|
|
which="both",
|
|
bottom=False,
|
|
left=False,
|
|
labelbottom=False,
|
|
labelleft=False,
|
|
)
|
|
|
|
return arrow_collection
|
|
|
|
|
|
def draw_labels(
|
|
graph,
|
|
pos,
|
|
labels=None,
|
|
font_size=12,
|
|
font_color="k",
|
|
font_family="sans-serif",
|
|
font_weight="normal",
|
|
alpha=None,
|
|
bbox=None,
|
|
horizontalalignment="center",
|
|
verticalalignment="center",
|
|
ax=None,
|
|
clip_on=True,
|
|
):
|
|
"""Draw node labels on the graph.
|
|
|
|
Parameters
|
|
----------
|
|
graph: A rustworkx graph
|
|
|
|
pos : dictionary
|
|
A dictionary with nodes as keys and positions as values.
|
|
Positions should be sequences of length 2.
|
|
|
|
labels : dictionary (default={n: n for n in graph})
|
|
Node labels in a dictionary of text labels keyed by node.
|
|
Node-keys in labels should appear as keys in `pos`.
|
|
If needed use: `{n:lab for n,lab in labels.items() if n in pos}`
|
|
|
|
font_size : int (default=12)
|
|
Font size for text labels
|
|
|
|
font_color : string (default='k' black)
|
|
Font color string
|
|
|
|
font_weight : string (default='normal')
|
|
Font weight
|
|
|
|
font_family : string (default='sans-serif')
|
|
Font family
|
|
|
|
alpha : float or None (default=None)
|
|
The text transparency
|
|
|
|
bbox : Matplotlib bbox, (default is Matplotlib's ax.text default)
|
|
Specify text box properties (e.g. shape, color etc.) for node labels.
|
|
|
|
horizontalalignment : string (default='center')
|
|
Horizontal alignment {'center', 'right', 'left'}
|
|
|
|
verticalalignment : string (default='center')
|
|
Vertical alignment {'center', 'top', 'bottom', 'baseline',
|
|
'center_baseline'}
|
|
|
|
ax : Matplotlib Axes object, optional
|
|
Draw the graph in the specified Matplotlib axes.
|
|
|
|
clip_on : bool (default=True)
|
|
Turn on clipping of node labels at axis boundaries
|
|
|
|
Returns
|
|
-------
|
|
dict
|
|
`dict` of labels keyed on the nodes
|
|
"""
|
|
try:
|
|
import matplotlib.pyplot as plt
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"matplotlib needs to be installed prior to running "
|
|
"rustworkx.visualization.mpl_draw(). You can install "
|
|
"matplotlib with:\n'pip install matplotlib'"
|
|
) from e
|
|
|
|
if ax is None:
|
|
ax = plt.gca()
|
|
|
|
if labels is None:
|
|
labels = {n: n for n in graph.node_indices()}
|
|
|
|
text_items = {} # there is no text collection so we'll fake one
|
|
for n, label in labels.items():
|
|
(x, y) = pos[n]
|
|
if not isinstance(label, str):
|
|
label = str(label) # this makes "1" and 1 labeled the same
|
|
t = ax.text(
|
|
x,
|
|
y,
|
|
label,
|
|
size=font_size,
|
|
color=font_color,
|
|
family=font_family,
|
|
weight=font_weight,
|
|
alpha=alpha,
|
|
horizontalalignment=horizontalalignment,
|
|
verticalalignment=verticalalignment,
|
|
transform=ax.transData,
|
|
bbox=bbox,
|
|
clip_on=clip_on,
|
|
)
|
|
text_items[n] = t
|
|
|
|
ax.tick_params(
|
|
axis="both",
|
|
which="both",
|
|
bottom=False,
|
|
left=False,
|
|
labelbottom=False,
|
|
labelleft=False,
|
|
)
|
|
|
|
return text_items
|
|
|
|
|
|
def draw_edge_labels(
|
|
graph,
|
|
pos,
|
|
edge_labels=None,
|
|
label_pos=0.5,
|
|
font_size=10,
|
|
font_color="k",
|
|
font_family="sans-serif",
|
|
font_weight="normal",
|
|
alpha=None,
|
|
bbox=None,
|
|
horizontalalignment="center",
|
|
verticalalignment="center",
|
|
ax=None,
|
|
rotate=True,
|
|
clip_on=True,
|
|
):
|
|
"""Draw edge labels.
|
|
|
|
Parameters
|
|
----------
|
|
graph: A rustworkx graph
|
|
|
|
pos : dictionary
|
|
A dictionary with nodes as keys and positions as values.
|
|
Positions should be sequences of length 2.
|
|
|
|
edge_labels : dictionary (default={})
|
|
Edge labels in a dictionary of labels keyed by edge two-tuple.
|
|
Only labels for the keys in the dictionary are drawn.
|
|
|
|
label_pos : float (default=0.5)
|
|
Position of edge label along edge (0=head, 0.5=center, 1=tail)
|
|
|
|
font_size : int (default=10)
|
|
Font size for text labels
|
|
|
|
font_color : string (default='k' black)
|
|
Font color string
|
|
|
|
font_weight : string (default='normal')
|
|
Font weight
|
|
|
|
font_family : string (default='sans-serif')
|
|
Font family
|
|
|
|
alpha : float or None (default=None)
|
|
The text transparency
|
|
|
|
bbox : Matplotlib bbox, optional
|
|
Specify text box properties (e.g. shape, color etc.) for edge labels.
|
|
Default is {boxstyle='round', ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0)}.
|
|
|
|
horizontalalignment : string (default='center')
|
|
Horizontal alignment {'center', 'right', 'left'}
|
|
|
|
verticalalignment : string (default='center')
|
|
Vertical alignment {'center', 'top', 'bottom', 'baseline',
|
|
'center_baseline'}
|
|
|
|
ax : Matplotlib Axes object, optional
|
|
Draw the graph in the specified Matplotlib axes.
|
|
|
|
rotate : bool (default=True)
|
|
Rotate edge labels to lie parallel to edges
|
|
|
|
clip_on : bool (default=True)
|
|
Turn on clipping of edge labels at axis boundaries
|
|
|
|
Returns
|
|
-------
|
|
dict
|
|
`dict` of labels keyed by edge
|
|
"""
|
|
try:
|
|
import matplotlib.pyplot as plt
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"matplotlib needs to be installed prior to running "
|
|
"rustworkx.visualization.mpl_draw(). You can install "
|
|
"matplotlib with:\n'pip install matplotlib'"
|
|
) from e
|
|
|
|
if ax is None:
|
|
ax = plt.gca()
|
|
if edge_labels is None:
|
|
labels = {(u, v): d for u, v, d in graph.weighted_edge_list()}
|
|
else:
|
|
labels = edge_labels
|
|
text_items = {}
|
|
for (n1, n2), label in labels.items():
|
|
(x1, y1) = pos[n1]
|
|
(x2, y2) = pos[n2]
|
|
(x, y) = (
|
|
x1 * label_pos + x2 * (1.0 - label_pos),
|
|
y1 * label_pos + y2 * (1.0 - label_pos),
|
|
)
|
|
if (n2, n1) in labels.keys(): # loop
|
|
dy = np.abs(y2 - y1)
|
|
if n2 > n1:
|
|
y -= 0.25 * dy
|
|
else:
|
|
y += 0.25 * dy
|
|
|
|
if rotate:
|
|
# in degrees
|
|
angle = np.arctan2(y2 - y1, x2 - x1) / (2.0 * np.pi) * 360
|
|
# make label orientation "right-side-up"
|
|
if angle > 90:
|
|
angle -= 180
|
|
if angle < -90:
|
|
angle += 180
|
|
# transform data coordinate angle to screen coordinate angle
|
|
xy = np.array((x, y))
|
|
trans_angle = ax.transData.transform_angles(np.array((angle,)), xy.reshape((1, 2)))[0]
|
|
else:
|
|
trans_angle = 0.0
|
|
# use default box of white with white border
|
|
if bbox is None:
|
|
bbox = dict(boxstyle="round", ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0))
|
|
if not isinstance(label, str):
|
|
label = str(label) # this makes "1" and 1 labeled the same
|
|
|
|
t = ax.text(
|
|
x,
|
|
y,
|
|
label,
|
|
size=font_size,
|
|
color=font_color,
|
|
family=font_family,
|
|
weight=font_weight,
|
|
alpha=alpha,
|
|
horizontalalignment=horizontalalignment,
|
|
verticalalignment=verticalalignment,
|
|
rotation=trans_angle,
|
|
transform=ax.transData,
|
|
bbox=bbox,
|
|
zorder=1,
|
|
clip_on=clip_on,
|
|
)
|
|
text_items[(n1, n2)] = t
|
|
|
|
ax.tick_params(
|
|
axis="both",
|
|
which="both",
|
|
bottom=False,
|
|
left=False,
|
|
labelbottom=False,
|
|
labelleft=False,
|
|
)
|
|
|
|
return text_items
|
|
|
|
|
|
def apply_alpha(colors, alpha, elem_list, cmap=None, vmin=None, vmax=None):
|
|
"""Apply an alpha (or list of alphas) to the colors provided.
|
|
|
|
Parameters
|
|
----------
|
|
|
|
colors : color string or array of floats (default='r')
|
|
Color of element. Can be a single color format string,
|
|
or a sequence of colors with the same length as node_list.
|
|
If numeric values are specified they will be mapped to
|
|
colors using the cmap and vmin,vmax parameters. See
|
|
matplotlib.scatter for more details.
|
|
|
|
alpha : float or array of floats
|
|
Alpha values for elements. This can be a single alpha value, in
|
|
which case it will be applied to all the elements of color. Otherwise,
|
|
if it is an array, the elements of alpha will be applied to the colors
|
|
in order (cycling through alpha multiple times if necessary).
|
|
|
|
elem_list : array of rustworkx objects
|
|
The list of elements which are being colored. These could be nodes,
|
|
edges or labels.
|
|
|
|
cmap : matplotlib colormap
|
|
Color map for use if colors is a list of floats corresponding to points
|
|
on a color mapping.
|
|
|
|
vmin, vmax : float
|
|
Minimum and maximum values for normalizing colors if a colormap is used
|
|
|
|
Returns
|
|
-------
|
|
|
|
rgba_colors : numpy ndarray
|
|
Array containing RGBA format values for each of the node colours.
|
|
|
|
"""
|
|
try:
|
|
import matplotlib as mpl
|
|
import matplotlib.colors # call as mpl.colors
|
|
import matplotlib.cm # type: ignore
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"matplotlib needs to be installed prior to running "
|
|
"rustworkx.visualization.mpl_draw(). You can install "
|
|
"matplotlib with:\n'pip install matplotlib'"
|
|
) from e
|
|
|
|
# If we have been provided with a list of numbers as long as elem_list,
|
|
# apply the color mapping.
|
|
if len(colors) == len(elem_list) and isinstance(colors[0], Number):
|
|
mapper = mpl.cm.ScalarMappable(cmap=cmap)
|
|
mapper.set_clim(vmin, vmax)
|
|
rgba_colors = mapper.to_rgba(colors)
|
|
# Otherwise, convert colors to matplotlib's RGB using the colorConverter
|
|
# object. These are converted to numpy ndarrays to be consistent with the
|
|
# to_rgba method of ScalarMappable.
|
|
else:
|
|
try:
|
|
rgba_colors = np.array([mpl.colors.colorConverter.to_rgba(colors)])
|
|
except ValueError:
|
|
rgba_colors = np.array([mpl.colors.colorConverter.to_rgba(color) for color in colors])
|
|
# Set the final column of the rgba_colors to have the relevant alpha values
|
|
try:
|
|
# If alpha is longer than the number of colors, resize to the number of
|
|
# elements. Also, if rgba_colors.size (the number of elements of
|
|
# rgba_colors) is the same as the number of elements, resize the array,
|
|
# to avoid it being interpreted as a colormap by scatter()
|
|
if len(alpha) > len(rgba_colors) or rgba_colors.size == len(elem_list):
|
|
rgba_colors = np.resize(rgba_colors, (len(elem_list), 4))
|
|
rgba_colors[1:, 0] = rgba_colors[0, 0]
|
|
rgba_colors[1:, 1] = rgba_colors[0, 1]
|
|
rgba_colors[1:, 2] = rgba_colors[0, 2]
|
|
rgba_colors[:, 3] = list(islice(cycle(alpha), len(rgba_colors)))
|
|
except TypeError:
|
|
rgba_colors[:, -1] = alpha
|
|
return rgba_colors
|