Coverage for src / mafw / processor_library / sns_plotter.py: 99%

242 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-09 09:08 +0000

1# Copyright 2025 European Union 

2# Author: Bulgheroni Antonio (antonio.bulgheroni@ec.europa.eu) 

3# SPDX-License-Identifier: EUPL-1.2 

4""" 

5Module implements a Seaborn plotter processor with a mixin structure to generate seaborn plots. 

6 

7This module implements the :mod:`.abstract_plotter` functionalities using :link:`seaborn` and :link:`pandas`. 

8 

9These two packages are not installed in the default installation of MAFw, unless the user decided to include the 

10optional feature `seaborn`. 

11 

12Along with the :class:`SNSPlotter`, it includes a set of standard data retriever specific for pandas data frames. 

13""" 

14 

15import logging 

16import re 

17import typing 

18import warnings 

19from collections.abc import Callable, Iterable, Mapping, MutableMapping, Sequence 

20from pathlib import Path 

21from typing import Any, TypeAlias 

22 

23import peewee 

24 

25from mafw.decorators import class_depends_on_optional, processor_depends_on_optional 

26from mafw.mafw_errors import MissingOptionalDependency, PlotterMixinNotInitialized 

27from mafw.processor_library.abstract_plotter import DataRetriever, FigurePlotter, GenericPlotter 

28 

29log = logging.getLogger(__name__) 

30 

31try: 

32 import matplotlib.colors 

33 import matplotlib.pyplot as plt 

34 import pandas as pd 

35 import seaborn as sns 

36 from matplotlib.typing import ColorType 

37 from seaborn._core.typing import ColumnName 

38 

39 from mafw.tools.pandas_tools import group_and_aggregate_data_frame, slice_data_frame 

40 

41 # noinspection PyProtectedMember 

42 _Palette: TypeAlias = str | Sequence[ColorType] | Mapping[Any, ColorType] 

43 

44 @class_depends_on_optional('pandas') 

45 class PdDataRetriever(DataRetriever): 

46 def __init__(self, *args: Any, **kwargs: Any) -> None: 

47 super().__init__(*args, **kwargs) 

48 self.data_frame: pd.DataFrame 

49 

50 def get_data_frame(self) -> None: 

51 pass 

52 

53 def patch_data_frame(self) -> None: 

54 super().patch_data_frame() # type: ignore[safe-super] 

55 

56 def _attributes_valid(self) -> bool: 

57 return True 

58 

59 @class_depends_on_optional('pandas') 

60 class FromDatasetDataRetriever(PdDataRetriever): 

61 """ 

62 A data retriever to get a dataframe from a seaborn dataset 

63 """ 

64 

65 def __init__(self, dataset_name: str | None = None, *args: Any, **kwargs: Any) -> None: 

66 super().__init__(*args, **kwargs) 

67 self.dataset_name = dataset_name if dataset_name is not None else '' 

68 

69 def _attributes_valid(self) -> bool: 

70 """Checks if the attributes of the mixin are all valid""" 

71 if self.dataset_name == '': 

72 return False 

73 

74 return self.dataset_name in sns.get_dataset_names() 

75 

76 def get_data_frame(self) -> None: 

77 """Gets the data frame from the standard seaborn datasets""" 

78 if not self._attributes_valid(): 

79 raise PlotterMixinNotInitialized() 

80 self.data_frame = sns.load_dataset(self.dataset_name) 

81 

82 @class_depends_on_optional('pandas') 

83 class SQLPdDataRetriever(PdDataRetriever): 

84 """ 

85 A specialized data retriever to get a data frame from a database table. 

86 

87 The idea is to implement an interface to the pandas ``read_sql``. The user has to provide the :attr:`table name 

88 <.table_name>`, the :attr:`the list of required columns <.required_columns>` and an optional :attr:`where clause 

89 <.where_clause>`. 

90 """ 

91 

92 database: peewee.Database 

93 """The database instance. It comes from the main class""" 

94 

95 def __init__( 

96 self, 

97 table_name: str | None = None, 

98 required_cols: Iterable[str] | str | None = None, 

99 where_clause: str | None = None, 

100 *args: Any, 

101 **kwargs: Any, 

102 ) -> None: 

103 """ 

104 Constructor parameters: 

105 

106 :param table_name: The name of the table from where to get the data 

107 :type table_name: str, Optional 

108 :param required_cols: A list of columns to be selected from the table and transferred as column in the dataframe. 

109 :type required_cols: Iterable[str] | str | None, Optional 

110 :param where_clause: The where clause used in the select SQL statement. If None is provided, then all rows will 

111 be selected. 

112 :type where_clause: str, Optional 

113 """ 

