Coverage for src / mafw / processor_library / sns_plotter.py: 99%
242 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-09 09:08 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-09 09:08 +0000
1# Copyright 2025 European Union
2# Author: Bulgheroni Antonio (antonio.bulgheroni@ec.europa.eu)
3# SPDX-License-Identifier: EUPL-1.2
4"""
5Module implements a Seaborn plotter processor with a mixin structure to generate seaborn plots.
7This module implements the :mod:`.abstract_plotter` functionalities using :link:`seaborn` and :link:`pandas`.
9These two packages are not installed in the default installation of MAFw, unless the user decided to include the
10optional feature `seaborn`.
12Along with the :class:`SNSPlotter`, it includes a set of standard data retriever specific for pandas data frames.
13"""
15import logging
16import re
17import typing
18import warnings
19from collections.abc import Callable, Iterable, Mapping, MutableMapping, Sequence
20from pathlib import Path
21from typing import Any, TypeAlias
23import peewee
25from mafw.decorators import class_depends_on_optional, processor_depends_on_optional
26from mafw.mafw_errors import MissingOptionalDependency, PlotterMixinNotInitialized
27from mafw.processor_library.abstract_plotter import DataRetriever, FigurePlotter, GenericPlotter
29log = logging.getLogger(__name__)
31try:
32 import matplotlib.colors
33 import matplotlib.pyplot as plt
34 import pandas as pd
35 import seaborn as sns
36 from matplotlib.typing import ColorType
37 from seaborn._core.typing import ColumnName
39 from mafw.tools.pandas_tools import group_and_aggregate_data_frame, slice_data_frame
41 # noinspection PyProtectedMember
42 _Palette: TypeAlias = str | Sequence[ColorType] | Mapping[Any, ColorType]
44 @class_depends_on_optional('pandas')
45 class PdDataRetriever(DataRetriever):
46 def __init__(self, *args: Any, **kwargs: Any) -> None:
47 super().__init__(*args, **kwargs)
48 self.data_frame: pd.DataFrame
50 def get_data_frame(self) -> None:
51 pass
53 def patch_data_frame(self) -> None:
54 super().patch_data_frame() # type: ignore[safe-super]
56 def _attributes_valid(self) -> bool:
57 return True
59 @class_depends_on_optional('pandas')
60 class FromDatasetDataRetriever(PdDataRetriever):
61 """
62 A data retriever to get a dataframe from a seaborn dataset
63 """
65 def __init__(self, dataset_name: str | None = None, *args: Any, **kwargs: Any) -> None:
66 super().__init__(*args, **kwargs)
67 self.dataset_name = dataset_name if dataset_name is not None else ''
69 def _attributes_valid(self) -> bool:
70 """Checks if the attributes of the mixin are all valid"""
71 if self.dataset_name == '':
72 return False
74 return self.dataset_name in sns.get_dataset_names()
76 def get_data_frame(self) -> None:
77 """Gets the data frame from the standard seaborn datasets"""
78 if not self._attributes_valid():
79 raise PlotterMixinNotInitialized()
80 self.data_frame = sns.load_dataset(self.dataset_name)
82 @class_depends_on_optional('pandas')
83 class SQLPdDataRetriever(PdDataRetriever):
84 """
85 A specialized data retriever to get a data frame from a database table.
87 The idea is to implement an interface to the pandas ``read_sql``. The user has to provide the :attr:`table name
88 <.table_name>`, the :attr:`the list of required columns <.required_columns>` and an optional :attr:`where clause
89 <.where_clause>`.
90 """
92 database: peewee.Database
93 """The database instance. It comes from the main class"""
95 def __init__(
96 self,
97 table_name: str | None = None,
98 required_cols: Iterable[str] | str | None = None,
99 where_clause: str | None = None,
100 *args: Any,
101 **kwargs: Any,
102 ) -> None:
103 """
104 Constructor parameters:
106 :param table_name: The name of the table from where to get the data
107 :type table_name: str, Optional
108 :param required_cols: A list of columns to be selected from the table and transferred as column in the dataframe.
109 :type required_cols: Iterable[str] | str | None, Optional
110 :param where_clause: The where clause used in the select SQL statement. If None is provided, then all rows will
111 be selected.
112 :type where_clause: str, Optional
113 """
114 super().__init__(*args, **kwargs)
115 self.table_name: str
116 """The table from where the data should be taken."""
117 if table_name is None:
118 self.table_name = ''
119 else:
120 self.table_name = table_name
122 self.required_columns: Iterable[str]
123 """
124 The iterable of columns.
126 Those are the column names to be selected from the :attr:`~.table_name` and included in the dataframe.
127 """
128 if required_cols is None:
129 self.required_columns = []
130 elif isinstance(required_cols, str):
131 self.required_columns = [required_cols]
132 else:
133 self.required_columns = required_cols
135 self.where_clause: str
136 """The where clause of the SQL statement"""
138 if where_clause is None:
139 self.where_clause = '1'
140 else:
141 self.where_clause = where_clause
143 def get_data_frame(self) -> None:
144 """
145 Retrieve the dataframe from a database table.
147 :raise PlotterMixinNotInitialized: If some of the required attributes are missing.
148 """
149 if not self._attributes_valid():
150 raise PlotterMixinNotInitialized
152 if isinstance(self.required_columns, str):
153 self.required_columns = [self.required_columns]
155 if self.where_clause == '':
156 where_clause = ' 1 '
157 else:
158 where_clause = self.where_clause.strip()
159 m = re.match('where', where_clause, re.I)
160 if m:
161 where_clause = where_clause.replace(m[0], '').strip()
163 # todo:
164 # sqlite does not allow to have column and table names parametrized.
165 # so we need to concatenate the strings.
166 # see https://www.sqlite.org/cintro.html#binding_parameters_and_reusing_prepared_statements
167 # we should actually do a check for SQL injection for those elements, but I have no idea at the moment how
168 # this could be implemented.
169 sql = f'SELECT {", ".join(self.required_columns)} FROM {self.table_name} WHERE ?'
170 params = (where_clause,)
172 data_frame = pd.read_sql(sql, con=self.database.connection(), params=params) # type: ignore[no-untyped-call]
174 self.data_frame = data_frame
176 def _attributes_valid(self) -> bool:
177 """Check if all required parameters are provided and valid."""
178 if self.table_name == '':
179 return False
180 if not self.required_columns:
181 return False
182 return True
184 @class_depends_on_optional('pandas')
185 class HDFPdDataRetriever(DataRetriever):
186 """
187 Retrieve a data frame from a HDF file
189 This data retriever is getting a dataframe from a HDF file provided the filename and the object key.
190 """
192 def __init__(
193 self, hdf_filename: str | Path | None = None, key: str | None = None, *args: Any, **kwargs: Any
194 ) -> None:
195 """
196 Constructor parameters:
198 :param hdf_filename: The filename of the HDF file
199 :type hdf_filename: str | Path, Optional
200 :param key: The key of the HDF store with the dataframe
201 :type key: str, Optional
202 """
203 super().__init__(*args, **kwargs)
204 self.hdf_filename: Path
205 if hdf_filename is None:
206 self.hdf_filename = Path()
207 else:
208 self.hdf_filename = Path(hdf_filename)
210 self.key: str
211 if key is None:
212 self.key = ''
213 else:
214 self.key = key
216 def get_data_frame(self) -> None:
217 """
218 Retrieve the dataframe from a HDF file
220 :raise PlotterMixinNotInitialized: if some of the required attributes are not initialised or invalid.
221 """
222 if not self._attributes_valid():
223 raise PlotterMixinNotInitialized
225 self.data_frame = typing.cast(pd.DataFrame, pd.read_hdf(self.hdf_filename, self.key))
227 def patch_data_frame(self) -> None:
228 super().patch_data_frame() # type: ignore[safe-super]
230 def _attributes_valid(self) -> bool:
231 if self.hdf_filename == Path():
232 return False
233 elif not self.hdf_filename.is_file():
234 # hdf is not a file or it does not exist.
235 log.warning('%s is not a valid HDF file' % self.hdf_filename)
236 return False
238 if self.key == '':
239 return False
241 return True
243 @class_depends_on_optional('seaborn;pandas')
244 class SNSFigurePlotter(FigurePlotter):
245 """Base mixin class to generate a seaborn Figure level plot"""
247 def __init__(self, *args: Any, **kwargs: Any) -> None:
248 self.data_frame: pd.DataFrame
249 """The dataframe instance shared with the main class"""
250 self.facet_grid: sns.FacetGrid
251 """The facet grid instance shared with the main class"""
252 super().__init__(*args, **kwargs)
254 def plot(self) -> None:
255 pass
257 def _attributes_valid(self) -> bool:
258 return True
260 @class_depends_on_optional('seaborn;pandas')
261 class RelPlot(SNSFigurePlotter):
262 """
263 The relational plot mixin.
265 This mixin will produce either a scatter or a line figure level plot.
267 The full documentation of the relplot object can be read at `this link <https://seaborn.pydata.org/generated/seaborn.relplot.html>`_.
268 """
270 def __init__(
271 self,
272 x: ColumnName | Iterable[float | complex | int] | None = None,
273 y: ColumnName | Iterable[float | complex | int] | None = None,
274 hue: ColumnName | Iterable[float | complex | int] | None = None,
275 row: ColumnName | Iterable[float | complex | int] | None = None,
276 col: ColumnName | Iterable[float | complex | int] | None = None,
277 palette: _Palette | matplotlib.colors.Colormap | None = None,
278 kind: typing.Literal['scatter', 'line'] = 'scatter',
279 legend: typing.Literal['auto', 'brief', 'full'] | bool = 'auto',
280 plot_kws: Mapping[str, Any] | None = None,
281 facet_kws: dict[str, Any] | None = None,
282 *args: Any,
283 **kwargs: Any,
284 ) -> None:
285 """
286 Constructor parameters:
288 :param x: The name of the x variable or an iterable containing the x values.
289 :type x: str | Iterable, Optional
290 :param y: The name of the y variable or an iterable containing the y values.
291 :type y: str | Iterable, Optional
292 :param hue: The name of the hue variable or an iterable containing the hue values.
293 :type hue: str | Iterable, Optional
294 :param row: The name of the row category or an iterable containing the row values.
295 :type row: str | Iterable, Optional
296 :param col: The name of the column category or an iterable containing the column values.
297 :type col: str | Iterable, Optional
298 :param palette: The colour palette to be used.
299 :type palette: str | Colormap, Optional
300 :param kind: The type of relational plot (scatter or line). Defaults to scatter.
301 :type kind: str, Optional
302 :param legend: How to draw the legend. If “brief”, numeric hue and size variables will be represented with a
303 sample of evenly spaced values. If “full”, every group will get an entry in the legend.
304 If “auto”, choose between brief or full representation based on number of levels.
305 If False, no legend data is added and no legend is drawn. Defaults to auto.
306 :type legend: str | bool, Optional
307 :param plot_kws: A dictionary like list of keywords passed to the underlying `seaborn.relplot
308 <https://seaborn.pydata.org/generated/seaborn.relplot.html#seaborn.relplot>`_.
309 :type plot_kws: dict[str, Any], Optional
310 :param facet_kws: A dictionary like list of keywords passed to the underlying
311 `seaborn.FacetGrid <https://seaborn.pydata.org/generated/seaborn.FacetGrid.html#seaborn-facetgrid>`_
312 :type facet_kws: dict[str, Any], Optional
313 """
314 super().__init__(*args, **kwargs)
315 self.x = x
316 self.y = y
317 self.hue = hue
318 self.row = row
319 self.col = col
320 self.palette = palette
321 self.kind = kind
322 self.legend = legend
323 self.plot_kws = plot_kws if plot_kws is not None else {}
324 self.facet_kws = facet_kws
326 def plot(self) -> None:
327 """Implements the plot method of a figure-level relational graph."""
328 self.facet_grid = sns.relplot(
329 data=self.data_frame,
330 x=self.x,
331 y=self.y,
332 hue=self.hue,
333 row=self.row,
334 col=self.col,
335 palette=self.palette,
336 kind=self.kind,
337 legend=self.legend,
338 facet_kws=self.facet_kws,
339 **self.plot_kws,
340 )
342 @class_depends_on_optional('seaborn;pandas')
343 class DisPlot(SNSFigurePlotter):
344 """
345 The distribution plot mixin.
347 This mixin is the MAFw implementation of the `seaborn displot
348 <https://seaborn.pydata.org/generated/seaborn.displot.html#seaborn.displot>`_ and will produce one of the following figure level plots:
350 * **histplot**: a simple `histogram
351 plot <https://seaborn.pydata.org/generated/seaborn.histplot.html#seaborn.histplot>`_
353 * **kdeplot**: a `kernel density <https://seaborn.pydata.org/generated/seaborn.kdeplot.html#seaborn.kdeplot>`_
354 estimate plot
356 * **ecdfplot**: an `empirical cumulative distribution functions
357 <https://seaborn.pydata.org/generated/seaborn.ecdfplot.html#seaborn.ecdfplot>`_ plot
359 * **rugplot**: a plot of the `marginal distributions
360 <https://seaborn.pydata.org/generated/seaborn.rugplot.html#seaborn.rugplot>`_ as ticks.
361 """
363 def __init__(
364 self,
365 x: ColumnName | Iterable[float | complex | int] | None = None,
366 y: ColumnName | Iterable[float | complex | int] | None = None,
367 hue: ColumnName | Iterable[float | complex | int] | None = None,
368 row: ColumnName | Iterable[float | complex | int] | None = None,
369 col: ColumnName | Iterable[float | complex | int] | None = None,
370 palette: _Palette | matplotlib.colors.Colormap | None = None,
371 kind: typing.Literal['hist', 'kde', 'ecdf'] = 'hist',
372 legend: bool = True,
373 rug: bool = False,
374 rug_kws: dict[str, Any] | None = None,
375 plot_kws: Mapping[str, Any] | None = None,
376 facet_kws: dict[str, Any] | None = None,
377 *args: Any,
378 **kwargs: Any,
379 ):
380 """
381 Constructor parameters:
383 :param x: The name of the x variable or an iterable containing the x values.
384 :type x: str | Iterable, Optional
385 :param y: The name of the y variable or an iterable containing the y values.
386 :type y: str | Iterable, Optional
387 :param hue: The name of the hue variable or an iterable containing the hue values.
388 :type hue: str | Iterable, Optional
389 :param row: The name of the row category or an iterable containing the row values.
390 :type row: str | Iterable, Optional
391 :param col: The name of the column category or an iterable containing the column values.
392 :type col: str | Iterable, Optional
393 :param palette: The colour palette to be used.
394 :type palette: str | Colormap, Optional
395 :param kind: The type of distribution plot (hist, kde or ecdf). Defaults to hist.
396 :type kind: str, Optional
397 :param legend: If false, suppress the legend for the semantic variables. Defaults to True.
398 :type legend: bool, Optional
399 :param rug: If true, show each observation with marginal ticks. Defaults to False.
400 :type rug: bool, Optional
401 :param rug_kws: Parameters to control the appearance of the rug plot.
402 :type rug_kws: Mapping[str, Any], Optional
403 :param plot_kws: Parameters passed to the underlying plotting object.
404 :type plot_kws: Mapping[str, Any], Optional
405 :param facet_kws: Parameters passed to the facet grid object.
406 :type facet_kws: Mapping[str, Any], Optional
407 """
408 super().__init__(*args, **kwargs)
409 self.x = x
410 self.y = y
411 self.hue = hue
412 self.row = row
413 self.col = col
414 self.palette = palette
415 self.kind = kind
416 self.legend = legend
417 self.rug = rug
418 self.rug_kws = rug_kws
419 self.plot_kws = plot_kws if plot_kws is not None else {}
420 self.facet_kws = facet_kws
422 def plot(self) -> None:
423 """Implements the plot method for a figure-level distribution graph"""
424 self.facet_grid = sns.displot(
425 data=self.data_frame,
426 x=self.x,
427 y=self.y,
428 hue=self.hue,
429 row=self.row,
430 col=self.col,
431 palette=self.palette,
432 kind=self.kind,
433 legend=self.legend,
434 rug=self.rug,
435 rug_kws=self.rug_kws,
436 facet_kws=self.facet_kws,
437 **self.plot_kws,
438 )
440 @class_depends_on_optional('seaborn;pandas')
441 class CatPlot(SNSFigurePlotter):
442 """
443 The categorical plot mixin.
445 This mixin will produce a figure level categorical plot as described `here
446 <https://seaborn.pydata.org/generated/seaborn.catplot.html>`_.
448 .. note:
450 By default this function treats one of the variables (typically x) as categorical, this means that even if
451 this variable is numeric, its value will not be considered. If you want to use the actual value of this
452 categorical variable, set native_scale = True.
453 """
455 def __init__(
456 self,
457 x: ColumnName | Iterable[float | complex | int] | None = None,
458 y: ColumnName | Iterable[float | complex | int] | None = None,
459 hue: ColumnName | Iterable[float | complex | int] | None = None,
460 row: ColumnName | Iterable[float | complex | int] | None = None,
461 col: ColumnName | Iterable[float | complex | int] | None = None,
462 palette: _Palette | None = None,
463 kind: typing.Literal['strip', 'swarm', 'box', 'violin', 'boxen', 'point', 'bar', 'count'] = 'strip',
464 legend: typing.Literal['auto', 'brief', 'full'] | bool = 'auto',
465 native_scale: bool = False,
466 plot_kws: Mapping[str, Any] | None = None,
467 facet_kws: dict[str, Any] | None = None,
468 *args: Any,
469 **kwargs: Any,
470 ) -> None:
471 """
472 Constructor parameters:
474 :param x: The name of the x variable or an iterable containing the x values.
475 :type x: str | Iterable, Optional
476 :param y: The name of the y variable or an iterable containing the y values.
477 :type y: str | Iterable, Optional
478 :param hue: The name of the hue variable or an iterable containing the hue values.
479 :type hue: str | Iterable, Optional
480 :param row: The name of the row category or an iterable containing the row values.
481 :type row: str | Iterable, Optional
482 :param col: The name of the column category or an iterable containing the column values.
483 :type col: str | Iterable, Optional
484 :param palette: The colour palette to be used.
485 :type palette: str, Optional
486 :param kind: The type of relational plot (scatter or line). Defaults to scatter.
487 :type kind: str, Optional
488 :param legend: How to draw the legend. If “brief”, numeric hue and size variables will be represented with a
489 sample of evenly spaced values. If “full”, every group will get an entry in the legend.
490 If “auto”, choose between brief or full representation based on number of levels.
491 If False, no legend data is added and no legend is drawn. Defaults to auto.
492 :type legend: str | bool, Optional
493 :param native_scale: When True, numeric or datetime values on the categorical axis will maintain their original
494 scaling rather than being converted to fixed indices. Defaults to False.
495 :type native_scale: bool, Optional
496 :param plot_kws: A dictionary like list of keywords passed to the underlying `seaborn.catplot
497 <https://seaborn.pydata.org/generated/seaborn.relplot.html#seaborn.catplot>`_.
498 :type plot_kws: dict[str, Any], Optional
499 :param facet_kws: A dictionary like list of keywords passed to the underlying
500 `seaborn.FacetGrid <https://seaborn.pydata.org/generated/seaborn.FacetGrid.html#seaborn-facetgrid>`_
501 :type facet_kws: dict[str, Any], Optional
502 """
503 super().__init__(*args, **kwargs)
504 self.x = x
505 self.y = y
506 self.hue = hue
507 self.row = row
508 self.col = col
509 self.palette = palette
510 self.kind = kind
511 self.legend = legend
512 self.native_scale = native_scale
513 self.plot_kws = plot_kws if plot_kws is not None else {}
514 self.facet_kws = facet_kws
516 def plot(self) -> None:
517 """Implements the plot method of a figure-level categorical graph."""
518 self.facet_grid = sns.catplot(
519 data=self.data_frame,
520 x=self.x,
521 y=self.y,
522 hue=self.hue,
523 row=self.row,
524 col=self.col,
525 palette=self.palette,
526 kind=self.kind,
527 legend=self.legend,
528 native_scale=self.native_scale,
529 facet_kws=self.facet_kws,
530 **self.plot_kws,
531 )
533 class LMPlot(SNSFigurePlotter):
534 """
535 The linear regression model plot mixin.
537 This mixin will produce a figure level regression model as described `here <https://seaborn.pydata.org/generated/seaborn.lmplot.html#seaborn.lmplot>`__
538 """
540 def __init__(
541 self,
542 x: str | None = None,
543 y: str | None = None,
544 hue: str | None = None,
545 row: str | None = None,
546 col: str | None = None,
547 palette: _Palette | None = None,
548 legend: bool = True,
549 scatter_kws: dict[str, Any] | None = None,
550 line_kws: dict[str, Any] | None = None,
551 facet_kws: dict[str, Any] | None = None,
552 *args: Any,
553 **kwargs: Any,
554 ) -> None:
555 """
556 Constructor parameters:
558 :param x: The name of the x variable or an iterable containing the x values.
559 :type x: str, Optional
560 :param y: The name of the y variable or an iterable containing the y values.
561 :type y: str, Optional
562 :param hue: The name of the hue variable or an iterable containing the hue values.
563 :type hue: str, Optional
564 :param row: The name of the row category or an iterable containing the row values.
565 :type row: str, Optional
566 :param col: The name of the column category or an iterable containing the column values.
567 :type col: str, Optional
568 :param palette: The colour palette to be used.
569 :type palette: str, Optional
570 :param legend: If True and there is a hue variable, add a legend.
571 :type legend: bool, Optional
572 :param scatter_kws: A dictionary like list of keywords passed to the underlying `scatter
573 <https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.scatter.html#matplotlib.pyplot.scatter>`_.
574 :type scatter_kws: dict[str, Any], Optional
575 :param scatter_kws: A dictionary like list of keywords passed to the underlying `scatter
576 <https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.scatter.html#matplotlib.pyplot.scatter>`_.
577 :type line_kws: dict[str, Any], Optional
578 :param line_kws: A dictionary like list of keywords passed to the underlying `plot
579 <https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.plot.html#matplotlib.pyplot.plot>`_.
580 :param facet_kws: A dictionary like list of keywords passed to the underlying
581 `seaborn.FacetGrid <https://seaborn.pydata.org/generated/seaborn.FacetGrid.html#seaborn-facetgrid>`_
582 :type facet_kws: dict[str, Any], Optional
583 """
584 super().__init__(*args, **kwargs)
585 self.x = x
586 self.y = y
587 self.hue = hue
588 self.row = row
589 self.col = col
590 self.palette = palette
591 self.legend = legend
592 self.scatter_kws = scatter_kws
593 self.line_kws = line_kws
594 self.facet_kws = facet_kws
595 self.other_kws = kwargs
597 def plot(self) -> None:
598 """Implements the plot method for a figure-level regression model."""
599 self.facet_grid = sns.lmplot(
600 data=self.data_frame,
601 x=self.x,
602 y=self.y,
603 hue=self.hue,
604 row=self.row,
605 col=self.col,
606 palette=self.palette,
607 legend=self.legend,
608 scatter_kws=self.scatter_kws,
609 line_kws=self.line_kws,
610 facet_kws=self.facet_kws,
611 **self.other_kws,
612 )
614 @processor_depends_on_optional(module_name='pandas;seaborn', warn=True, raise_ex=False)
615 class SNSPlotter(GenericPlotter):
616 """
617 The Generic Plotter processor.
619 This is a subclass of a Processor with advanced functionality to fetch data in the form of a dataframe and to
620 produce plots.
622 The key difference with respect to a normal processor is it :meth:`.process` method that has been already
623 implemented as follows:
625 .. literalinclude:: ../../../src/mafw/processor_library/sns_plotter.py
626 :pyobject: SNSPlotter.process
627 :dedent:
629 This actually means that when you are subclassing a SNSPlotter you do not have to implement the process method
630 as you would do for a normal Processor, but you will have to implement the following methods:
632 * :meth:`~.in_loop_customization`.
634 The processor execution workflow (LoopType) can be any of the available, so
635 actually the process method might be invoked only once, or multiple times inside a loop structure
636 (for or while).
637 If the execution is cyclic, then you may want to have the possibility to do some customisation for each
638 iteration, for example, changing the plot title, or modifying the data selection, or the filename where the
639 plots will be saved.
641 You can use this method also in case of a single loop processor, in this case you will not have access to
642 the loop parameters.
644 * :meth:`~.get_data_frame`.
646 This method has the task to get the data to be plotted in the form of a pandas DataFrame. The processor has
647 the :attr:`~.data_frame` attribute where the data will be stored to make them accessible from all other
648 methods.
650 * :meth:`~.GenericPlotter.patch_data_frame`.
652 A convenient method to apply data frame manipulation to the data just retrieved.
654 * :meth:`~.GenericPlotter.plot`.
656 This method is where the actual plotting occurs. Use the :attr:`~.data_frame` to plot the quantities
657 you want.
659 * :meth:`~.customize_plot`.
661 This method can be optionally used to customize the appearance of the facet grid produced by the
662 :meth:`~.GenericPlotter.plot` method. It is particularly useful when the user is mixing this class with one of the
663 :class:`~.FigurePlotter` mixin, thus not having direct access to the plot method.
665 * :meth:`~.save`.
667 This method is where the produced plot is saved in a file. Remember to append the output file name to the
668 :attr:`list of produced outputs <.output_filename_list>` so that the :meth:`~.GenericPlotter._update_plotter_db` method
669 will automatically store this file in the database during the :meth:`~.GenericPlotter.finish` execution.
671 * :meth:`~.update_db`.
673 If the user wants to update a specific table in the database, they can use this method.
675 It is worth reminding that all plotters are saving all generated files in the standard table PlotterOutput.
676 This is automatically done by the :meth:`~.GenericPlotter._update_plotter_db` method that is called in the
677 :meth:`~.GenericPlotter.finish` method.
679 You do not need to overload the :meth:`~.slice_data_frame` nor the :meth:`~.group_and_aggregate_data_frame`
680 methods, but you can simply use them by setting the :attr:`~.slicing_dict` and the :attr:`~.grouping_columns`
681 and the :attr:`~.aggregation_functions`.
683 The processor comes with two processors parameters that can be used by user-defined subclasses:
685 1. The output_folder that is the path where the output file will be saved
686 2. The force_replot flag to be used whether the user wants the plot to be regenerated even if the output
687 file already exists.
689 The default value of these parameters can be changed using the :attr:`.Processor.new_defaults` dictionary as
690 shown in :ref:`this example <parameter_inheritance>`.
691 """
693 def __init__(
694 self,
695 slicing_dict: MutableMapping[str, Any] | None = None,
696 grouping_columns: Iterable[str] | None = None,
697 aggregation_functions: Iterable[str | Callable[[Any], Any]] | None = None,
698 matplotlib_backend: str = 'agg',
699 *args: Any,
700 **kwargs: Any,
701 ) -> None:
702 """
703 Constructor parameters:
705 :param slicing_dict: A dictionary with key, value pairs to slice the input data frame before the plotting
706 occurs.
707 :type slicing_dict: dict[str, Any], Optional
708 :param grouping_columns: A list of columns for the groupby operation on the data frame.
709 :type grouping_columns: list[str], Optional
710 :param aggregation_functions: A list of functions for the aggregation on the grouped data frame.
711 :type aggregation_functions: list[str | Callable[[Any], Any], Optional
712 :param matplotlib_backend: The name of the matplotlib backend to be used. Defaults to 'Agg'
713 :type matplotlib_backend: str, Optional
714 :param output_folder: The path where the output file will be saved
715 :type output_folder: Path, Optional
716 :param force_replot: Whether to force re-plotting even if the output file already exists.
717 :type force_replot: bool, Optional
718 """
719 super().__init__(*args, **kwargs)
721 # attributes that can be set in the constructor
723 self.slicing_dict: MutableMapping[str, Any] | None = slicing_dict
724 """The dictionary for slicing the input data frame"""
726 self.grouping_columns: Iterable[str] | None = grouping_columns
727 """The list of columns for grouping the data frame"""
729 self.aggregation_functions: Iterable[str | Callable[[Any], Any]] | None = aggregation_functions
730 """The list of aggregation functions to be applied to the grouped dataframe"""
732 self.matplotlib_backend: str = matplotlib_backend.lower()
733 """The backend to be used for matplotlib."""
735 # internal use attributes.
737 self.data_frame: pd.DataFrame = pd.DataFrame()
738 """The pandas DataFrame containing the data to be plotted."""
740 self.output_filename_list: list[Path] = []
741 """The list of produced filenames."""
743 self.facet_grid: sns.FacetGrid | None = None
744 """The reference to the facet grid."""
746 # private attributes
748 # be sure that some additional methods if implemented are calling the super.
749 # TODO: check if this is really needed
750 self._methods_to_be_checked_for_super.extend([('patch_data_frame', SNSPlotter)]) # type: ignore[list-item]
752 def start(self) -> None:
753 """
754 Overload of the start method.
756 The :class:`~.SNSPlotter` is overloading the :meth:`~.Processor.start` in order to allow the user to
757 change the matplotlib backend.
759 The user can selected which backend to use either directly in the class constructor or assign it to the class
760 attribute :attr:`~.matplotlib_backend`.
761 """
762 super().start()
763 try:
764 if plt.get_backend().lower() != self.matplotlib_backend:
765 plt.switch_backend(self.matplotlib_backend)
766 except ModuleNotFoundError:
767 log.critical('%s is not a valid plt backend' % self.matplotlib_backend)
768 raise
770 def get_data_frame(self) -> None:
771 """
772 Specific implementation of the get data frame for the Seaborn plotter.
774 It must be overloaded.
776 The method is **NOT** returning the data_frame, but in your implementation you need to assign the data frame
777 to the class :attr:`.data_frame` attribute.
778 """
779 pass
781 def process(self) -> None:
782 """
783 Specific implementation of the process method for the Seaborn plotter.
785 It is almost the same as the GenericProcessor, with the addition that all open pyplot figures are closed
786 after the process is finished.
788 This part cannot be moved upward to the :class:`~.GenericPlotter` because the user might have selected
789 another plotting library different from :link:`matplotlib`.
790 """
791 super().process()
792 if not self.is_data_frame_empty():
793 plt.close('all')
795 def group_and_aggregate_data_frame(self) -> None:
796 """
797 Performs groupyby and aggregation of the data frame.
799 If the user provided some :attr:`grouping columns <.grouping_columns>` and :attr:`aggregation functions
800 <.aggregation_functions>` then the :func:`~.group_and_aggregate_data_frame` is invoked accordingly.
802 The user can update the values of those attributes during each cycle iteration within the implementation of
803 the :meth:`~.in_loop_customization`.
805 .. seealso::
806 This method is simply invoking the :func:`~.group_and_aggregate_data_frame` function from the :mod:`~.pandas_tools`.
807 """
808 if self.grouping_columns and self.aggregation_functions:
809 self.data_frame = group_and_aggregate_data_frame(
810 self.data_frame, self.grouping_columns, self.aggregation_functions
811 )
813 def is_data_frame_empty(self) -> bool:
814 if len(self.data_frame.index) == 0:
815 return True
816 return False
818 def slice_data_frame(self) -> None:
819 """
820 Perform data frame slicing
822 The user can set some slicing criteria in the :attr:`~.slicing_dict` to select some specific data subset. The
823 values of the slicing dict can be changed during each iteration within the implementation of the
824 :meth:`~.in_loop_customization`.
826 .. seealso::
827 This method is simply invoking the :func:`~.slice_data_frame` function from the :mod:`~.pandas_tools`.
828 """
829 if self.slicing_dict:
830 self.data_frame = slice_data_frame(self.data_frame, self.slicing_dict)
832except ImportError:
833 msg = (
834 'Trying to use the seaborn Plotter implementation without having installed the required dependencies.\n'
835 'Consider installing mafw with the optional feature seaborn. For example:\n'
836 '\npip install mafw[seaborn]\n\n'
837 )
838 warnings.warn(MissingOptionalDependency(msg), stacklevel=2)
839 raise