Source code for node_interpolation

# This file is part of Sympathy for Data.
# Copyright (c) 2013, 2017, Combine Control Systems AB
#
# Sympathy for Data is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# Sympathy for Data is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Sympathy for Data.  If not, see <http://www.gnu.org/licenses/>.
from __future__ import annotations
import collections
from typing import Callable, Union

import numpy as np
from scipy.interpolate import CubicSpline, make_interp_spline
import dask.array as da
import dask

from sympathy.api import node_helper
from sympathy.api import node as synode
from sympathy.api.nodeconfig import Port, Ports, Tag, Tags, adjust
from sympathy.api.exceptions import SyDataError, SyConfigurationError, sywarn
from sympathy.api import dtypes
from sympathy.api import qt2 as qt_compat
from sympathy.api import masked
from sympathy.api import table, adaf
from sylib.array import monotonic_array, monotonic_increasing_array
QtGui = qt_compat.import_module('QtGui')
QtWidgets = qt_compat.import_module('QtWidgets')
QtCore = qt_compat.import_module('QtCore')


_METHODS = ['zero', 'nearest', 'linear', 'quadratic', 'cubic']


def _method_wrapper(f):
    def inner(self, signal):
        dtype = signal.dtype
        if not len(self._source_basis):
            return np.ma.masked_all(self._target_basis.shape, dtype=dtype)
        if dtype.kind in ('m', 'M'):
            signal = signal.astype(float)
        result = f(self, signal)
        if dtype.kind in ('m', 'M'):
            return result.astype(dtype)
        return result
    return inner


class ProgressCounter:
    def __init__(self, progress, total_signals_count):
        self._progress = progress
        self._total_signals_count = total_signals_count
        self._count = 0

    def send_progress(self):
        self._progress(100 * min(1, self._count / self._total_signals_count))

    def increment(self):
        self._count += 1