114 super().__init__(*args, **kwargs) 

115 self.table_name: str 

116 """The table from where the data should be taken.""" 

117 if table_name is None: 

118 self.table_name = '' 

119 else: 

120 self.table_name = table_name 

121 

122 self.required_columns: Iterable[str] 

123 """ 

124 The iterable of columns. 

125  

126 Those are the column names to be selected from the :attr:`~.table_name` and included in the dataframe.  

127 """ 

128 if required_cols is None: 

129 self.required_columns = [] 

130 elif isinstance(required_cols, str): 

131 self.required_columns = [required_cols] 

132 else: 

133 self.required_columns = required_cols 

134 

135 self.where_clause: str 

136 """The where clause of the SQL statement""" 

137 

138 if where_clause is None: 

139 self.where_clause = '1' 

140 else: 

141 self.where_clause = where_clause 

142 

143 def get_data_frame(self) -> None: 

144 """ 

145 Retrieve the dataframe from a database table. 

146 

147 :raise PlotterMixinNotInitialized: If some of the required attributes are missing. 

148 """ 

149 if not self._attributes_valid(): 

150 raise PlotterMixinNotInitialized 

151 

152 if isinstance(self.required_columns, str): 

153 self.required_columns = [self.required_columns] 

154 

155 if self.where_clause == '': 

156 where_clause = ' 1 ' 

157 else: 

158 where_clause = self.where_clause.strip() 

159 m = re.match('where', where_clause, re.I) 

160 if m: 

161 where_clause = where_clause.replace(m[0], '').strip() 

162 

163 # todo: 

164 # sqlite does not allow to have column and table names parametrized. 

165 # so we need to concatenate the strings. 

166 # see https://www.sqlite.org/cintro.html#binding_parameters_and_reusing_prepared_statements 

167 # we should actually do a check for SQL injection for those elements, but I have no idea at the moment how 

168 # this could be implemented. 

169 sql = f'SELECT {", ".join(self.required_columns)} FROM {self.table_name} WHERE ?' 

170 params = (where_clause,) 

171 

172 data_frame = pd.read_sql(sql, con=self.database.connection(), params=params) # type: ignore[no-untyped-call] 

173 

174 self.data_frame = data_frame 

175 

176 def _attributes_valid(self) -> bool: 

177 """Check if all required parameters are provided and valid.""" 

178 if self.table_name == '': 

179 return False 

180 if not self.required_columns: 

181 return False 

182 return True 

183 

184 @class_depends_on_optional('pandas') 

185 class HDFPdDataRetriever(DataRetriever): 

186 """ 

187 Retrieve a data frame from a HDF file 

188 

189 This data retriever is getting a dataframe from a HDF file provided the filename and the object key. 

190 """ 

191 

192 def __init__( 

193 self, hdf_filename: str | Path | None = None, key: str | None = None, *args: Any, **kwargs: Any 

194 ) -> None: 

195 """ 

196 Constructor parameters: 

197 

198 :param hdf_filename: The filename of the HDF file 

199 :type hdf_filename: str | Path, Optional 

200 :param key: The key of the HDF store with the dataframe 

201 :type key: str, Optional 

202 """ 

203 super().__init__(*args, **kwargs) 

204 self.hdf_filename: Path 

205 if hdf_filename is None: 

206 self.hdf_filename = Path() 

207 else: 

208 self.hdf_filename = Path(hdf_filename) 

209 

210 self.key: str 

211 if key is None: 

212 self.key = '' 

213 else: 

214 self.key = key 

215 

216 def get_data_frame(self) -> None: 

217 """ 

218 Retrieve the dataframe from a HDF file 

219 

220 :raise PlotterMixinNotInitialized: if some of the required attributes are not initialised or invalid. 

221 """ 

222 if not self._attributes_valid(): 

223 raise PlotterMixinNotInitialized 

224 

225 self.data_frame = typing.cast(pd.DataFrame, pd.read_hdf(self.hdf_filename, self.key)) 

226 

227 def patch_data_frame(self) -> None: 

228 super().patch_data_frame() # type: ignore[safe-super] 

229 

230 def _attributes_valid(self) -> bool: 

231 if self.hdf_filename == Path(): 

232 return False 

233 elif not self.hdf_filename.is_file(): 

234 # hdf is not a file or it does not exist. 

235 log.warning('%s is not a valid HDF file' % self.hdf_filename) 

