Coverage for src / mafw / tools / pandas_tools.py: 100%
40 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-09 09:08 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-09 09:08 +0000
1# Copyright 2025 European Union
2# Author: Bulgheroni Antonio (antonio.bulgheroni@ec.europa.eu)
3# SPDX-License-Identifier: EUPL-1.2
4"""
5A collection of useful convenience functions for common pandas operations
6"""
8import typing
9import warnings
10from collections.abc import Callable, Iterable, MutableMapping
11from typing import Any
13from mafw.decorators import depends_on_optional
14from mafw.mafw_errors import MissingOptionalDependency
16try:
17 import pandas as pd
19 @depends_on_optional(module_name='pandas')
20 def slice_data_frame(
21 input_data_frame: pd.DataFrame, slicing_dict: MutableMapping[str, Any] | None = None, **kwargs: Any
22 ) -> pd.DataFrame:
23 """
24 Slice a data frame according to `slicing_dict`.
26 The input data frame will be sliced using the items of the `slicing_dict` applying the loc operator in this way:
27 :python:`sliced = input_data_frame[(input_data_frame[key]==value)]`.
29 If the slicing_dict is empty, then the full input_data_frame is returned.
31 Instead of the slicing_dict, the user can also provide key and value pairs as keyword arguments.
33 :python:`slice_data_frame(data_frame, {'A':14})`
35 is equivalent to
37 :python:`slice_data_frame(data_frame, A=14)`.
39 If the user provides a keyword argument that also exists in the slicing_dict, then the keyword argument will update
40 the slicing_dict.
42 No checks on the column name is done, should a label be missing, the loc method will raise a KeyError.
44 :param input_data_frame: The data frame to be sliced.
45 :type input_data_frame: pd.DataFrame
46 :param slicing_dict: A dictionary with columns and values for the slicing. Defaults to None
47 :type slicing_dict: dict, Optional
48 :param kwargs: Keyword arguments to be used instead of the slicing dictionary.
49 :return: The sliced dataframe
50 :rtype: pd.DataFrame
51 """
52 if slicing_dict is None:
53 slicing_dict = {}
55 slicing_dict.update(kwargs)
57 if not slicing_dict or len(input_data_frame) == 0:
58 return input_data_frame
60 sliced: pd.DataFrame = input_data_frame
61 for key, value in slicing_dict.items():
62 sliced = sliced.loc[(sliced[key] == value)]
64 return sliced
66 @depends_on_optional(module_name='pandas')
67 def group_and_aggregate_data_frame(
68 data_frame: pd.DataFrame,
69 grouping_columns: Iterable[str],
70 aggregation_functions: Iterable[str | Callable[[Any], Any]],
71 ) -> pd.DataFrame:
72 """
73 Utility function to perform dataframe groupby and aggregation.
75 This function is a simple wrapper to perform group by and aggregation operations on a dataframe. The user must
76 provide a list of columns to perform the group by on and a list of functions for the aggregation of the other
77 columns.
79 The output dataframe will have the aggregated columns renamed as originalname_aggregationfunction.
81 .. note::
82 Only numeric columns (and columns that can be aggregated) will be included in the aggregation.
83 String columns that are not used for grouping will be automatically excluded from aggregation.
85 :param data_frame: The input data frame
86 :type data_frame: pandas.DataFrame
87 :param grouping_columns: The list of columns to group by on.
88 :type grouping_columns: Iterable[str]
89 :param aggregation_functions: The list of functions to be used for the aggregation of the not grouped columns.
90 :type aggregation_functions: Iterable[str | Callable[[Any], Any]
91 :return: The aggregated dataframe after the groupby operation.
92 :rtype: pandas.DataFrame
93 """
94 # typing of this function is a nightmare.
95 # I have not understood anything about these errors
96 if grouping_columns:
97 grouped_df = data_frame.groupby(grouping_columns) # type: ignore
99 # Get columns that are not used for grouping
100 grouping_columns_list = list(grouping_columns)
101 non_grouping_columns = [col for col in data_frame.columns if col not in grouping_columns_list]
103 # Filter to only numeric/aggregatable columns
104 # We'll try to aggregate only numeric columns and datetime columns
105 aggregatable_columns = []
106 for col in non_grouping_columns:
107 if pd.api.types.is_numeric_dtype(data_frame[col]) or pd.api.types.is_datetime64_any_dtype(
108 data_frame[col]
109 ):
110 aggregatable_columns.append(col)
112 # If we have aggregatable columns, perform aggregation on them
113 if aggregatable_columns:
114 aggregated_df = typing.cast(
115 pd.DataFrame, grouped_df[aggregatable_columns].agg(aggregation_functions).reset_index()
116 )
117 chain = '_'
118 aggregated_df.columns = [chain.join(col).strip(chain) for col in aggregated_df.columns.values] # type: ignore
119 else:
120 # If no aggregatable columns, just return the grouped columns with their unique combinations
121 aggregated_df = typing.cast(pd.DataFrame, grouped_df.size().reset_index(name='count'))
123 else:
124 aggregated_df = data_frame
126 return aggregated_df
128except ImportError:
129 msg = (
130 'Trying to use the seaborn Plotter implementation without having installed the required dependencies.\n'
131 'Consider installing mafw with the optional feature seaborn. For example:\n'
132 '\npip install mafw[seaborn]\n\n'
133 )
134 warnings.warn(MissingOptionalDependency(msg), stacklevel=2)
135 raise