# This file is part of Sympathy for Data.
# Copyright (c) 2021 Combine Control Systems
#
# SYMPATHY FOR DATA COMMERCIAL LICENSE
# You should have received a link to the License with Sympathy for Data.
from sympathy.api import node, exceptions
from sympathy.api.nodeconfig import Ports, Tag, Tags, adjust
from sylib_aml.dataset import DatasetPort, prefix_rename
import sklearn.exceptions
import sklearn.model_selection
import dask.dataframe as dd
from sympathy.api import ParameterView
from sympathy.api import qt2 as qt
from sympathy.utils.pip_util import import_optional
QtWidgets = qt.QtWidgets
def _torch():
return import_optional("torch", group="torch")
class SplitWidget(ParameterView):
"""
Creates a configuration GUI for the split method of the connected ML model
"""
def __init__(self, parameters, output=None, parent=None):
super().__init__(parent=parent)
self._parameters = parameters
gui_list = []
if output == "output_col":
col_edit = self._parameters["output_col"].gui()
gui_list.append(col_edit)
if output == "labels":
label_edit = self._parameters["labels"].gui()
gui_list.append(label_edit)
test_edit = self._parameters["test_size"].gui()
gui_list.append(test_edit)
strat_edit = self._parameters["stratify"].gui()
gui_list.append(strat_edit)
layout = QtWidgets.QVBoxLayout()
for gui in gui_list:
layout.addWidget(gui)
self.setLayout(layout)
[docs]
class SplitDataset(node.Node):
name = "Split Dataset (Experimental)"
nodeid = "com.sympathyfordata.advancedmachinelearning.splitdataset"
author = "Jannes Germishuys"
description = (
"Splits input dataset into a training and a test dataset in a "
"lazy-loaded manner. This means it only adds the indices attribute in "
"the dataset structure, indicating the indices of the dataset that "
"belong to the Training or Test datasets without actually splitting "
"the dataset into two. As a result, you will not see the actual splits"
" in the node's output ports.")
icon = "split_image_ds.svg"
tags = Tags(Tag.MachineLearning.Partitioning)
inputs = Ports(
[
DatasetPort("Input dataset", "input_ds"),
]
)
outputs = Ports(
[
DatasetPort("Training dataset", "train_ds"),
DatasetPort("Test dataset", "test_ds"),
]
)
parameters = node.parameters()
# Choose labels
editor = node.editors.combo_editor(edit=False, mode=False)
parameters.set_string(
"output_col",
label="Choose an output column",
description="Label column",
value="Column_name",
editor=editor,
)
parameters.set_string(
"labels",
label="Create image labels from filepaths (img_path).",
value='lambda item: item',
editor=node.editors.code_editor(),
description=(
"Specify expression to create label column for image "
"datasets or column name for tabular datasets"),
)
parameters.set_float(
"test_size",
label="Test set proportion",
value=0.2,
description="Test size for train/test split",
)
parameters.set_boolean(
"stratify",
label="Stratify",
value=True,
description="Stratify data using Y as class labels")
def exec_parameter_view(self, node_context):
data = node_context.input["input_ds"]
try:
data.load()
except Exception:
return SplitWidget(node_context.parameters)
ds = data.get_ds()
if ds is None:
return SplitWidget(node_context.parameters)
if ds["dstype"] == "table":
return SplitWidget(node_context.parameters, output="output_col")
else:
return SplitWidget(node_context.parameters, output="labels")
def adjust_parameters(self, node_context):
try:
adjust(node_context.parameters["output_col"],
node_context.input[0])
except Exception:
pass
def execute(self, node_context):
from sylib_aml.amlnets import ImgDataSet, TabDataSet, SyConcatDataset
data = node_context.input["input_ds"]
data.load()
ds = data.get_ds()
if ds is None:
raise exceptions.SyDataError("Empty dataset")
if ds["dstype"] == "table":
ds["output_col"] = (
node_context.parameters["output_col"].value
if len(node_context.parameters["output_col"].value) >= 1
else None)
reader_config = dict(ds["reader_config"])
datasets = list(
map(
lambda x: TabDataSet(
x,
None,
reader_config,
ds["column_config"],
label_col=ds["output_col"],
),
ds["paths"],
)
)
cds = SyConcatDataset(datasets)
prefix = reader_config.pop('prefix', None)
output_col_vals = dd.read_csv(
ds["paths"],
**reader_config,
include_path_column=False,
assume_missing=True,
blocksize=1000e6,
usecols=[
ds["output_col"]
]).compute().reset_index().drop(columns=["index"])
prefix_rename(output_col_vals, prefix)
train_size = max(
int((1 - node_context.parameters["test_size"].value) *
len(cds)),
1)
test_size = len(cds) - train_size
train_idx, test_idx = sklearn.model_selection.train_test_split(
range(0, len(cds)),
test_size=node_context.parameters["test_size"].value,
random_state=None,
stratify=output_col_vals.values if
node_context.parameters["stratify"].value else None
)
train_ds, test_ds = ds.copy(), ds.copy()
train_ds.update({"indices": train_idx}), test_ds.update(
{"indices": test_idx}
)
elif ds["dstype"] == "image":
ds["labels"] = node_context.parameters["labels"].value
cds = ImgDataSet(
images=ds["paths"], transforms=ds["transforms"],
transforms_values=ds["transforms_values"], labels=ds["labels"]
)
train_size = max(
int((1 - node_context.parameters["test_size"].value) *
len(cds)),
1)
test_size = len(cds) - train_size
tr_ds, te_ds = _torch().utils.data.random_split(
cds, [train_size, test_size])
train_ds, test_ds = ds.copy(), ds.copy()
train_ds.update({"indices": tr_ds.indices}), test_ds.update(
{"indices": te_ds.indices}
)
else:
raise exceptions.SyDataError("This data type is not supported.")
train_out = node_context.output["train_ds"]
test_out = node_context.output["test_ds"]
train_out.set_ds(train_ds)
test_out.set_ds(test_ds)
train_out.save()
test_out.save()