236 return False 

237 

238 if self.key == '': 

239 return False 

240 

241 return True 

242 

243 @class_depends_on_optional('seaborn;pandas') 

244 class SNSFigurePlotter(FigurePlotter): 

245 """Base mixin class to generate a seaborn Figure level plot""" 

246 

247 def __init__(self, *args: Any, **kwargs: Any) -> None: 

248 self.data_frame: pd.DataFrame 

249 """The dataframe instance shared with the main class""" 

250 self.facet_grid: sns.FacetGrid 

251 """The facet grid instance shared with the main class""" 

252 super().__init__(*args, **kwargs) 

253 

254 def plot(self) -> None: 

255 pass 

256 

257 def _attributes_valid(self) -> bool: 

258 return True 

259 

260 @class_depends_on_optional('seaborn;pandas') 

261 class RelPlot(SNSFigurePlotter): 

262 """ 

263 The relational plot mixin. 

264 

265 This mixin will produce either a scatter or a line figure level plot. 

266 

267 The full documentation of the relplot object can be read at `this link <https://seaborn.pydata.org/generated/seaborn.relplot.html>`_. 

268 """ 

269 

270 def __init__( 

271 self, 

272 x: ColumnName | Iterable[float | complex | int] | None = None, 

273 y: ColumnName | Iterable[float | complex | int] | None = None, 

274 hue: ColumnName | Iterable[float | complex | int] | None = None, 

275 row: ColumnName | Iterable[float | complex | int] | None = None, 

276 col: ColumnName | Iterable[float | complex | int] | None = None, 

277 palette: _Palette | matplotlib.colors.Colormap | None = None, 

278 kind: typing.Literal['scatter', 'line'] = 'scatter', 

279 legend: typing.Literal['auto', 'brief', 'full'] | bool = 'auto', 

280 plot_kws: Mapping[str, Any] | None = None, 

281 facet_kws: dict[str, Any] | None = None, 

282 *args: Any, 

283 **kwargs: Any, 

284 ) -> None: 

285 """ 

286 Constructor parameters: 

287 

288 :param x: The name of the x variable or an iterable containing the x values. 

289 :type x: str | Iterable, Optional 

290 :param y: The name of the y variable or an iterable containing the y values. 

291 :type y: str | Iterable, Optional 

292 :param hue: The name of the hue variable or an iterable containing the hue values. 

293 :type hue: str | Iterable, Optional 

294 :param row: The name of the row category or an iterable containing the row values. 

295 :type row: str | Iterable, Optional 

296 :param col: The name of the column category or an iterable containing the column values. 

297 :type col: str | Iterable, Optional 

298 :param palette: The colour palette to be used. 

299 :type palette: str | Colormap, Optional 

300 :param kind: The type of relational plot (scatter or line). Defaults to scatter. 

301 :type kind: str, Optional 

302 :param legend: How to draw the legend. If “brief”, numeric hue and size variables will be represented with a 

303 sample of evenly spaced values. If “full”, every group will get an entry in the legend. 

304 If “auto”, choose between brief or full representation based on number of levels. 

305 If False, no legend data is added and no legend is drawn. Defaults to auto. 

306 :type legend: str | bool, Optional 

307 :param plot_kws: A dictionary like list of keywords passed to the underlying `seaborn.relplot 

308 <https://seaborn.pydata.org/generated/seaborn.relplot.html#seaborn.relplot>`_. 

309 :type plot_kws: dict[str, Any], Optional 

310 :param facet_kws: A dictionary like list of keywords passed to the underlying 

311 `seaborn.FacetGrid <https://seaborn.pydata.org/generated/seaborn.FacetGrid.html#seaborn-facetgrid>`_ 

312 :type facet_kws: dict[str, Any], Optional 

313 """ 

314 super().__init__(*args, **kwargs) 

315 self.x = x 

316 self.y = y 

317 self.hue = hue 

318 self.row = row 

319 self.col = col 

320 self.palette = palette 

321 self.kind = kind 

322 self.legend = legend 

323 self.plot_kws = plot_kws if plot_kws is not None else {} 

324 self.facet_kws = facet_kws 

325 

326 def plot(self) -> None: 

327 """Implements the plot method of a figure-level relational graph.""" 

328 self.facet_grid = sns.relplot( 

329 data=self.data_frame, 

330 x=self.x, 

331 y=self.y, 

332 hue=self.hue, 

333 row=self.row, 

334 col=self.col, 

335 palette=self.palette, 

336 kind=self.kind, 

337 legend=self.legend, 

338 facet_kws=self.facet_kws, 

339 **self.plot_kws, 

340 ) 

