Source code for node_scatter

# This file is part of Sympathy for Data.
# Copyright (c) 2013, 2017, Combine Control Systems AB
#
# Sympathy for Data is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# Sympathy for Data is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Sympathy for Data.  If not, see <http://www.gnu.org/licenses/>.
import os
import numpy as np
import sys
import warnings

from sympathy.api import node as synode
from sympathy.api.nodeconfig import Port, Ports, Tag, Tags, adjust
from sympathy.api.exceptions import sywarn
from sympathy.api import qt2 as qt_compat

from matplotlib.backends.backend_qt5agg import (
    FigureCanvasQTAgg as FigureCanvas)
from matplotlib.backends.backend_agg import (
    FigureCanvasAgg as FigureCanvasNonInteractive)
from matplotlib.backends.backend_qt5agg import (
    NavigationToolbar2QT as NavigationToolbar)
from matplotlib.backends.backend_qt import NavigationToolbar2
from matplotlib.figure import Figure

# For 3D plot
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D

QtCore = qt_compat.QtCore
QtGui = qt_compat.import_module('QtGui')
QtWidgets = qt_compat.import_module('QtWidgets')
qt_compat.backend.use_matplotlib_qt()


def reset_plot_parameters(parameters, table):
    """Reset parameters for 3d plot. Ex. when file datasource has
    changed.
    """
    parameters['x_axis'].value_names = []
    parameters['y_axis'].value_names = []
    parameters['z_axis'].value_names = []
    parameters['line_style'].value_names = ['o']
    parameters['plot_func'].value_names = ['scatter']
    parameters['filename_extension'].value_names = []


def create_filenames_from_parameters(parameters):
    export_directory = parameters['directory'].value or '.'
    filename = parameters['filename'].value
    extension = parameters['filename_extension'].selected
    complete_filename = f'{filename}.{extension}'
    return os.path.join(export_directory, complete_filename)


def _create_and_add_axes_3d(fig):
    axes = Axes3D(fig, auto_add_to_figure=False)
    fig.add_axes(axes)
    return axes


