# This code is licensed under the Apache License, Version 2.0. You may # obtain a copy of this license in the LICENSE.txt file in the root directory # of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. # # Any modifications or derivative works of this code must retain this # copyright notice, and modified files need to carry a notice indicating # that they have been altered from the originals. # # NetworkX is distributed with the 3-clause BSD license. # # Copyright (C) 2004-2020, NetworkX Developers # Aric Hagberg # Dan Schult # Pieter Swart # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are # met: # # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # # * Redistributions in binary form must reproduce the above # copyright notice, this list of conditions and the following # disclaimer in the documentation and/or other materials provided # with the distribution. # # * Neither the name of the NetworkX Developers nor the names of its # contributors may be used to endorse or promote products derived # from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # # This code is forked from networkx's networkx_pylab.py module and adapted to # work with rustworkx instead. The original source can be found at: # # https://github.com/networkx/networkx/blob/80b1afa2ae50314a8312998c214a8c1a356adcf1/networkx/drawing/nx_pylab.py """Draw a rustworkx graph with matplotlib.""" from collections.abc import Iterable from itertools import islice, cycle from numbers import Number import numpy as np import rustworkx __all__ = [ "mpl_draw", ] def mpl_draw(graph, pos=None, ax=None, arrows=True, with_labels=False, **kwds): r"""Draw a graph with Matplotlib. .. note:: Matplotlib is an optional dependency and will not be installed with rustworkx by default. If you intend to use this function make sure that you install matplotlib with either ``pip install matplotlib`` or ``pip install 'rustworkx[mpl]'`` :param graph: A rustworkx graph, either a :class:`~rustworkx.PyGraph` or a :class:`~rustworkx.PyDiGraph`. :param dict pos: An optional dictionary (or a :class:`~rustworkx.Pos2DMapping` object) with nodes as keys and positions as values. If not specified a spring layout positioning will be computed. See `layout_functions` for functions that compute node positions. :param matplotlib.Axes ax: An optional Matplotlib Axes object to draw the graph in. :param bool arrows: For :class:`~rustworkx.PyDiGraph` objects if ``True`` draw arrowheads. (defaults to ``True``) Note, that the Arrows will be the same color as edges. :param str arrowstyle: An optional string for directed graphs to choose the style of the arrowsheads. See :class:`matplotlib.patches.ArrowStyle` for more options. By default the value is set to ``'-\|>'``. :param int arrow_size: For directed graphs, choose the size of the arrow head's length and width. See :class:`matplotlib.patches.FancyArrowPatch` attribute and constructor kwarg ``mutation_scale`` for more info. Defaults to 10. :param bool with_labels: Set to ``True`` to draw labels on the nodes. Edge labels will only be drawn if the ``edge_labels`` parameter is set to a function. Defaults to ``False``. :param list node_list: An optional list of node indices in the graph to draw. If not specified all nodes will be drawn. :param list edge_list: An option list of edges in the graph to draw. If not specified all edges will be drawn :param int|list node_size: Optional size of nodes. If an array is specified it must be the same length as node_list. Defaults to 300 :param node_color: Optional node color. Can be a single color or a sequence of colors with the same length as node_list. Color can be string or rgb (or rgba) tuple of floats from 0-1. If numeric values are specified they will be mapped to colors using the ``cmap`` and ``vmin``,``vmax`` parameters. See :func:`matplotlib.scatter` for more details. Defaults to ``'#1f78b4'``) :param str node_shape: The optional shape node. The specification is the same as the :func:`matplotlib.pyplot.scatter` function's ``marker`` kwarg, valid options are one of ``['s', 'o', '^', '>', 'v', '<', 'd', 'p', 'h', '8']``. Defaults to ``'o'`` :param float alpha: Optional value for node and edge transparency :param matplotlib.colors.Colormap cmap: An optional Matplotlib colormap object for mapping intensities of nodes :param float vmin: Optional minimum value for node colormap scaling :param float vmax: Optional minimum value for node colormap scaling :param float|sequence linewidths: An optional line width for symbol borders. If a sequence is specified it must be the same length as node_list. Defaults to 1.0 :param float|sequence width: An optional width to use for edges. Can either be a float or sequence of floats. If a sequence is specified it must be the same length as node_list. Defaults to 1.0 :param str|sequence edge_color: color or array of colors (default='k') Edge color. Can be a single color or a sequence of colors with the same length as edge_list. Color can be string or rgb (or rgba) tuple of floats from 0-1. If numeric values are specified they will be mapped to colors using the ``edge_cmap`` and ``edge_vmin``, ``edge_vmax`` parameters. :param matplotlib.colors.Colormap edge_cmap: An optional Matplotlib colormap for mapping intensities of edges. :param float edge_vmin: Optional minimum value for edge colormap scaling :param float edge_vmax: Optional maximum value for node colormap scaling :param str style: An optional string to specify the edge line style. For example, ``'-'``, ``'--'``, ``'-.'``, ``':'`` or words like ``'solid'`` or ``'dashed'``. See the :class:`matplotlib.patches.FancyArrowPatch` attribute and kwarg ``linestyle`` for more details. Defaults to ``'solid'``. :param func labels: An optional callback function that will be passed a node payload and return a string label for the node. For example:: labels=str could be used to just return a string cast of the node's data payload. Or something like:: labels=lambda node: node['label'] could be used if the node payloads are dictionaries. :param func edge_labels: An optional callback function that will be passed an edge payload and return a string label for the edge. For example:: edge_labels=str could be used to just return a string cast of the edge's data payload. Or something like:: edge_labels=lambda edge: edge['label'] could be used if the edge payloads are dictionaries. If this is set edge labels will be drawn in the visualization. :param int font_size: An optional fontsize to use for text labels, By default a value of 12 is used for nodes and 10 for edges. :param str font_color: An optional font color for strings. By default ``'k'`` (ie black) is set. :param str font_weight: An optional string used to specify the font weight. By default a value of ``'normal'`` is used. :param str font_family: An optional font family to use for strings. By default ``'sans-serif'`` is used. :param str label: An optional string label to use for the graph legend. :param str connectionstyle: An optional value used to create a curved arc of rounding radius rad. For example, ``connectionstyle='arc3,rad=0.2'``. See :class:`matplotlib.patches.ConnectionStyle` and :class:`matplotlib.patches.FancyArrowPatch` for more info. By default this is set to ``"arc3"``. :returns: A matplotlib figure for the visualization if not running with an interactive backend (like in jupyter) or if ``ax`` is not set. :rtype: matplotlib.figure.Figure For Example: .. jupyter-execute:: import matplotlib.pyplot as plt import rustworkx as rx from rustworkx.visualization import mpl_draw G = rx.generators.directed_path_graph(25) mpl_draw(G) plt.draw() """ try: import matplotlib.pyplot as plt # type: ignore except ImportError as e: raise ImportError( "matplotlib needs to be installed prior to running " "rustworkx.visualization.mpl_draw(). You can install " "matplotlib with:\n'pip install matplotlib'" ) from e if ax is None: cf = plt.gcf() else: cf = ax.get_figure() cf.set_facecolor("w") if ax is None: if cf.axes: ax = cf.gca() else: ax = cf.add_axes((0, 0, 1, 1)) draw_graph(graph, pos=pos, ax=ax, arrows=arrows, with_labels=with_labels, **kwds) ax.set_axis_off() plt.draw_if_interactive() if not plt.isinteractive() or ax is None: return cf def draw_graph(graph, pos=None, arrows=True, with_labels=False, **kwds): r"""Draw the graph using Matplotlib. Draw the graph with Matplotlib with options for node positions, labeling, titles, and many other drawing features. See draw() for simple drawing without labels or axes. Parameters ---------- graph: A rustworkx :class:`~rustworkx.PyDiGraph` or :class:`~rustworkx.PyGraph` pos : dictionary, optional A dictionary with nodes as keys and positions as values. If not specified a spring layout positioning will be computed. See :mod:`rustworkx.drawing.layout` for functions that compute node positions. Notes ----- For directed graphs, arrows are drawn at the head end. Arrows can be turned off with keyword arrows=False. """ try: import matplotlib.pyplot as plt except ImportError as e: raise ImportError( "matplotlib needs to be installed prior to running " "rustworkx.visualization.mpl_draw(). You can install " "matplotlib with:\n'pip install matplotlib'" ) from e valid_node_kwds = { "node_list", "node_size", "node_color", "node_shape", "alpha", "cmap", "vmin", "vmax", "ax", "linewidths", "edgecolors", "label", } valid_edge_kwds = { "edge_list", "width", "edge_color", "style", "alpha", "arrowstyle", "arrow_size", "edge_cmap", "edge_vmin", "edge_vmax", "ax", "label", "node_size", "node_list", "node_shape", "connectionstyle", "min_source_margin", "min_target_margin", } valid_label_kwds = { "labels", "font_size", "font_color", "font_family", "font_weight", "alpha", "bbox", "ax", "horizontalalignment", "verticalalignment", } valid_edge_label_kwds = { "edge_labels", "font_size", "font_color", "font_family", "font_weight", "alpha", "bbox", "ax", "rotate", "horizontalalignment", "verticalalignment", } valid_kwds = valid_node_kwds | valid_edge_kwds | valid_label_kwds | valid_edge_label_kwds if any([k not in valid_kwds for k in kwds]): invalid_args = ", ".join([k for k in kwds if k not in valid_kwds]) raise ValueError(f"Received invalid argument(s): {invalid_args}") label_fn = kwds.pop("labels", None) if label_fn: kwds["labels"] = {x: label_fn(graph[x]) for x in graph.node_indices()} edge_label_fn = kwds.pop("edge_labels", None) if edge_label_fn: kwds["edge_labels"] = { (x[0], x[1]): edge_label_fn(x[2]) for x in graph.weighted_edge_list() } node_kwds = {k: v for k, v in kwds.items() if k in valid_node_kwds} edge_kwds = {k: v for k, v in kwds.items() if k in valid_edge_kwds} if isinstance(edge_kwds.get("alpha"), list): del edge_kwds["alpha"] label_kwds = {k: v for k, v in kwds.items() if k in valid_label_kwds} edge_label_kwds = {k: v for k, v in kwds.items() if k in valid_edge_label_kwds} if pos is None: pos = rustworkx.spring_layout(graph) # default to spring layout draw_nodes(graph, pos, **node_kwds) draw_edges(graph, pos, arrows=arrows, **edge_kwds) if with_labels: draw_labels(graph, pos, **label_kwds) if edge_label_fn: draw_edge_labels(graph, pos, **edge_label_kwds) plt.draw_if_interactive() def draw_nodes( graph, pos, node_list=None, node_size=300, node_color="#1f78b4", node_shape="o", alpha=None, cmap=None, vmin=None, vmax=None, ax=None, linewidths=None, edgecolors=None, label=None, ): """Draw the nodes of the graph. This draws only the nodes of the graph. :param graph: A rustworkx graph, either a :class:`~rustworkx.PyGraph` or a :class:`~rustworkx.PyDiGraph`. :param dict pos: A dictionary with nodes as keys and positions as values. Positions should be sequences of length 2. :param Axes ax: An optional Matplotlib Axes object, if specified it will draw the graph in the specified Matplotlib axes. :param list node_list: If specified only draw the specified node indices. If not specified all nodes in the graph will be drawn. :param float|array node_size: Size of nodes. If an array it must be the same length as node_list. Defaults to 300 node_color : color or array of colors (default='#1f78b4') Node color. Can be a single color or a sequence of colors with the same length as node_list. Color can be string or rgb (or rgba) tuple of floats from 0-1. If numeric values are specified they will be mapped to colors using the cmap and vmin,vmax parameters. See matplotlib.scatter for more details. node_shape : string (default='o') The shape of the node. Specification is as matplotlib.scatter marker, one of 'so^>v', otherwise defaults to '-'. See `matplotlib.patches.ArrowStyle` for more options. arrow_size : int (default=10) For directed graphs, choose the size of the arrow head's length and width. See `matplotlib.patches.FancyArrowPatch` for attribute ``mutation_scale`` for more info. node_size : scalar or array (default=300) Size of nodes. Though the nodes are not drawn with this function, the node size is used in determining edge positioning. node_list : list, optional (default=graph.node_indices()) This provides the node order for the `node_size` array (if it is an array). node_shape : string (default='o') The marker used for nodes, used in determining edge positioning. Specification is as a `matplotlib.markers` marker, e.g. one of 'so^>vv