341 

342 @class_depends_on_optional('seaborn;pandas') 

343 class DisPlot(SNSFigurePlotter): 

344 """ 

345 The distribution plot mixin. 

346 

347 This mixin is the MAFw implementation of the `seaborn displot 

348 <https://seaborn.pydata.org/generated/seaborn.displot.html#seaborn.displot>`_ and will produce one of the following figure level plots: 

349 

350 * **histplot**: a simple `histogram 

351 plot <https://seaborn.pydata.org/generated/seaborn.histplot.html#seaborn.histplot>`_ 

352 

353 * **kdeplot**: a `kernel density <https://seaborn.pydata.org/generated/seaborn.kdeplot.html#seaborn.kdeplot>`_ 

354 estimate plot 

355 

356 * **ecdfplot**: an `empirical cumulative distribution functions 

357 <https://seaborn.pydata.org/generated/seaborn.ecdfplot.html#seaborn.ecdfplot>`_ plot 

358 

359 * **rugplot**: a plot of the `marginal distributions 

360 <https://seaborn.pydata.org/generated/seaborn.rugplot.html#seaborn.rugplot>`_ as ticks. 

361 """ 

362 

363 def __init__( 

364 self, 

365 x: ColumnName | Iterable[float | complex | int] | None = None, 

366 y: ColumnName | Iterable[float | complex | int] | None = None, 

367 hue: ColumnName | Iterable[float | complex | int] | None = None, 

368 row: ColumnName | Iterable[float | complex | int] | None = None, 

369 col: ColumnName | Iterable[float | complex | int] | None = None, 

370 palette: _Palette | matplotlib.colors.Colormap | None = None, 

371 kind: typing.Literal['hist', 'kde', 'ecdf'] = 'hist', 

372 legend: bool = True, 

373 rug: bool = False, 

374 rug_kws: dict[str, Any] | None = None, 

375 plot_kws: Mapping[str, Any] | None = None, 

376 facet_kws: dict[str, Any] | None = None, 

377 *args: Any, 

378 **kwargs: Any, 

379 ): 

380 """ 

381 Constructor parameters: 

382 

383 :param x: The name of the x variable or an iterable containing the x values. 

384 :type x: str | Iterable, Optional 

385 :param y: The name of the y variable or an iterable containing the y values. 

386 :type y: str | Iterable, Optional 

387 :param hue: The name of the hue variable or an iterable containing the hue values. 

388 :type hue: str | Iterable, Optional 

389 :param row: The name of the row category or an iterable containing the row values. 

390 :type row: str | Iterable, Optional 

391 :param col: The name of the column category or an iterable containing the column values. 

392 :type col: str | Iterable, Optional 

393 :param palette: The colour palette to be used. 

394 :type palette: str | Colormap, Optional 

395 :param kind: The type of distribution plot (hist, kde or ecdf). Defaults to hist. 

396 :type kind: str, Optional 

397 :param legend: If false, suppress the legend for the semantic variables. Defaults to True. 

398 :type legend: bool, Optional 

399 :param rug: If true, show each observation with marginal ticks. Defaults to False. 

400 :type rug: bool, Optional 

401 :param rug_kws: Parameters to control the appearance of the rug plot. 

402 :type rug_kws: Mapping[str, Any], Optional 

403 :param plot_kws: Parameters passed to the underlying plotting object. 

404 :type plot_kws: Mapping[str, Any], Optional 

405 :param facet_kws: Parameters passed to the facet grid object. 

406 :type facet_kws: Mapping[str, Any], Optional 

407 """ 

408 super().__init__(*args, **kwargs) 

409 self.x = x 

410 self.y = y 

411 self.hue = hue 

412 self.row = row 

413 self.col = col 

414 self.palette = palette 

415 self.kind = kind 

416 self.legend = legend 

417 self.rug = rug 

418 self.rug_kws = rug_kws 

419 self.plot_kws = plot_kws if plot_kws is not None else {} 

420 self.facet_kws = facet_kws 

421 

422 def plot(self) -> None: 

423 """Implements the plot method for a figure-level distribution graph""" 

424 self.facet_grid = sns.displot( 

425 data=self.data_frame, 

426 x=self.x, 

427 y=self.y, 

428 hue=self.hue, 

429 row=self.row, 

430 col=self.col, 

431 palette=self.palette, 

432 kind=self.kind, 

433 legend=self.legend, 

434 rug=self.rug, 

435 rug_kws=self.rug_kws, 

436 facet_kws=self.facet_kws, 

437 **self.plot_kws, 

438 ) 

