Coverage for src / mafw / tools / pandas_tools.py: 100%

40 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-09 09:08 +0000

1# Copyright 2025 European Union 

2# Author: Bulgheroni Antonio (antonio.bulgheroni@ec.europa.eu) 

3# SPDX-License-Identifier: EUPL-1.2 

4""" 

5A collection of useful convenience functions for common pandas operations 

6""" 

7 

8import typing 

9import warnings 

10from collections.abc import Callable, Iterable, MutableMapping 

11from typing import Any 

12 

13from mafw.decorators import depends_on_optional 

14from mafw.mafw_errors import MissingOptionalDependency 

15 

16try: 

17 import pandas as pd 

18 

19 @depends_on_optional(module_name='pandas') 

20 def slice_data_frame( 

21 input_data_frame: pd.DataFrame, slicing_dict: MutableMapping[str, Any] | None = None, **kwargs: Any 

22 ) -> pd.DataFrame: 

23 """ 

24 Slice a data frame according to `slicing_dict`. 

25 

26 The input data frame will be sliced using the items of the `slicing_dict` applying the loc operator in this way: 

27 :python:`sliced = input_data_frame[(input_data_frame[key]==value)]`. 

28 

29 If the slicing_dict is empty, then the full input_data_frame is returned. 

30 

31 Instead of the slicing_dict, the user can also provide key and value pairs as keyword arguments. 

32 

33 :python:`slice_data_frame(data_frame, {'A':14})` 

34 

35 is equivalent to 

36 

37 :python:`slice_data_frame(data_frame, A=14)`. 

38 

39 If the user provides a keyword argument that also exists in the slicing_dict, then the keyword argument will update 

40 the slicing_dict. 

41 

42 No checks on the column name is done, should a label be missing, the loc method will raise a KeyError. 

43 

44 :param input_data_frame: The data frame to be sliced. 

45 :type input_data_frame: pd.DataFrame 

46 :param slicing_dict: A dictionary with columns and values for the slicing. Defaults to None 

47 :type slicing_dict: dict, Optional 

48 :param kwargs: Keyword arguments to be used instead of the slicing dictionary. 

49 :return: The sliced dataframe 

50 :rtype: pd.DataFrame 

51 """ 

52 if slicing_dict is None: 

53 slicing_dict = {} 

54 

55 slicing_dict.update(kwargs) 

56 

57 if not slicing_dict or len(input_data_frame) == 0: 

58 return input_data_frame 

59 

60 sliced: pd.DataFrame = input_data_frame 

61 for key, value in slicing_dict.items(): 

62 sliced = sliced.loc[(sliced[key] == value)] 

63 

64 return sliced 

65 

66 @depends_on_optional(module_name='pandas') 

67 def group_and_aggregate_data_frame( 

68 data_frame: pd.DataFrame, 

69 grouping_columns: Iterable[str], 

70 aggregation_functions: Iterable[str | Callable[[Any], Any]], 

71 ) -> pd.DataFrame: 

72 """ 

73 Utility function to perform dataframe groupby and aggregation. 

74 

75 This function is a simple wrapper to perform group by and aggregation operations on a dataframe. The user must 

76 provide a list of columns to perform the group by on and a list of functions for the aggregation of the other 

77 columns. 

78 

79 The output dataframe will have the aggregated columns renamed as originalname_aggregationfunction. 

80 

81 .. note:: 

82 Only numeric columns (and columns that can be aggregated) will be included in the aggregation. 

83 String columns that are not used for grouping will be automatically excluded from aggregation. 

84 

85 :param data_frame: The input data frame 

86 :type data_frame: pandas.DataFrame 

87 :param grouping_columns: The list of columns to group by on. 

88 :type grouping_columns: Iterable[str] 

89 :param aggregation_functions: The list of functions to be used for the aggregation of the not grouped columns. 

90 :type aggregation_functions: Iterable[str | Callable[[Any], Any] 

91 :return: The aggregated dataframe after the groupby operation. 

92 :rtype: pandas.DataFrame 

93 """ 

94 # typing of this function is a nightmare. 

95 # I have not understood anything about these errors 

96 if grouping_columns: 

97 grouped_df = data_frame.groupby(grouping_columns) # type: ignore 

98 

99 # Get columns that are not used for grouping 

100 grouping_columns_list = list(grouping_columns) 

101 non_grouping_columns = [col for col in data_frame.columns if col not in grouping_columns_list] 

102 

103 # Filter to only numeric/aggregatable columns 

104 # We'll try to aggregate only numeric columns and datetime columns 

105 aggregatable_columns = [] 

106 for col in non_grouping_columns: 

107 if pd.api.types.is_numeric_dtype(data_frame[col]) or pd.api.types.is_datetime64_any_dtype( 

108 data_frame[col] 

109 ): 

110 aggregatable_columns.append(col) 

111 

112 # If we have aggregatable columns, perform aggregation on them 

113 if aggregatable_columns: 

114 aggregated_df = typing.cast( 

115 pd.DataFrame, grouped_df[aggregatable_columns].agg(aggregation_functions).reset_index() 

116 ) 

117 chain = '_' 

118 aggregated_df.columns = [chain.join(col).strip(chain) for col in aggregated_df.columns.values] # type: ignore 

119 else: 

120 # If no aggregatable columns, just return the grouped columns with their unique combinations 

121 aggregated_df = typing.cast(pd.DataFrame, grouped_df.size().reset_index(name='count')) 

122 

123 else: 

124 aggregated_df = data_frame 

125 

126 return aggregated_df 

127 

128except ImportError: 

129 msg = ( 

130 'Trying to use the seaborn Plotter implementation without having installed the required dependencies.\n' 

131 'Consider installing mafw with the optional feature seaborn. For example:\n' 

132 '\npip install mafw[seaborn]\n\n' 

133 ) 

134 warnings.warn(MissingOptionalDependency(msg), stacklevel=2) 

135 raise