class RasterResampler:
    def __init__(self, source_basis, target_basis):
        self._has_warned_about_missing_values = False
        # np.interp can't handle some types. Notably it can't handle datetimes.
        self._time_reference = None
        self._source_basis = self._to_float(source_basis)
        self._target_basis = self._to_float(target_basis)

        # Base cache variables used for all interpolation methods:
        self._i = None
        self._left = None
        self._right = None
        self._mask = None
        # Cache variables used for specific interpolation methods:
        self._nearest_i = None
        self._zero_i = None
        self._higher = False
        self._higher_mask = None

    def _mask_or_none(self, mask):
        if not np.any(mask):
            mask = None
        return mask

    def _to_float(self, basis):
        """Convert datetime bases to float."""
        if basis.dtype.kind == 'M':
            if self._time_reference is None:
                self._time_reference = basis[0]
            timeunit = np.timedelta64(1, 'us')
            return (basis - self._time_reference) / timeunit
        return basis

    def get_method(self, selected_method: str
                   ) -> Callable[[np.ndarray], np.ndarray]:
        """
        Return a function that can be used to interpolate signal to the current
        raster.
        """
        methods = {
            'zero': self.zero,
            'nearest': self.nearest,
            'linear': self.linear,
            'quadratic': self.quadratic,
            'cubic': self.cubic,
        }
        try:
            return methods[selected_method]
        except KeyError:
            raise ValueError(
                f"Unknown interpolation method '{selected_method}'") from None

    def _interp_indices(self):
        """
        Calculate and cache the common interpolation steps for the current
        raster. Each interpolation method will refine and use these results.
        """
        if self._i is not None:
            return

        LEFT, RIGHT = -1, -2
        old_i = np.arange(len(self._source_basis))
        self._i = np.interp(self._target_basis, self._source_basis, old_i,
                            left=LEFT, right=RIGHT)
        self._left = self._i == LEFT
        self._right = self._i == RIGHT
        self._mask = self._mask_or_none(np.logical_or(self._left, self._right))

        # Set extrapolated values to what is needed in methods zero and
        # nearest. Other methods mask these values anyways.
        self._i[self._left] = 0
        self._i[self._right] = -1

        # Make it easier to check if any values are masked.
        self._left = self._mask_or_none(self._left)

    @_method_wrapper
    def zero(self, signal):
        """Nearest preceding samples."""
        if self._zero_i is None:
            self._interp_indices()
            self._zero_i = self._i.astype(int)
        new_signal = signal[self._zero_i]
        return self._mask_array(new_signal, self._left)

    @_method_wrapper
    def nearest(self, signal):
        """Nearest neighbor."""
        if self._nearest_i is None:
            self._interp_indices()
            self._nearest_i = self._i.astype(int)
            if self._mask is not None:
                self._nearest_i[~self._mask] = (
                    self._i[~self._mask] + 0.5).astype(int)
        return signal[self._nearest_i]

    def _extrapolated_mask(self):
        if self._higher is False:
            self._higher_mask = self._mask_or_none(np.logical_or(
                self._target_basis > masked.max(self._source_basis),
                self._target_basis < masked.min(self._source_basis)))
            self._higher = True
        return self._higher_mask

    def _mask_array(self, signal, mask):
        if mask is None:
            return signal
        else:
            return np.ma.MaskedArray(signal, mask=mask)

    def _mask_extrapolated_values(self, signal):
        mask = self._extrapolated_mask()
        return self._mask_array(signal, mask)

    def _combined_mask(self, *masks):
        masks = [mask for mask in masks if mask is not None]
        mask = None
        if masks:
            if len(masks) == 1:
                mask = masks[0]
            else:
                mask = np.logical_and.reduce(masks)
        return mask

    def _mask_when_not_all(self, mask):
        return None if np.all(mask) else mask

    def _separate_signal_mask(self, signal):
        mask = None
        if isinstance(signal, np.ma.MaskedArray):
            mask = self._mask_when_not_all(~np.ma.getmaskarray(signal))
            signal = signal.data
        return signal, mask

    def _last_consecutive_index(self, array):
        return np.concatenate((np.diff(array) != 0, [True]))

    def _finite_index(self, signal):
        return self._mask_when_not_all(np.isfinite(signal))

    def _higher_order(self, signal, spline_factory, min_points):
        """
        Interpolations based on scipy.interpolation spline methods.

        Masks and NaN values are separately interpolated using the same method
        and finally applied to the interpolated signal.

        All non-finite values are treated as NaN.

        Reverts to nearest if the number of usable samples are less than
        `min_points`.
        """
        # Use scipy interpolation function

        def remove_duplicates(basis, signal):
            index = self._last_consecutive_index(basis)
            if not np.all(index):
                basis = basis[index]
                signal = signal[index]
            return basis, signal

        def interpolate_mask(basis, mask):
            spline = spline_factory(basis, ~mask)
            return spline(self._target_basis, extrapolate=False) > 0.0

        def apply_mask(basis, signal, mask):
            if mask is not None:
                basis = basis[mask]
                signal = signal[mask]
            return basis, signal

        self._warn_if_missing_values(signal)
        basis = self._source_basis
        # Interpolation needs to be strictly monotonic basis
        clean_basis, new_signal = remove_duplicates(basis, signal)
        # Signal values needs to be finite.
        finite_mask = self._finite_index(new_signal)
        # Handle mask values separately.
        new_signal, masked_mask = self._separate_signal_mask(new_signal)
        new_basis, new_signal = apply_mask(
            clean_basis, new_signal,
            self._combined_mask(finite_mask, masked_mask))

        if len(clean_basis) < min_points:
            return self.nearest(signal)

        spline = spline_factory(new_basis, new_signal)
        new_signal = spline(self._target_basis, extrapolate=False)
        # Calculate final mask.
        if masked_mask is not None:
            masked_mask = interpolate_mask(clean_basis, masked_mask)
            extrapolated_mask = self._extrapolated_mask()
            if extrapolated_mask is not None:
                masked_mask |= extrapolated_mask
        else:
            masked_mask = self._extrapolated_mask()

        if finite_mask is not None:
            # Restore non-finite values.
            new_signal[interpolate_mask(clean_basis, finite_mask)] = np.nan
        return self._mask_array(new_signal, masked_mask)

    @_method_wrapper
    def linear(self, signal):
        """Piecewise linear interpolation."""
        return self._higher_order(
            signal,
            lambda basis, signal: make_interp_spline(
                basis, signal, k=1, bc_type='not-a-knot'),
            min_points=2)

    @_method_wrapper
    def quadratic(self, signal):
        """Quadratic interpolation."""
        return self._higher_order(
            signal,
            lambda basis, signal: make_interp_spline(
                basis, signal, k=2, bc_type='not-a-knot'),
            min_points=3)

    @_method_wrapper
    def cubic(self, signal):
        """Cubic interpolation."""
        return self._higher_order(signal, CubicSpline, min_points=4)

    def _warn_if_missing_values(self, signal):
        """Warn about using quadratic/cubic interpolation with NaN values."""
        if self._has_warned_about_missing_values:
            return

        if isinstance(signal, np.ma.MaskedArray):
            if np.any(np.ma.getmaskarray(signal) | np.isnan(signal)):
                sywarn("Masked values/NaNs are not supported when "
                       "interpolating with quadratic or cubic method.")
                self._has_warned_about_missing_values = True