439 

440 @class_depends_on_optional('seaborn;pandas') 

441 class CatPlot(SNSFigurePlotter): 

442 """ 

443 The categorical plot mixin. 

444 

445 This mixin will produce a figure level categorical plot as described `here 

446 <https://seaborn.pydata.org/generated/seaborn.catplot.html>`_. 

447 

448 .. note: 

449 

450 By default this function treats one of the variables (typically x) as categorical, this means that even if 

451 this variable is numeric, its value will not be considered. If you want to use the actual value of this 

452 categorical variable, set native_scale = True. 

453 """ 

454 

455 def __init__( 

456 self, 

457 x: ColumnName | Iterable[float | complex | int] | None = None, 

458 y: ColumnName | Iterable[float | complex | int] | None = None, 

459 hue: ColumnName | Iterable[float | complex | int] | None = None, 

460 row: ColumnName | Iterable[float | complex | int] | None = None, 

461 col: ColumnName | Iterable[float | complex | int] | None = None, 

462 palette: _Palette | None = None, 

463 kind: typing.Literal['strip', 'swarm', 'box', 'violin', 'boxen', 'point', 'bar', 'count'] = 'strip', 

464 legend: typing.Literal['auto', 'brief', 'full'] | bool = 'auto', 

465 native_scale: bool = False, 

466 plot_kws: Mapping[str, Any] | None = None, 

467 facet_kws: dict[str, Any] | None = None, 

468 *args: Any, 

469 **kwargs: Any, 

470 ) -> None: 

471 """ 

472 Constructor parameters: 

473 

474 :param x: The name of the x variable or an iterable containing the x values. 

475 :type x: str | Iterable, Optional 

476 :param y: The name of the y variable or an iterable containing the y values. 

477 :type y: str | Iterable, Optional 

478 :param hue: The name of the hue variable or an iterable containing the hue values. 

479 :type hue: str | Iterable, Optional 

480 :param row: The name of the row category or an iterable containing the row values. 

481 :type row: str | Iterable, Optional 

482 :param col: The name of the column category or an iterable containing the column values. 

483 :type col: str | Iterable, Optional 

484 :param palette: The colour palette to be used. 

485 :type palette: str, Optional 

486 :param kind: The type of relational plot (scatter or line). Defaults to scatter. 

487 :type kind: str, Optional 

488 :param legend: How to draw the legend. If “brief”, numeric hue and size variables will be represented with a 

489 sample of evenly spaced values. If “full”, every group will get an entry in the legend. 

490 If “auto”, choose between brief or full representation based on number of levels. 

491 If False, no legend data is added and no legend is drawn. Defaults to auto. 

492 :type legend: str | bool, Optional 

493 :param native_scale: When True, numeric or datetime values on the categorical axis will maintain their original 

494 scaling rather than being converted to fixed indices. Defaults to False. 

495 :type native_scale: bool, Optional 

496 :param plot_kws: A dictionary like list of keywords passed to the underlying `seaborn.catplot 

497 <https://seaborn.pydata.org/generated/seaborn.relplot.html#seaborn.catplot>`_. 

498 :type plot_kws: dict[str, Any], Optional 

499 :param facet_kws: A dictionary like list of keywords passed to the underlying 

500 `seaborn.FacetGrid <https://seaborn.pydata.org/generated/seaborn.FacetGrid.html#seaborn-facetgrid>`_ 

501 :type facet_kws: dict[str, Any], Optional 

502 """ 

503 super().__init__(*args, **kwargs) 

504 self.x = x 

505 self.y = y 

506 self.hue = hue 

507 self.row = row 

508 self.col = col 

509 self.palette = palette 

510 self.kind = kind 

511 self.legend = legend 

512 self.native_scale = native_scale 

513 self.plot_kws = plot_kws if plot_kws is not None else {} 

514 self.facet_kws = facet_kws 

515 

516 def plot(self) -> None: 

517 """Implements the plot method of a figure-level categorical graph.""" 

518 self.facet_grid = sns.catplot( 

519 data=self.data_frame, 

520 x=self.x, 

521 y=self.y, 

522 hue=self.hue, 

523 row=self.row, 

524 col=self.col, 

525 palette=self.palette, 

526 kind=self.kind, 

527 legend=self.legend, 

528 native_scale=self.native_scale, 

529 facet_kws=self.facet_kws, 

530 **self.plot_kws, 

531 ) 