[docs]class Super3dNode(synode.Node): """ In the configuration Table columns are selected along the axes in the plots. There exist differences between the nodes how to do this, but the basic principle is the same. For the actual plots is possible to change both line/marker style and plot style in the plot. Below, the available plot styles are listed. A plot legend is, by default, shown in the plot, but can be hidden by a simple push of a button. The navigation toolbar under the plot let the user zoom and pan the plot window. Available plot types (3D): - scatter - surf - wireframe - plot - contour - heatmap The advanced plot controller allows the user to draw two lines parallel to the Y-axis. These can be moved along the X-axis while information about the intersection points between these lines and the plotted data points is shown in a table. If a line is drawn in between two points in the plotted data, the line will always move to the closest point. """ name = 'Scatter 3D Table' description = 'Create a three-dimensional plot' nodeid = 'org.sysess.sympathy.visualize.scatter3dnode' author = 'Helena Olen' version = '1.0' icon = 'scatter3d.svg' tags = Tags(Tag.Visual.Plot) parameters = synode.parameters() parameters.set_list( 'tb_names', label='Time basis', description='Combo of all timebasis', editor=synode.editors.combo_editor()) parameters.set_list( 'x_axis', label='X axis', description='X axis selection for plot', editor=synode.editors.combo_editor()) parameters.set_list( 'y_axis', label='Y axis', description='Y axis selection for plot', editor=synode.editors.combo_editor()) parameters.set_list( 'z_axis', label='Z axis', description='Z axis selection for plot', editor=synode.editors.combo_editor()) parameters.set_list( 'line_style', label='Line style', plist=['o', '^', '*'], description='Selectable line styles', editor=synode.editors.combo_editor()) parameters.set_list( 'plot_func', label='Plot type', plist=['scatter', 'surf', 'wireframe', 'plot', 'contour', 'heatmap'], description='Selectable plot types', editor=synode.editors.combo_editor()) parameters.set_list( 'filename_extension', label='File extension', description='Filename extension', editor=synode.editors.combo_editor()) parameters.set_string( "directory", label="Output directory", description="Select the directory where to export the files.", editor=synode.editors.directory_editor()) parameters.set_string( "filename", label="Filename", description="Filename without extension.") inputs = Ports( [Port.Table('Input Table', name='port1')]) outputs = Ports([Port.Datasource('Output file', name='port2')]) def update_parameters(self, params): # Remove old parameters. for param in ['azim', 'elev']: if param in params: del params[param] def exec_parameter_view(self, node_context): """Create the parameter view""" table = node_context.input['port1'] parameters = node_context.parameters adjust(parameters['x_axis'], table) adjust(parameters['y_axis'], table) adjust(parameters['z_axis'], table) parameters['filename_extension'].adjust(['svg', 'pdf', 'eps', 'png']) try: return Scatter3dWidget(parameters, table) except Exception: reset_plot_parameters(parameters, table) return Scatter3dWidget(parameters, table) def execute(self, node_context): """Execute""" parameters = node_context.parameters if not parameters['filename'].value: sywarn('No output filename selected, ' 'the output file will be empty.') return fq_filename = create_filenames_from_parameters(parameters) fig = Figure() FigureCanvasNonInteractive(fig) axes = _create_and_add_axes_3d(fig) table = node_context.input['port1'] plot_widget = Plot3d(table, parameters, fig, axes) plot_widget.update_figure() fig.savefig(fq_filename) node_context.output['port2'].encode_path(fq_filename)
class FigureCanvasCustom(FigureCanvas): canvasResized = qt_compat.Signal() def resizeEvent(self, event): FigureCanvas.resizeEvent(self, event) self.canvasResized.emit() class Scatter3dWidget(QtWidgets.QWidget): """Widget to plot a three dimensional scatter graph""" def __init__(self, parameters, table): super().__init__() self._table = table self._parameters = parameters self._x_axis_combobox = None self._y_axis = None self._line_style_combobox = None self._plot_combobox = None self._file_extension_combo = None self._outputs_hlayout = None self._projection = None self._background = None self._figure = None self._axes = None self._canvas = None self._toolbar = None self._plot = None self._z_axis_combobox = None self._projection = '3d' self._init_gui() def _init_gui(self): """Init GUI""" # Create plot window. self._create_figure_gui() vlayout = QtWidgets.QVBoxLayout() axes_hlayout = QtWidgets.QHBoxLayout() axes_hlayout.setSpacing(20) self.setMinimumWidth(640) self.setMinimumHeight(480) self._x_axis_combobox = self._parameters['x_axis'].gui() self._y_axis_combobox = self._parameters['y_axis'].gui() self._z_axis_combobox = self._parameters['z_axis'].gui() self._line_style_combobox = self._parameters['line_style'].gui() self._plot_combobox = self._parameters['plot_func'].gui() axes_hlayout.addWidget(self._x_axis_combobox) axes_hlayout.addWidget(self._y_axis_combobox) axes_hlayout.addWidget(self._z_axis_combobox) axes_hlayout.addWidget(self._line_style_combobox) axes_hlayout.addWidget(self._plot_combobox) axes_hlayout.insertStretch(-1) # Create outputlayout self._create_output_layout() vlayout.addItem(axes_hlayout) vlayout.addWidget(self._canvas) vlayout.addWidget(self._toolbar) vlayout.addLayout(self._outputs_hlayout) self.setLayout(vlayout) self._init_gui_from_parameters() self._x_axis_combobox.editor().currentIndexChanged[int].connect( self._x_axis_change) self._y_axis_combobox.editor().currentIndexChanged[int].connect( self._y_axis_change) self._z_axis_combobox.editor().currentIndexChanged[int].connect( self._z_axis_change) self._line_style_combobox.editor().currentIndexChanged[int].connect( self._line_style_changed) self._plot_combobox.editor().currentIndexChanged.connect( self._plot_func_changed) self._figure.canvas.mpl_connect( 'button_release_event', self._update_view) def _init_gui_from_parameters(self): """Init GUI from parameters""" self._plot = Plot3d( self._table, self._parameters, self._figure, self._axes) self._plot_func_changed() self._update_figure() def _create_figure_gui(self): if sys.platform == 'darwin': backgroundcolor = '#ededed' else: backgroundcolor = self.palette().color( QtGui.QPalette.ColorRole.Window).name() self._figure = Figure(facecolor=backgroundcolor) self._create_subplot() self._create_canvas_tool() def _create_subplot(self): """To be implemented by subclasses""" try: with warnings.catch_warnings(): warnings.simplefilter('ignore', UserWarning) self._axes = self._figure.add_subplot( 111, projection=self._projection) except ValueError: pass def _create_canvas_tool(self): """Create canvas and navigation toolbar.""" self._canvas = FigureCanvasCustom(self._figure) policy = QtWidgets.QSizePolicy() policy.setHorizontalStretch(1) policy.setVerticalStretch(1) policy.setHorizontalPolicy(QtWidgets.QSizePolicy.Policy.Expanding) policy.setVerticalPolicy(QtWidgets.QSizePolicy.Policy.Expanding) self._canvas.setSizePolicy(policy) self._toolbar = NavigationToolbarCustom(self._canvas, self) def _create_output_layout(self): """Create output layout with directory edit, file editor and file extension combo. """ self._outputs_hlayout = QtWidgets.QHBoxLayout() self._outputs_hlayout.addWidget( self._parameters['directory'].gui()) self._outputs_hlayout.addWidget(self._parameters['filename'].gui()) self._file_extension_combo = ( self._parameters['filename_extension'].gui()) self._outputs_hlayout.addWidget(self._file_extension_combo) def _x_axis_change(self, index): """Update figure with new x_axis value""" self._update_figure() def _y_axis_change(self, index): """Update figure with new y_axis value""" self._update_figure() def _z_axis_change(self, index): """Update figure with new z_axis value""" self._update_figure() def _line_style_changed(self, index): """Update figure with new line_style""" self._update_figure() def _enable_line(self, state): """Enable or disable line style combo""" self._line_style_combobox.editor().setEnabled(state) def _plot_func_changed(self): """Update GUI and figure when plot function changed""" plot_func = self._parameters['plot_func'].selected if plot_func == 'plot': self._enable_line(False) elif plot_func == 'scatter': self._enable_line(True) else: self._enable_line(False) self._update_figure() def _update_figure(self): """Update figure""" self._plot.update_figure() self._toolbar.connect_canvas() try: self._canvas.draw() except ValueError: self._axes.clear() self._canvas.draw() def _update_view(self, event): """Update view when figure rotated""" self._plot.update_view() class Plot3d: def __init__(self, table, parameters, fig, axes): self._fig = fig self._axes = axes self._parameters = parameters self._table = table self._cb = None self._nbr_points = 100 self._2d_axes = False self._rotation = None def update_figure(self): """Update figure""" x_column_name = '' x_data = [] y_data = [] z_data = [] plot_func = self._parameters['plot_func'].selected x_column_name = self._parameters['x_axis'].selected or '' y_column_name = self._parameters['y_axis'].selected or '' z_column_name = self._parameters['z_axis'].selected or '' if x_column_name and y_column_name and z_column_name: x_data = self._table[x_column_name] y_data = self._table[y_column_name] z_data = self._table[z_column_name] if self._axes is not None: self._axes.clear() if not self._2d_axes: with warnings.catch_warnings(): warnings.simplefilter('ignore', UserWarning) self._axes.mouse_init() # Get units for axis x_units = self._get_units(x_column_name) y_units = self._get_units(y_column_name) z_units = self._get_units(z_column_name) # Only plot real numbers if (np.all(np.isreal(x_data)) and np.all(np.isreal(y_data)) and np.all(np.isreal(z_data))): if plot_func == 'surf': self._axes_2d_to_3d() if len(x_data) >= 1 and len(y_data) >= 1 and len(z_data) >= 1: self._surf_plot(x_data, y_data, z_data) elif plot_func == 'contour': self._axes_2d_to_3d() if len(x_data) >= 1 and len(y_data) >= 1 and len(z_data) >= 1: self._contour_plot(x_data, y_data, z_data) elif plot_func == 'scatter': self._axes_2d_to_3d() self._check_cb() if self._axes is not None: self._axes.scatter( x_data, y_data, z_data, marker=self._parameters['line_style'].selected) self._cb = None elif plot_func == 'plot': self._axes_2d_to_3d() self._check_cb() if self._axes is not None: self._axes.plot(x_data, y_data, z_data) self._cb = None elif plot_func == 'wireframe': self._axes_2d_to_3d() self._wireframe_plot(x_data, y_data, z_data) elif plot_func == 'heatmap': self._heatmap_plot( x_data, y_data, z_data, z_column_name, z_units) else: pass if self._axes is not None: self._axes.set_xlabel(self._get_label(x_column_name, x_units)) self._axes.set_ylabel(self._get_label(y_column_name, y_units)) if not self._2d_axes and self._axes is not None: self._axes.set_zlabel(self._get_label(z_column_name, z_units)) def update_view(self): """Update figure rotation if 3d figure""" pass def _get_label(self, column_name, unit): """Get axis label""" if unit: label = str(column_name) + ' [' + unit + ']' else: label = str(column_name) return label def _interpolate_data(self, data): """Linear interpolation of data when data more than predefined nbr of points. """ len_data = len(data) if len_data > self._nbr_points: x_interp = range(0, len_data) new_x_points = np.linspace(0, len_data - 1, self._nbr_points) new_data = np.interp(new_x_points, x_interp, data) return new_data return data def _interpolate_all(self, x_data, y_data, z_data): """Interpolate x_data, y_data and z_data""" x_data_new = self._interpolate_data(x_data) y_data_new = self._interpolate_data(y_data) z_data_new = self._interpolate_data(z_data) return x_data_new, y_data_new, z_data_new def _surf_plot(self, x_data, y_data, z_data): """3d surface plot""" x_data_new, y_data_new, z_data_new = self._interpolate_all( x_data, y_data, z_data) X, Y, Z = self._get_xyz( x_data_new, y_data_new, z_data_new) self._check_cb() if self._axes is not None: surf = self._axes.plot_surface( X, Y, Z, cmap=cm.coolwarm, rstride=1, cstride=1, linewidth=0, antialiased=False) if len(x_data) > 1 and len(y_data) > 1 and len(z_data) > 1: self._cb = self._fig.colorbar(surf, format='%d') # ????? else: self._cb = None def _contour_plot(self, x_data, y_data, z_data): """3d contour plot""" x_data_new, y_data_new, z_data_new = self._interpolate_all( x_data, y_data, z_data) X, Y, Z = self._get_xyz( x_data_new, y_data_new, z_data_new) if self._axes is not None: cset = self._axes.contour(X, Y, Z, cmap=cm.coolwarm) self._axes.clabel(cset, fontsize=9, inline=1) self._check_cb() self._cb = None def _wireframe_plot(self, x_data, y_data, z_data): """3d wireframe plot""" x_data_new, y_data_new, z_data_new = self._interpolate_all( x_data, y_data, z_data) X, Y, Z = self._get_xyz( x_data_new, y_data_new, z_data_new) if self._axes is not None: self._axes.plot_wireframe( X, Y, Z, rstride=1, cstride=1, alpha=0.4) self._check_cb() self._cb = None def _heatmap_plot(self, x_data, y_data, z_data, z_column_name, z_units): """2d heatmap plot with z_axis as colour""" x_data_new, y_data_new, z_data_new = self._interpolate_all( x_data, y_data, z_data) X, Y, Z = self._get_xyz( x_data_new, y_data_new, z_data_new) x = X.ravel() y = Y.ravel() z = Z.ravel() gridsize = 30 self._check_cb() self._fig.delaxes(self._axes) self._axes = self._fig.add_subplot(111) self._2d_axes = True if (len(x_data) > 1 and len(y_data) > 1 and len(z_data) > 1 and self._axes is not None): heat = self._axes.hexbin( x, y, C=z, gridsize=gridsize, cmap=cm.jet, bins=None) self._axes.axis([x.min(), x.max(), y.min(), y.max()]) self._cb = self._fig.colorbar(heat) z_label = self._get_label(z_column_name, z_units) self._cb.set_label(z_label) else: self._cb = None def _axes_2d_to_3d(self): """Create new 3d figure and delete old figure""" self._check_cb() if self._axes is not None: self._fig.delaxes(self._axes) self._axes = _create_and_add_axes_3d(self._fig) self._2d_axes = False def _check_cb(self): """Check if colorbar exists, and then delete it""" if self._cb: try: self._fig.delaxes(self._fig.axes[1]) except Exception: pass def _get_units(self, column_name): """Get units of axis and handle deg to be displayed correctly in plot. """ try: unit = self._table.col(column_name).attr('unit') or '' except Exception: unit = '' return unit def _get_xyz(self, x_data, y_data, z_data): """Get matrices X,Y,Z from 1D arrays""" x, y = np.meshgrid(x_data, y_data) z = np.tile(z_data, (len(z_data), 1)) return x, y, z class NavigationToolbarCustom(NavigationToolbar): _home = True def __init__(self, canvas, parent): super().__init__(canvas, parent) self._forward = False self._back = False self._parent = parent def draw(self): super().draw() if self._forward: self._views.back() forward_view = self._views.forward() x_min, x_max = forward_view[0][0:2] y_min, y_max = forward_view[0][2:4] elif self._back: self._views.forward() back_view = self._views.back() x_min, x_max = back_view[0][0:2] y_min, y_max = back_view[0][2:4] else: a = self.canvas.figure.get_axes() x_min, x_max = a.get_xlim() y_min, y_max = a.get_ylim() self.zoomChanged.emit(x_min, x_max, y_min, y_max) def home(self, *args): self._home = True super().home() self._home = False def release_zoom(self, event): super().release_zoom(event) self._home = False def forward(self, *args): self._forward = True super().forward(*args) self._forward = False def back(self, *args): self._back = True super().back(*args) self._back = False def release_pan(self, event): super().release_pan(event) self._home = False def drag_pan(self, event): super().drag_pan(event) def connect_canvas(self): """ Workaround to restore lost canvas interactivity. """ NavigationToolbar2.__init__(self, self.canvas)