class TableResampler:
    def __init__(self, raster_table: table.Table, basis_name: str,
                 target_basis: np.ndarray, progress_counter: ProgressCounter):
        super().__init__()
        self._basis_name = basis_name
        self._table = raster_table
        self._target_basis = target_basis
        self._progress_counter = progress_counter
        self._raster_resampler: RasterResampler
        self._initialized = False

    def _get_signal(self, signal_name: str) -> np.ndarray:
        return self._table.get_column_to_array(signal_name)

    def _get_selected_method_name(
            self, signal: np.ndarray, methods: tuple[str, str, str]
    ) -> str:
        bool_method, int_method, float_method = methods

        kind = signal.dtype.kind
        if kind in ('b', 'S', 'U'):
            selected_method = bool_method
        elif kind in ('i', 'u'):
            selected_method = int_method
        else:
            selected_method = float_method

        if self._table.number_of_rows() == 1 and selected_method not in (
                'zero', 'nearest'):
            # Signals with one sample can only be resampled using 'nearest' or
            # 'zero'.
            selected_method = 'nearest'
        return selected_method

    def _initialize(self):
        if not self._initialized:
            self._raster_resampler = RasterResampler(
                self._get_signal(self._basis_name), self._target_basis)
            del self._target_basis
            self._initialized = True

    def _resample(self, signal_name: str, selected_method: str) -> np.ndarray:
        self._initialize()
        self._progress_counter.send_progress()
        signal = self._get_signal(signal_name)
        interp = self._raster_resampler.get_method(selected_method)
        res = interp(signal)
        self._progress_counter.increment()
        return res

    def resample(self, signal_name: str, methods: tuple[str, str, str]
                 ) -> Union[np.ndarray, da.Array]:
        signal = self._get_signal(signal_name)
        selected_method = self._get_selected_method_name(signal, methods)
        return self._resample(signal_name, selected_method)

    @classmethod
    def from_raster(cls, raster: adaf.RasterN, target_basis: np.ndarray,
                    progress_counter: ProgressCounter
                    ) -> TableResampler:
        """
        Create and return TableResampler from raster.
        """
        basis_name = '__sylib_interpolation_basis__'
        raster_table = raster.to_table(basis_name=basis_name)
        return cls(
            raster_table, basis_name, target_basis, progress_counter)


class LazyTableResampler(TableResampler):

    def _output_dtype(self, signal: np.ndarray, selected_method: str
                      ) -> np.dtype:
        dtype = signal.dtype
        if (selected_method not in ('zero', 'nearest')
                and dtype.kind not in ('m', 'M')):
            dtype = np.dtype(
                complex if np.issubdtype(dtype, complex) else float)
        return dtype

    def resample(self, signal_name: str, methods: tuple[str, str, str]
                 ) -> Union[np.ndarray, da.Array]:
        """
        Return lazily resampled signal as dask array.
        """
        signal = self._get_signal(signal_name)
        selected_method = self._get_selected_method_name(signal, methods)
        delayed = dask.delayed(self._resample)(signal_name, selected_method)
        return da.from_delayed(
            delayed, self._target_basis.shape,
            dtype=self._output_dtype(signal, selected_method))


def _table_resampler_from_raster(
        raster: adaf.RasterN, target_basis: np.ndarray,
        progress_counter: ProgressCounter, lazy: bool) -> TableResampler:
    cls = LazyTableResampler if lazy else TableResampler
    return cls.from_raster(raster, target_basis, progress_counter)