532 

533 class LMPlot(SNSFigurePlotter): 

534 """ 

535 The linear regression model plot mixin. 

536 

537 This mixin will produce a figure level regression model as described `here <https://seaborn.pydata.org/generated/seaborn.lmplot.html#seaborn.lmplot>`__ 

538 """ 

539 

540 def __init__( 

541 self, 

542 x: str | None = None, 

543 y: str | None = None, 

544 hue: str | None = None, 

545 row: str | None = None, 

546 col: str | None = None, 

547 palette: _Palette | None = None, 

548 legend: bool = True, 

549 scatter_kws: dict[str, Any] | None = None, 

550 line_kws: dict[str, Any] | None = None, 

551 facet_kws: dict[str, Any] | None = None, 

552 *args: Any, 

553 **kwargs: Any, 

554 ) -> None: 

555 """ 

556 Constructor parameters: 

557 

558 :param x: The name of the x variable or an iterable containing the x values. 

559 :type x: str, Optional 

560 :param y: The name of the y variable or an iterable containing the y values. 

561 :type y: str, Optional 

562 :param hue: The name of the hue variable or an iterable containing the hue values. 

563 :type hue: str, Optional 

564 :param row: The name of the row category or an iterable containing the row values. 

565 :type row: str, Optional 

566 :param col: The name of the column category or an iterable containing the column values. 

567 :type col: str, Optional 

568 :param palette: The colour palette to be used. 

569 :type palette: str, Optional 

570 :param legend: If True and there is a hue variable, add a legend. 

571 :type legend: bool, Optional 

572 :param scatter_kws: A dictionary like list of keywords passed to the underlying `scatter 

573 <https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.scatter.html#matplotlib.pyplot.scatter>`_. 

574 :type scatter_kws: dict[str, Any], Optional 

575 :param scatter_kws: A dictionary like list of keywords passed to the underlying `scatter 

576 <https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.scatter.html#matplotlib.pyplot.scatter>`_. 

577 :type line_kws: dict[str, Any], Optional 

578 :param line_kws: A dictionary like list of keywords passed to the underlying `plot 

579 <https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.plot.html#matplotlib.pyplot.plot>`_. 

580 :param facet_kws: A dictionary like list of keywords passed to the underlying 

581 `seaborn.FacetGrid <https://seaborn.pydata.org/generated/seaborn.FacetGrid.html#seaborn-facetgrid>`_ 

582 :type facet_kws: dict[str, Any], Optional 

583 """ 

584 super().__init__(*args, **kwargs) 

585 self.x = x 

586 self.y = y 

587 self.hue = hue 

588 self.row = row 

589 self.col = col 

590 self.palette = palette 

591 self.legend = legend 

592 self.scatter_kws = scatter_kws 

593 self.line_kws = line_kws 

594 self.facet_kws = facet_kws 

595 self.other_kws = kwargs 

596 

597 def plot(self) -> None: 

598 """Implements the plot method for a figure-level regression model.""" 

599 self.facet_grid = sns.lmplot( 

600 data=self.data_frame, 

601 x=self.x, 

602 y=self.y, 

603 hue=self.hue, 

604 row=self.row, 

605 col=self.col, 

606 palette=self.palette, 

607 legend=self.legend, 

608 scatter_kws=self.scatter_kws, 

609 line_kws=self.line_kws, 

610 facet_kws=self.facet_kws, 

611 **self.other_kws, 

612 ) 

613 

614 @processor_depends_on_optional(module_name='pandas;seaborn', warn=True, raise_ex=False) 

615 class SNSPlotter(GenericPlotter): 