def resample_file_with_spec(parameter_root, spec, in_adaffile, out_adaffile,
                            progress):
    signals_col = parameter_root['signals_colname']._require().selected
    dt_col = parameter_root['dt_colname'].selected
    tb_col = parameter_root['tb_colname'].selected
    raster_param = parameter_root['raster_colname']
    signals = spec.get_column_to_array(signals_col)
    lazy = _output_lazy(out_adaffile)

    if dt_col is not None:
        dts = spec.get_column_to_array(dt_col)
    else:
        dts = np.zeros(signals.shape, dtype=float) * np.nan
    if tb_col is not None:
        to_tbs = spec.get_column_to_array(tb_col)
    else:
        to_tbs = np.zeros(signals.shape, dtype=str)

    if raster_param.selected:
        raster_names = spec._require_column(raster_param, types='U')
    else:
        raster_names = np.zeros(signals.shape, dtype=str)

    dt_to_signals = collections.defaultdict(list)
    tbname_to_signals = collections.defaultdict(list)
    for i, (dt, to_tb, signal, raster_name) in enumerate(
            zip(dts, to_tbs, signals, raster_names)):
        if dt and not np.isnan(dt):
            dt_to_signals[(dt, raster_name)].append(signal)
        elif to_tb:
            tbname_to_signals[(to_tb, raster_name)].append(signal)
        else:
            raise SyDataError("Row {} in specification table specifies "
                              "neither dt nor a target time basis.".format(i))

    new_timebases = []
    raster_names_dt = collections.defaultdict(list)
    total_signals_count = 0
    for (dt, raster_name), dt_signals in dt_to_signals.items():
        # Create output system/raster
        new_timebasis, unit = get_new_timebasis(in_adaffile, dt, dt_signals)
        if raster_name:
            pass
        else:
            raster_name = 'Resampled raster {}'.format(dt)
        new_timebases.append(
            (raster_name, new_timebasis, unit, dt_signals))
        raster_names_dt[raster_name].append(dt)
        total_signals_count += len(dt_signals)
    for raster_name, dts in raster_names_dt.items():
        if len(dts) > 1:
            raise SyConfigurationError('Repeated raster names.')

    for (tbname, raster_name), tb_signals in tbname_to_signals.items():
        raster_dict = get_raster_dict(in_adaffile)
        try:
            old_system_name, old_raster_name = raster_dict[tbname]
        except KeyError:
            sywarn('No such system/raster: {}'.format(tbname))
            continue
        old_tb_column = (
            in_adaffile.sys[old_system_name][old_raster_name].basis_column())

        raster_name = raster_name or old_raster_name

        new_timebasis = old_tb_column.value()
        unit = old_tb_column.attr['unit'] or 's'
        new_timebases.append((raster_name, new_timebasis, unit, tb_signals))
        total_signals_count += len(tb_signals)

    selected_methods = (parameter_root['bool_interp_method'].selected,
                        parameter_root['int_interp_method'].selected,
                        parameter_root['interpolation_method'].selected)

    new_system = out_adaffile.sys.create("Resampled system")
    progress_counter = ProgressCounter(progress, total_signals_count)
    for raster_name, new_timebasis, unit, signals in new_timebases:
        new_raster = new_system.get_or_create(raster_name)
        raster_table = table.Table()

        # Loop over all rasters and resample selected signals in them
        for origin_system_name in in_adaffile.sys.keys():
            origin_system = in_adaffile.sys[origin_system_name]
            for origin_raster_name in origin_system.keys():
                origin_raster = origin_system[origin_raster_name]
                signals_in_this_raster = [
                    signal_name for signal_name in origin_raster.keys()
                    if signal_name in signals]
                if not signals_in_this_raster:
                    continue

                resampler = _table_resampler_from_raster(
                    origin_raster, new_timebasis, progress_counter, lazy)

                # Loop over all signals and resample them
                for signal_name in signals_in_this_raster:
                    signal = origin_raster[signal_name]
                    new_y = resampler.resample(signal_name, selected_methods)
                    attrs = dict(signal.signal().attr.items())
                    raster_table.set_column_from_array(
                        signal.name, new_y, attrs)
        new_raster.from_table(raster_table, use_basis_name=False)
        basis_attr = None
        if unit:
            basis_attr = {'unit': unit}
        new_raster.create_basis(new_timebasis, attributes=basis_attr)


def _warn_if_not_monotonic_increasing(what, signal):
    if not monotonic_increasing_array(signal):
        sywarn(f"{what} should be a monotonic increasing signal.")


def _warn_if_not_monotonic(what, signal):
    if not monotonic_array(signal):
        sywarn(f"{what} should be a monotonic signal.")