616 """ 

617 The Generic Plotter processor. 

618 

619 This is a subclass of a Processor with advanced functionality to fetch data in the form of a dataframe and to 

620 produce plots. 

621 

622 The key difference with respect to a normal processor is it :meth:`.process` method that has been already 

623 implemented as follows: 

624 

625 .. literalinclude:: ../../../src/mafw/processor_library/sns_plotter.py 

626 :pyobject: SNSPlotter.process 

627 :dedent: 

628 

629 This actually means that when you are subclassing a SNSPlotter you do not have to implement the process method 

630 as you would do for a normal Processor, but you will have to implement the following methods: 

631 

632 * :meth:`~.in_loop_customization`. 

633 

634 The processor execution workflow (LoopType) can be any of the available, so 

635 actually the process method might be invoked only once, or multiple times inside a loop structure 

636 (for or while). 

637 If the execution is cyclic, then you may want to have the possibility to do some customisation for each 

638 iteration, for example, changing the plot title, or modifying the data selection, or the filename where the 

639 plots will be saved. 

640 

641 You can use this method also in case of a single loop processor, in this case you will not have access to 

642 the loop parameters. 

643 

644 * :meth:`~.get_data_frame`. 

645 

646 This method has the task to get the data to be plotted in the form of a pandas DataFrame. The processor has 

647 the :attr:`~.data_frame` attribute where the data will be stored to make them accessible from all other 

648 methods. 

649 

650 * :meth:`~.GenericPlotter.patch_data_frame`. 

651 

652 A convenient method to apply data frame manipulation to the data just retrieved. 

653 

654 * :meth:`~.GenericPlotter.plot`. 

655 

656 This method is where the actual plotting occurs. Use the :attr:`~.data_frame` to plot the quantities 

657 you want. 

658 

659 * :meth:`~.customize_plot`. 

660 

661 This method can be optionally used to customize the appearance of the facet grid produced by the 

662 :meth:`~.GenericPlotter.plot` method. It is particularly useful when the user is mixing this class with one of the 

663 :class:`~.FigurePlotter` mixin, thus not having direct access to the plot method. 

664 

665 * :meth:`~.save`. 

666 

667 This method is where the produced plot is saved in a file. Remember to append the output file name to the 

668 :attr:`list of produced outputs <.output_filename_list>` so that the :meth:`~.GenericPlotter._update_plotter_db` method 

669 will automatically store this file in the database during the :meth:`~.GenericPlotter.finish` execution. 

670 

671 * :meth:`~.update_db`. 

672 

673 If the user wants to update a specific table in the database, they can use this method. 

674 

675 It is worth reminding that all plotters are saving all generated files in the standard table PlotterOutput. 

676 This is automatically done by the :meth:`~.GenericPlotter._update_plotter_db` method that is called in the 

677 :meth:`~.GenericPlotter.finish` method. 

678 

679 You do not need to overload the :meth:`~.slice_data_frame` nor the :meth:`~.group_and_aggregate_data_frame` 

680 methods, but you can simply use them by setting the :attr:`~.slicing_dict` and the :attr:`~.grouping_columns` 

681 and the :attr:`~.aggregation_functions`. 

682 

683 The processor comes with two processors parameters that can be used by user-defined subclasses: 

684 

685 1. The output_folder that is the path where the output file will be saved 

686 2. The force_replot flag to be used whether the user wants the plot to be regenerated even if the output 

687 file already exists. 

688 

689 The default value of these parameters can be changed using the :attr:`.Processor.new_defaults` dictionary as 

690 shown in :ref:`this example <parameter_inheritance>`. 

691 """ 

692 

693 def __init__( 

694 self, 

695 slicing_dict: MutableMapping[str, Any] | None = None, 

696 grouping_columns: Iterable[str] | None = None, 

697 aggregation_functions: Iterable[str | Callable[[Any], Any]] | None = None, 

698 matplotlib_backend: str = 'agg', 

699 *args: Any, 

700 **kwargs: Any, 

701 ) -> None: 

702 """ 

703 Constructor parameters: 

704 

705 :param slicing_dict: A dictionary with key, value pairs to slice the input data frame before the plotting 

706 occurs. 

707 :type slicing_dict: dict[str, Any], Optional 

708 :param grouping_columns: A list of columns for the groupby operation on the data frame. 

709 :type grouping_columns: list[str], Optional 

710 :param aggregation_functions: A list of functions for the aggregation on the grouped data frame. 

711 :type aggregation_functions: list[str | Callable[[Any], Any], Optional 

712 :param matplotlib_backend: The name of the matplotlib backend to be used. Defaults to 'Agg' 

713 :type matplotlib_backend: str, Optional 

714 :param output_folder: The path where the output file will be saved 

715 :type output_folder: Path, Optional 

716 :param force_replot: Whether to force re-plotting even if the output file already exists. 

717 :type force_replot: bool, Optional 

718 """ 

719 super().__init__(*args, **kwargs) 

720 

721 # attributes that can be set in the constructor 

722 

723 self.slicing_dict: MutableMapping[str, Any] | None = slicing_dict 

724 """The dictionary for slicing the input data frame""" 

725 

726 self.grouping_columns: Iterable[str] | None = grouping_columns 

727 """The list of columns for grouping the data frame""" 

728 

729 self.aggregation_functions: Iterable[str | Callable[[Any], Any]] | None = aggregation_functions 