def resample_file(parameter_root, in_adaffile, out_adaffile, progress):
    dt = parameter_root['dt'].value
    use_dt = parameter_root['use_dt'].value
    resample_all = parameter_root['resample_all_rasters'].value
    chosen_signals = parameter_root['ts'].value_names
    lazy = _output_lazy(out_adaffile)

    if resample_all:
        signals = in_adaffile.ts.keys()
    else:
        signals = chosen_signals

    # Create output system/raster
    if use_dt:
        system_name = 'Resampled system'
        raster_name = 'Resampled raster'
        new_timebasis, unit = get_new_timebasis(in_adaffile, dt, signals)
    else:
        raster_dict = get_raster_dict(in_adaffile)
        new_tb = parameter_root['new_tb']._require().selected
        system_name, raster_name = raster_dict[new_tb]
        old_tb_column = (in_adaffile.sys[system_name][raster_name]
                         .basis_column())
        new_timebasis = old_tb_column.value()
        _warn_if_not_monotonic('Target timebasis', new_timebasis)
        unit = old_tb_column.attr['unit']
    new_system = out_adaffile.sys.create(system_name)
    new_raster = new_system.create(raster_name)
    raster_table = table.Table()

    attributes = {}
    if unit:
        attributes['unit'] = unit
    if use_dt:
        attributes['time step'] = dt

    if parameter_root['only_timebasis'].value:
        new_raster.create_basis(new_timebasis, attributes=attributes)
        return

    selected_methods = (parameter_root['bool_interp_method'].selected,
                        parameter_root['int_interp_method'].selected,
                        parameter_root['interpolation_method'].selected)

    # Loop over all rasters and resample selected signals in them
    progress_counter = ProgressCounter(progress, len(signals))
    for origin_system_name in in_adaffile.sys.keys():
        origin_system = in_adaffile.sys[origin_system_name]
        for origin_raster_name in origin_system.keys():
            origin_raster = origin_system[origin_raster_name]
            signals_in_this_raster = [
                signal_name for signal_name in origin_raster.keys()
                if signal_name in signals]
            if not signals_in_this_raster:
                continue

            origin_basis = origin_raster.basis_column().value()
            new_basis = new_timebasis

            _warn_if_not_monotonic_increasing(
                f'Origin timebasis {origin_system_name}/{origin_raster_name}',
                origin_basis)

            resampler = _table_resampler_from_raster(
                origin_raster, new_basis, progress_counter, lazy)

            # Loop over all signals and resample them
            for signal_name in signals_in_this_raster:
                signal = origin_raster[signal_name]
                new_y = resampler.resample(signal_name, selected_methods)
                attrs = dict(signal.signal().attr.items())
                raster_table.set_column_from_array(signal_name, new_y, attrs)

    new_raster.from_table(raster_table, use_basis_name=False)
    new_raster.create_basis(new_timebasis, attributes=attributes)


def get_new_timebasis(in_adaffile, dt, signals):
    """
    Get new timebasis covering the same range as all the old timebases using
    step size dt.
    """
    if dt <= 0:
        raise SyConfigurationError('Time step must be positive.')

    t_range = None
    basis_kind = None
    basis_units = set()

    # The range of the new time basis should be the superset of all the
    # times in the resampled rasters.
    # TODO: What about reference times here?
    # TODO: I removed the warning when some signals are missing. Does that
    # matter?
    for system_name in in_adaffile.sys.keys():
        system = in_adaffile.sys[system_name]
        for raster_name in system.keys():
            raster = system[raster_name]
            if not set(raster.keys()) & set(signals):
                # No selected signals in this raster
                continue

            old_basis = raster.basis_column()
            basis = old_basis.value()
            unit = old_basis.attr.get('unit', '')

            if basis_kind:
                if basis_kind != basis.dtype.kind and not (
                        basis_kind in 'fiu' and basis.dtype.kind in 'fiu'):
                    raise SyDataError(
                        "All time bases must be of the same type. "
                        "Found both {} and {}.".format(
                            dtypes.typename_from_kind(basis_kind),
                            dtypes.typename_from_kind(basis.dtype.kind)))
            else:
                basis_kind = basis.dtype.kind

            # 'unknown' was the previous default unit so we can treat that as
            # empty.
            if unit and unit != 'unknown':
                basis_units.add(unit)

            if len(basis):
                if t_range is None:
                    t_range = basis[0], basis[-1]
                else:
                    t_range = (min(t_range[0], basis[0]),
                               max(t_range[1], basis[-1]))

    basis_unit = ''
    if len(basis_units) == 1:
        basis_unit = basis_units.pop()
    elif len(basis_units) > 1:
        sywarn("The time bases have different units: {}".format(
            ", ".join(basis_units)))
    if t_range is None:
        # None of the selected signals were present, return empty target basis
        return np.array([]), ''
    if basis_kind == 'M':
        t_range = t_range[0].astype(float), t_range[1].astype(float)
        dt *= 1000000

    # If number of samples fit perfectly, use linspace to always catch both end
    # points, at the expense of sometimes changing dt very slightly.
    sample_count = np.around((t_range[1] - t_range[0]) / dt, 8)
    if sample_count == int(sample_count):
        timebasis_new = np.linspace(
            t_range[0], t_range[1], int(sample_count + 1))
    else:
        timebasis_new = np.arange(t_range[0], t_range[1], dt)
    if basis_kind == 'M':
        timebasis_new = timebasis_new.astype('M8[us]')

    return timebasis_new, basis_unit