730 """The list of aggregation functions to be applied to the grouped dataframe""" 

731 

732 self.matplotlib_backend: str = matplotlib_backend.lower() 

733 """The backend to be used for matplotlib.""" 

734 

735 # internal use attributes. 

736 

737 self.data_frame: pd.DataFrame = pd.DataFrame() 

738 """The pandas DataFrame containing the data to be plotted.""" 

739 

740 self.output_filename_list: list[Path] = [] 

741 """The list of produced filenames.""" 

742 

743 self.facet_grid: sns.FacetGrid | None = None 

744 """The reference to the facet grid.""" 

745 

746 # private attributes 

747 

748 # be sure that some additional methods if implemented are calling the super. 

749 # TODO: check if this is really needed 

750 self._methods_to_be_checked_for_super.extend([('patch_data_frame', SNSPlotter)]) # type: ignore[list-item] 

751 

752 def start(self) -> None: 

753 """ 

754 Overload of the start method. 

755 

756 The :class:`~.SNSPlotter` is overloading the :meth:`~.Processor.start` in order to allow the user to 

757 change the matplotlib backend. 

758 

759 The user can selected which backend to use either directly in the class constructor or assign it to the class 

760 attribute :attr:`~.matplotlib_backend`. 

761 """ 

762 super().start() 

763 try: 

764 if plt.get_backend().lower() != self.matplotlib_backend: 

765 plt.switch_backend(self.matplotlib_backend) 

766 except ModuleNotFoundError: 

767 log.critical('%s is not a valid plt backend' % self.matplotlib_backend) 

768 raise 

769 

770 def get_data_frame(self) -> None: 

771 """ 

772 Specific implementation of the get data frame for the Seaborn plotter. 

773 

774 It must be overloaded. 

775 

776 The method is **NOT** returning the data_frame, but in your implementation you need to assign the data frame 

777 to the class :attr:`.data_frame` attribute. 

778 """ 

779 pass 

780 

781 def process(self) -> None: 

782 """ 

783 Specific implementation of the process method for the Seaborn plotter. 

784 

785 It is almost the same as the GenericProcessor, with the addition that all open pyplot figures are closed 

786 after the process is finished. 

787 

788 This part cannot be moved upward to the :class:`~.GenericPlotter` because the user might have selected 

789 another plotting library different from :link:`matplotlib`. 

790 """ 

791 super().process() 

792 if not self.is_data_frame_empty(): 

793 plt.close('all') 

794 

795 def group_and_aggregate_data_frame(self) -> None: 

796 """ 

797 Performs groupyby and aggregation of the data frame. 

798 

799 If the user provided some :attr:`grouping columns <.grouping_columns>` and :attr:`aggregation functions 

800 <.aggregation_functions>` then the :func:`~.group_and_aggregate_data_frame` is invoked accordingly. 

801 

802 The user can update the values of those attributes during each cycle iteration within the implementation of 

803 the :meth:`~.in_loop_customization`. 

804 

805 .. seealso:: 

806 This method is simply invoking the :func:`~.group_and_aggregate_data_frame` function from the :mod:`~.pandas_tools`. 

807 """ 

808 if self.grouping_columns and self.aggregation_functions: 

809 self.data_frame = group_and_aggregate_data_frame( 

810 self.data_frame, self.grouping_columns, self.aggregation_functions 

811 ) 

812 

813 def is_data_frame_empty(self) -> bool: 

814 if len(self.data_frame.index) == 0: 

815 return True 

816 return False 

817 

818 def slice_data_frame(self) -> None: 

819 """ 

820 Perform data frame slicing 

821 

822 The user can set some slicing criteria in the :attr:`~.slicing_dict` to select some specific data subset. The 

823 values of the slicing dict can be changed during each iteration within the implementation of the 

824 :meth:`~.in_loop_customization`. 

825 

826 .. seealso:: 

827 This method is simply invoking the :func:`~.slice_data_frame` function from the :mod:`~.pandas_tools`. 

828 """ 

829 if self.slicing_dict: 

830 self.data_frame = slice_data_frame(self.data_frame, self.slicing_dict) 

831 

832except ImportError: 

833 msg = ( 

834 'Trying to use the seaborn Plotter implementation without having installed the required dependencies.\n' 

835 'Consider installing mafw with the optional feature seaborn. For example:\n' 

836 '\npip install mafw[seaborn]\n\n' 

837 ) 

838 warnings.warn(MissingOptionalDependency(msg), stacklevel=2) 

839 raise