def get_raster_dict(adaffile):
    if adaffile.is_valid():
        return dict(
            [('/'.join([system_name, raster_name]), (system_name, raster_name))
             for system_name, system in adaffile.sys.items()
             for raster_name in system.keys()])
    else:
        return {}


def timebasis_only_parameter(parameters):
    parameters.set_boolean(
        'only_timebasis', label='Export time basis only', value=False,
        description='Choose to only export the time basis')
    return parameters


class SuperNode(synode.Node):
    author = 'Helena Olen & Magnus Sandén'
    icon = 'interpolate.svg'
    tags = Tags(Tag.Analysis.SignalProcessing)

    parameters = synode.parameters()
    parameters.set_boolean(
        'resample_all_rasters', value=True, label="Resample all signals",
        description='Apply resampling to all signals')
    ts_editor = synode.editors.multilist_editor(edit=True)
    ts_editor.set_attribute('filter', True)
    parameters.set_list(
        'ts', label="Choose signals",
        description='Choose signals to interpolate', editor=ts_editor)
    parameters.set_float(
        'dt', label='Time step',
        description=('Time step in new timebasis. If old timebasis is of '
                     'type datetime this is considered to be in seconds.'),
        editor=synode.editors.spinbox_editor(step=1e-3))
    parameters.set_list(
        'new_tb', label='Timebasis to use for interpolation',
        description=('Timebasis to use as new timebasis '
                     'for selected timeseries'),
        editor=synode.editors.combo_editor(filter=True))
    parameters.set_boolean(
        'use_dt', label='Time step approach', value=True,
        description='Choose between a custom time step and using an existing.')
    parameters = timebasis_only_parameter(parameters)
    parameters.set_list(
        'bool_interp_method', plist=['zero', 'nearest'], value=[1],
        description=('Method used to interpolate boolean, text, and '
                     'byte string data'),
        editor=synode.editors.combo_editor())
    parameters.set_list(
        'int_interp_method', plist=_METHODS, value=[_METHODS.index('nearest')],
        description='Method used to interpolate integer data',
        editor=synode.editors.combo_editor())
    parameters.set_list(
        'interpolation_method', plist=_METHODS,
        value=[_METHODS.index('linear')],
        description='Method used to interpolate other data types',
        editor=synode.editors.combo_editor())


def _output_lazy(output_file) -> bool:
    """
    Return True if the file can be written lazily.
    """
    return output_file.can_write()


[docs]class InterpolateADAF(SuperNode): """ The new timebasis can either be an existing timebasis in the ADAF-file or a timebasis with a timestep defined by the user. The timeseries that will be interpolated are selected in a list. The output file will contain a single system and raster with all the chosen timeseries. Signals in the configuration GUI corresponds to timeseries in the ADAF. Interpolation methods are chosen separately for different groups of data types: non-numeric, integer, other numeric. Methods: :zero: Zero-order hold (preceding sample) :nearest: Nearest neighbor interpolation (nearest sample) :linear: Linear interpolation :quadratic: Quadratic spline interpolation :cubic: Cubic spline interpolation Notable behaviors: - Extrapolated values are masked. - Non-finite values are treated as NaN. - Fallback to `nearest` when usable samples too few for the selected method (fallback might change in future versions). """ name = 'Interpolate ADAF' description = 'Interpolate timeseries to a single timebasis.' nodeid = 'org.sysess.sympathy.data.adaf.interpolateadaf' related = ['org.sysess.sympathy.data.adaf.interpolateadafs'] inputs = Ports([Port.ADAF('Input ADAF', name='port1')]) outputs = Ports([Port.ADAF('Interpolated ADAF', name='port1')]) def adjust_parameters(self, node_context): adjust(node_context.parameters['new_tb'], node_context.input['port1'], kind='rasters') adjust(node_context.parameters['ts'], node_context.input['port1'], kind='ts') def exec_parameter_view(self, node_context): parameter_root = node_context.parameters return InterpolationWidget(parameter_root) def execute(self, node_context): in_adaffile = node_context.input['port1'] out_adaffile = node_context.output['port1'] parameter_root = node_context.parameters if in_adaffile.is_empty(): return out_adaffile.meta.from_table(in_adaffile.meta.to_table()) out_adaffile.res.from_table(in_adaffile.res.to_table()) resample_file(parameter_root, in_adaffile, out_adaffile, progress=self.set_progress)
[docs]@node_helper.list_node_decorator(['port1'], ['port1']) class InterpolateADAFs(InterpolateADAF): name = 'Interpolate ADAFs' nodeid = 'org.sysess.sympathy.data.adaf.interpolateadafs'
[docs]class InterpolateADAFsFromTable(synode.Node): """ The specification table should have two to three columns. It must have a column with the names of the signals that should be interpolated. Furthermore it should have either a column with resampling rate for each signal or a column with the names of the signals to whose time basis it should interpolate each signal. It can also have both columns and if both of them have values for the same row it will use the resample rate. """ name = 'Interpolate ADAFs with Table' description = 'Interpolate timeseries using a specification table.' nodeid = 'org.sysess.sympathy.data.adaf.interpolateadafswithtable' author = 'Magnus Sandén' icon = 'interpolate.svg' tags = Tags(Tag.Analysis.SignalProcessing) related = ['org.sysess.sympathy.data.adaf.interpolateadaf'] inputs = Ports([Port.Table('Specification Table', name='spec'), Port.ADAFs('Input ADAFs', name='port1')]) outputs = Ports([Port.ADAFs('Interpolated ADAFs', name='port1')]) parameters = synode.parameters() parameters.set_list( 'signals_colname', label='Column with signal names', description='Resample the timeseries in this column.', editor=synode.editors.combo_editor(edit=True, filter=True)) parameters.set_list( 'dt_colname', label='Column with sample rates', description=('The selected column should contain sample rates to ' 'which the selected signals will be resampled. ' 'At least one of this parameter and the time bases ' 'parameter must be specified.'), editor=synode.editors.combo_editor( include_empty=True, edit=True, filter=True)) parameters.set_list( 'tb_colname', label='Column with time bases', description=('The selected column should contain existsing time bases ' 'to which the selected signals will be resampled. ' 'At least one of this parameter and the time bases ' 'parameter must be specified.'), editor=synode.editors.combo_editor( include_empty=True, edit=True, filter=True)) parameters.set_list( 'raster_colname', label='Column with output raster names', description=( 'The selected column should contain raster names ' 'for the resampled rasters, OPTIONAL.\n' 'When not specified, names are determined from sample rates or ' 'time bases as usual.' ), editor=synode.editors.combo_editor( include_empty=True, edit=True, filter=True)) parameters = timebasis_only_parameter(parameters) parameters.set_list( 'bool_interp_method', plist=['zero', 'nearest'], value=[1], description=('Method used to interpolate boolean, text, and ' 'byte string data'), editor=synode.editors.combo_editor()) parameters.set_list( 'int_interp_method', plist=_METHODS, value=[_METHODS.index('nearest')], description='Method used to interpolate integer data', editor=synode.editors.combo_editor()) parameters.set_list( 'interpolation_method', plist=_METHODS, value=[_METHODS.index('linear')], description='Method used to interpolate other data types', editor=synode.editors.combo_editor()) def update_parameters(self, parameters): # Older versions (before sympathy 1.4.4) specified the empty element # explicitly in adjust_parameters. In 1.4.4 the empty element went # missing. From 1.4.5 we add it with 'include_empty' instead. parameters['dt_colname'].editor['include_empty'] = True parameters['tb_colname'].editor['include_empty'] = True def adjust_parameters(self, node_context): pnames = ['signals_colname', 'dt_colname', 'tb_colname', 'raster_colname'] for pname in pnames: adjust(node_context.parameters[pname], node_context.input['spec']) def exec_parameter_view(self, node_context): parameter_root = node_context.parameters tab_widget = QtWidgets.QTabWidget() methods_tab = MethodsTab(parameter_root) choose_spec_cols_tab = ChooseSpecColsTab(parameter_root) tab_widget.addTab(choose_spec_cols_tab, 'Specification table') tab_widget.addTab(methods_tab, 'Interpolation method') return tab_widget def execute(self, node_context): input_adafs = node_context.input['port1'] spec = node_context.input['spec'] if spec.is_empty(): return output_adafs = node_context.output['port1'] parameter_root = node_context.parameters def interpolate_adafs_w_table(input_adaf, output_adaf, set_progress): output_adaf.meta.from_table(input_adaf.meta.to_table()) output_adaf.res.from_table(input_adaf.res.to_table()) resample_file_with_spec(parameter_root, spec, input_adaf, output_adaf, progress=set_progress) synode.map_list_node(interpolate_adafs_w_table, input_adafs, output_adafs, self.set_progress)
class ChooseSpecColsTab(QtWidgets.QWidget): def __init__(self, parameter_root): super().__init__() layout = QtWidgets.QVBoxLayout() layout.addWidget(parameter_root['signals_colname'].gui()) layout.addWidget(parameter_root['dt_colname'].gui()) layout.addWidget(parameter_root['tb_colname'].gui()) layout.addWidget(parameter_root['raster_colname'].gui()) layout.addStretch(0) self.setLayout(layout) class MethodsTab(QtWidgets.QWidget): def __init__(self, parameter_root): super().__init__() layout = QtWidgets.QVBoxLayout() layout.addWidget(QtWidgets.QLabel( "Choose interpolation methods for different data types.")) self._form_layout = QtWidgets.QFormLayout() categories = ['Boolean, text, and byte string', 'Integer', 'Float, datetime, and timedelta'] parameters = ['bool_interp_method', 'int_interp_method', 'interpolation_method'] for category, parameter in zip(categories, parameters): widget = parameter_root[parameter].gui() self._form_layout.addRow(category, widget) layout.addLayout(self._form_layout) self._only_timebasis = parameter_root['only_timebasis'].gui() layout.addWidget(self._only_timebasis) layout.addStretch(0) self.setLayout(layout) self._only_timebasis.stateChanged.connect( self._only_timebasis_changed) self._only_timebasis_changed() def _only_timebasis_changed(self): disabled = ( self._only_timebasis.editor().checkState() != QtCore.Qt.CheckState.Checked) for i in range(self._form_layout.count()): self._form_layout.itemAt(i).widget().setEnabled(disabled) class InterpolationWidget(QtWidgets.QWidget): """A widget containing a TimeBasisWidget and a ListSelectorWidget.""" def __init__(self, parameter_root, parent=None): super().__init__() self._parameter_root = parameter_root self._init_gui() def _init_gui(self): # Create widgets and add to layout tab_widget = QtWidgets.QTabWidget() tb_tab = self._init_tb_tab() tab_widget.addTab(tb_tab, 'New time basis') methods_tab = MethodsTab(self._parameter_root) tab_widget.addTab(methods_tab, 'Interpolation method') signals_tab = self._init_signals_tab() tab_widget.addTab(signals_tab, 'Signals') layout = QtWidgets.QVBoxLayout() layout.addWidget(tab_widget) self.setLayout(layout) self._init_gui_from_parameters() def _init_tb_tab(self): # Create radio button group self._dt_or_tb = QtWidgets.QButtonGroup() self._dt_or_tb.setExclusive(True) self._custom_dt_button = QtWidgets.QRadioButton('Use custom timestep') self._use_tb_button = QtWidgets.QRadioButton( 'Interpolate using existing timebasis') # Add buttons to group self._dt_or_tb.addButton(self._custom_dt_button) self._dt_or_tb.addButton(self._use_tb_button) self._new_tb = self._parameter_root['new_tb'].gui() self._dt = self._parameter_root['dt'].gui() tb_ts_vlayout = QtWidgets.QVBoxLayout() tb_ts_vlayout.addWidget(self._custom_dt_button) tb_ts_vlayout.addWidget(self._dt) tb_ts_vlayout.addWidget(self._use_tb_button) tb_ts_vlayout.addWidget(self._new_tb) tb_ts_vlayout.addStretch(0) tb_tab = QtWidgets.QWidget() tb_tab.setLayout(tb_ts_vlayout) self._dt_or_tb.buttonClicked.connect(self._button_changed) return tb_tab def _init_signals_tab(self): signals_tab = QtWidgets.QWidget() self._ts_selection = self._parameter_root['ts'].gui() self._resample_all_signals = ( self._parameter_root['resample_all_rasters'].gui()) signals_vlayout = QtWidgets.QVBoxLayout() signals_vlayout.addWidget(self._resample_all_signals) signals_vlayout.addWidget(self._ts_selection) self._resample_all_signals.editor().toggled.connect( self._ts_selection.set_disabled) self._ts_selection.set_disabled( self._parameter_root['resample_all_rasters'].value) signals_tab.setLayout(signals_vlayout) return signals_tab def _init_gui_from_parameters(self): dt_checked = self._parameter_root['use_dt'].value self._custom_dt_button.setChecked(dt_checked) self._use_tb_button.setChecked(not dt_checked) self._use_tb_button.setChecked(not dt_checked) self._dt.setEnabled(dt_checked) self._new_tb.setEnabled(not dt_checked) def _button_changed(self, button): """ Radiobuttton clicked. Enable/disable custom coefficient edits or predefined filter widgets depedning on which button that is pressed. """ if button == self._custom_dt_button: self._dt.setEnabled(True) self._new_tb.setEnabled(False) self._parameter_root['use_dt'].value = True else: self._dt.setEnabled(False) self._new_tb.setEnabled(True) self._parameter_root['use_dt'].value = False