# Copyright 2025 European Union
# Author: Bulgheroni Antonio (antonio.bulgheroni@ec.europa.eu)
# SPDX-License-Identifier: EUPL-1.2
"""
Module provides a Trigger class and related tools to create triggers in the database via the ORM.
It supports SQLite, MySQL and PostgreSQL with dialect-specific SQL generation.
"""
from abc import ABC, abstractmethod
from enum import StrEnum
from typing import Any, Self, cast
import peewee
from peewee import Model
from mafw.db.db_types import PeeweeModelWithMeta
from mafw.mafw_errors import MissingSQLStatement, UnsupportedDatabaseError
from mafw.tools.regexp import normalize_sql_spaces
[docs]
def and_(*conditions: str) -> str:
"""
Concatenates conditions with logical AND.
:param conditions: The condition to join.
:type conditions: str
:return: The and-concatenated string of conditions
:rtype: str
"""
conditions_l = [f'({c})' for c in conditions]
return ' AND '.join(conditions_l)
[docs]
def or_(*conditions: str) -> str:
"""
Concatenates conditions with logical OR.
:param conditions: The condition to join.
:type conditions: str
:return: The or-concatenated string of conditions.
:rtype: str
"""
conditions_l = [f'({c})' for c in conditions]
return ' OR '.join(conditions_l)
[docs]
class TriggerWhen(StrEnum):
"""String enumerator for the trigger execution time (Before, After or Instead Of)"""
Before = 'BEFORE'
After = 'AFTER'
Instead = 'INSTEAD OF'
[docs]
class TriggerAction(StrEnum):
"""String enumerator for the trigger action (Delete, Insert, Update)"""
Delete = 'DELETE'
Insert = 'INSERT'
Update = 'UPDATE'
[docs]
class TriggerDialect(ABC):
"""Abstract base class for database-specific trigger SQL generation."""
[docs]
@abstractmethod
def create_trigger_sql(self, trigger: 'Trigger') -> str:
"""
Generate the SQL to create a trigger for a specific database dialect.
:param trigger: The trigger object
:return: SQL string to create the trigger
"""
pass # pragma: no cover
[docs]
@abstractmethod
def drop_trigger_sql(self, trigger_name: str, safe: bool = True) -> str:
"""
Generate the SQL to drop a trigger for a specific database dialect.
:param trigger_name: The name of the trigger to drop
:param safe: If True, add an IF EXISTS clause. Defaults to True.
:return: SQL string to drop the trigger
"""
pass # pragma: no cover
[docs]
@abstractmethod
def supports_trigger_type(self, when: TriggerWhen, action: TriggerAction, on_view: bool = False) -> bool:
"""
Check if the database supports the specified trigger type.
:param when: When the trigger should fire (BEFORE, AFTER, INSTEAD OF)
:param action: The action that triggers the trigger (INSERT, UPDATE, DELETE)
:param on_view: Whether the trigger is on a view
:return: True if supported, False otherwise
"""
pass # pragma: no cover
[docs]
@abstractmethod
def supports_safe_create(self) -> bool:
"""
Check if the database supports IF NOT EXISTS for triggers.
:return: True if supported, False otherwise
"""
pass # pragma: no cover
[docs]
@abstractmethod
def supports_update_of_columns(self) -> bool:
"""
Check if the database supports column-specific UPDATE triggers.
:return: True if supported, False otherwise
"""
pass # pragma: no cover
[docs]
@abstractmethod
def supports_when_clause(self) -> bool:
"""
Check if the database supports WHEN conditions.
:return: True if supported, False otherwise
"""
pass # pragma: no cover
[docs]
class SQLiteDialect(TriggerDialect):
"""SQLite-specific trigger SQL generation."""
[docs]
def create_trigger_sql(self, trigger: 'Trigger') -> str:
"""Generate SQLite trigger SQL."""
if_not_exists = 'IF NOT EXISTS' if trigger.safe else ''
of_columns = (
f'OF {", ".join(trigger.update_columns)}'
if trigger.trigger_action == TriggerAction.Update and trigger.update_columns
else ''
)
for_each_row = 'FOR EACH ROW' if trigger.for_each_row else ''
when_clause = f'WHEN {" AND ".join(trigger._when_list)}' if trigger._when_list else ''
sql_statements = '\n'.join(trigger._sql_list)
return normalize_sql_spaces(
f'CREATE TRIGGER {if_not_exists} {trigger.trigger_name}\n'
f'{trigger.trigger_when} {trigger.trigger_action} {of_columns} ON {trigger.target_table}\n'
f'{for_each_row} {when_clause}\n'
f'BEGIN\n'
f'{sql_statements}\n'
f'END;'
)
[docs]
def drop_trigger_sql(self, trigger_name: str, safe: bool = True) -> str:
"""Generate SQLite drop trigger SQL."""
return normalize_sql_spaces(f'DROP TRIGGER {"IF EXISTS" if safe else ""} {trigger_name}')
[docs]
def supports_trigger_type(self, when: TriggerWhen, action: TriggerAction, on_view: bool = False) -> bool:
"""SQLite supports all trigger types except INSTEAD OF on tables (only on views)."""
if when == TriggerWhen.Instead and not on_view:
return False
return True
[docs]
def supports_safe_create(self) -> bool:
"""SQLite supports IF NOT EXISTS for triggers."""
return True
[docs]
def supports_update_of_columns(self) -> bool:
"""SQLite supports column-specific UPDATE triggers."""
return True
[docs]
def supports_when_clause(self) -> bool:
"""SQLite supports WHEN conditions."""
return True
[docs]
class MySQLDialect(TriggerDialect):
"""MySQL-specific trigger SQL generation."""
[docs]
def create_trigger_sql(self, trigger: 'Trigger') -> str:
"""Generate MySQL trigger SQL."""
# MySQL doesn't support INSTEAD OF triggers
# MySQL doesn't support column-specific UPDATE triggers
# MySQL requires FOR EACH ROW
if_not_exists = 'IF NOT EXISTS' if trigger.safe else ''
# In MySQL, we need to convert WHEN conditions to IF/THEN/END IF blocks
sql_statements = []
# If there are conditional statements, wrap them in IF blocks
if trigger._when_list:
condition = ' AND '.join(trigger._when_list)
# Start the IF block
sql_statements.append(f'IF {condition} THEN')
# Add the SQL statements with indentation
for stmt in trigger._sql_list:
sql_statements.append(f' {stmt}')
# Close the IF block
sql_statements.append('END IF;')
else:
# No conditions, just add the SQL statements directly
sql_statements.extend(trigger._sql_list)
# Join all statements
trigger_body = '\n'.join(sql_statements)
# Construct the final SQL
sql = (
f'CREATE TRIGGER {if_not_exists} {trigger.trigger_name}\n'
f'{trigger.trigger_when} {trigger.trigger_action} ON {trigger.target_table}\n'
f'FOR EACH ROW\n'
f'BEGIN\n'
f'{trigger_body}\n'
f'END;'
)
return normalize_sql_spaces(sql)
[docs]
def drop_trigger_sql(self, trigger_name: str, safe: bool = True) -> str:
"""Generate MySQL drop trigger SQL."""
return normalize_sql_spaces(f'DROP TRIGGER {"IF EXISTS" if safe else ""} {trigger_name}')
[docs]
def supports_trigger_type(self, when: TriggerWhen, action: TriggerAction, on_view: bool = False) -> bool:
"""MySQL doesn't support INSTEAD OF triggers."""
return when != TriggerWhen.Instead
[docs]
def supports_safe_create(self) -> bool:
"""MySQL supports IF NOT EXISTS for triggers."""
return True
[docs]
def supports_update_of_columns(self) -> bool:
"""MySQL doesn't support column-specific UPDATE triggers."""
return False
[docs]
def supports_when_clause(self) -> bool:
"""MySQL supports conditions but through WHERE instead of WHEN."""
return True
[docs]
class PostgreSQLDialect(TriggerDialect):
"""PostgreSQL-specific trigger SQL generation."""
[docs]
def create_trigger_sql(self, trigger: 'Trigger') -> str:
"""Generate PostgreSQL trigger SQL."""
# PostgreSQL handles INSTEAD OF differently
# PostgreSQL uses functions for trigger bodies
function_name = f'fn_{trigger.trigger_name}'
# First create the function
function_sql = f'CREATE OR REPLACE FUNCTION {function_name}() RETURNS TRIGGER AS $$\nBEGIN\n'
# Add WHEN condition as IF statements if needed
if trigger._when_list:
when_condition = ' AND '.join(trigger._when_list)
function_sql += f' IF {when_condition} THEN\n'
# Indent SQL statements
sql_statements = '\n'.join([' ' + self._clean_sql(sql) for sql in trigger._sql_list])
function_sql += f'{sql_statements}\n END IF;\n'
else:
# Indent SQL statements
sql_statements = '\n'.join([' ' + self._clean_sql(sql) for sql in trigger._sql_list])
function_sql += f'{sql_statements}\n'
# For AFTER triggers, we need to return NULL or NEW
if trigger.trigger_when == TriggerWhen.After:
function_sql += ' RETURN NULL;\n'
# For BEFORE or INSTEAD OF triggers, we need to return NEW
else:
function_sql += ' RETURN NEW;\n'
function_sql += 'END;\n$$ LANGUAGE plpgsql;'
# Then create the trigger - PostgreSQL doesn't support IF NOT EXISTS for triggers before v14
# We'll handle this through a conditional drop
drop_if_exists = ''
if trigger.safe:
drop_if_exists = f'DROP TRIGGER IF EXISTS {trigger.trigger_name} ON {trigger.target_table} CASCADE;\n'
# PostgreSQL uses different syntax for INSTEAD OF (only allowed on views)
trigger_when = trigger.trigger_when
# Column-specific triggers in PostgreSQL
of_columns = (
f'OF {", ".join(trigger.update_columns)}'
if trigger.update_columns and trigger.trigger_action == TriggerAction.Update
else ''
)
for_each = 'FOR EACH ROW' if trigger.for_each_row else 'FOR EACH STATEMENT'
trigger_sql = (
f'{drop_if_exists}'
f'CREATE TRIGGER {trigger.trigger_name}\n'
f'{trigger_when} {trigger.trigger_action} {of_columns} ON {trigger.target_table}\n'
f'{for_each}\n'
f'EXECUTE FUNCTION {function_name}();'
)
sql = f'{function_sql}\n\n{trigger_sql}'
return normalize_sql_spaces(sql)
[docs]
def _clean_sql(self, sql: str) -> str:
"""
Remove RETURNING clauses from SQL statements for PostgreSQL trigger functions.
:param sql: The SQL statement
:return: SQL statement without RETURNING clause
"""
# Find the RETURNING clause position - case insensitive search
sql_upper = sql.upper()
returning_pos = sql_upper.find('RETURNING')
# If RETURNING exists, remove it and everything after it up to the semicolon
if returning_pos != -1:
semicolon_pos = sql.find(';', returning_pos)
if semicolon_pos != -1:
return sql[:returning_pos] + ';'
return sql[:returning_pos]
return sql
[docs]
def drop_trigger_sql(self, trigger_name: str, safe: bool = True) -> str:
"""Generate PostgreSQL drop trigger SQL."""
function_name = f'fn_{trigger_name}'
return normalize_sql_spaces(
f'DROP TRIGGER {"IF EXISTS" if safe else ""} {trigger_name} ON table_name;\n'
f'DROP FUNCTION {"IF EXISTS" if safe else ""} {function_name}();'
)
[docs]
def supports_trigger_type(self, when: TriggerWhen, action: TriggerAction, on_view: bool = False) -> bool:
"""PostgreSQL supports INSTEAD OF only on views."""
if when == TriggerWhen.Instead and not on_view:
return False
return True
[docs]
def supports_safe_create(self) -> bool:
"""PostgreSQL doesn't support IF NOT EXISTS for triggers before v14, but we implement safety differently."""
return True # We report True but handle it with DROP IF EXISTS
[docs]
def supports_update_of_columns(self) -> bool:
"""PostgreSQL supports column-specific UPDATE triggers."""
return True
[docs]
def supports_when_clause(self) -> bool:
"""PostgreSQL supports WHEN conditions."""
return True
[docs]
class Trigger:
"""Trigger template wrapper for use with peewee ORM."""
# noinspection PyProtectedMember
def __init__(
self,
trigger_name: str,
trigger_type: tuple[TriggerWhen, TriggerAction],
source_table: type[Model] | Model | str,
safe: bool = False,
for_each_row: bool = False,
update_columns: list[str] | None = None,
on_view: bool = False, # Added parameter to indicate if the target is a view
):
"""
Constructor parameters:
:param trigger_name: The name of this trigger. It needs to be unique!
:type trigger_name: str
:param trigger_type: A tuple with :class:`TriggerWhen` and :class:`TriggerAction` to specify on which action
the trigger should be invoked and if before, after or instead of.
:type trigger_type: tuple[TriggerWhen, TriggerAction]
:param source_table: The table originating the trigger. It can be a model class, instance, or also the name of
the table.
:type source_table: type[Model] | Model | str
:param safe: A boolean flag to define if in the trigger creation statement a 'IF NOT EXISTS' clause should be
included. Defaults to False
:type safe: bool, Optional
:param for_each_row: A boolean flag to repeat the script content for each modified row in the table.
Defaults to False.
:type for_each_row: bool, Optional
:param update_columns: A list of column names. When defining a trigger on a table update, it is possible to
restrict the firing of the trigger to the cases when a subset of all columns have been updated. An column
is updated also when the new value is equal to the old one. If you want to discriminate this case, use the
:meth:`add_when` method. Defaults to None.
:type update_columns: list[str], Optional
:param on_view: A boolean flag to indicate if the target is a view. This affects the support for INSTEAD OF.
Defaults to False.
:type on_view: bool, Optional
"""
self.trigger_name = trigger_name
self.trigger_type = trigger_type
self._trigger_when, self._trigger_op = self.trigger_type
self.update_columns = update_columns or []
self.on_view = on_view
if isinstance(source_table, type):
model_cls = cast(PeeweeModelWithMeta, source_table)
self.target_table = model_cls._meta.table_name
elif isinstance(source_table, Model):
model_instance = cast(PeeweeModelWithMeta, source_table)
self.target_table = model_instance._meta.table_name
else:
self.target_table = source_table
self.safe = safe
self.for_each_row = for_each_row
self._when_list: list[str] = []
self._sql_list: list[str] = []
self._database: peewee.Database | peewee.DatabaseProxy | None = None
self._dialect: TriggerDialect | None = None
@property
def trigger_action(self) -> TriggerAction:
return self._trigger_op
@trigger_action.setter
def trigger_action(self, action: TriggerAction) -> None:
self._trigger_op = action
@property
def trigger_when(self) -> TriggerWhen:
return self._trigger_when
@trigger_when.setter
def trigger_when(self, when: TriggerWhen) -> None:
self._trigger_when = when
def __setattr__(self, key: Any, value: Any) -> None:
if key == 'safe':
self.if_not_exists = 'IF NOT EXISTS' if value else ''
elif key == 'for_each_row':
self._for_each_row = 'FOR EACH ROW' if value else ''
else:
super().__setattr__(key, value)
def __getattr__(self, item: str) -> Any:
"""
Custom attribute getter for computed properties.
:param item: The attribute name to get
:return: The attribute value
:raises AttributeError: If the attribute doesn't exist
"""
if item == 'safe':
return hasattr(self, 'if_not_exists') and self.if_not_exists == 'IF NOT EXISTS'
elif item == 'for_each_row':
return hasattr(self, '_for_each_row') and self._for_each_row == 'FOR EACH ROW'
else:
# Raise AttributeError for non-existent attributes (standard Python behavior)
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'")
[docs]
def add_sql(self, sql: str | peewee.Query) -> Self:
"""
Add an SQL statement to be executed by the trigger.
The ``sql`` can be either a string containing the sql statement, or it can be any other peewee Query.
For example:
.. code-block:: python
# assuming you have created a trigger ...
sql = AnotherTable.insert(
field1=some_value, field2=another_value
)
trigger.add_sql(sql)
In this way the SQL code is generated with parametric placeholder if needed.
:param sql: The SQL statement.
:type sql: str | peewee.Query
:return: self for easy chaining
:rtype: Trigger
"""
if not isinstance(sql, str):
sql = str(sql)
sql = sql.strip()
sql = chr(9) + sql
if not sql.endswith(';'):
sql += ';'
self._sql_list.append(sql)
return self
[docs]
def add_when(self, *conditions: str) -> Self:
"""
Add conditions to the `when` statements.
Conditions are logically ANDed.
To have mixed `OR` and `AND` logic, use the functions :func:`and_` and :func:`or_`.
:param conditions: Conditions to be added with logical AND
:type conditions: str
:return: self for easy chaining
:rtype: Trigger
"""
conditions_l = [f'({c.strip()})' for c in conditions]
self._when_list.append(f'({" AND ".join(conditions_l)})')
return self
[docs]
def set_database(self, database: peewee.Database | peewee.DatabaseProxy) -> Self:
"""
Set the database to use for this trigger.
:param database: The database instance
:return: self for easy chaining
"""
self._database = database
return self
[docs]
def _get_dialect(self) -> TriggerDialect:
"""
Get the appropriate dialect based on the database type.
:return: A dialect instance
"""
if self._dialect is not None:
return self._dialect
if self._database is None:
# Default to SQLite dialect
return SQLiteDialect()
db = self._database
if isinstance(db, peewee.DatabaseProxy):
db = db.obj # Get the actual database from the proxy
if isinstance(db, peewee.SqliteDatabase):
self._dialect = SQLiteDialect()
elif isinstance(db, peewee.MySQLDatabase):
self._dialect = MySQLDialect()
elif isinstance(db, peewee.PostgresqlDatabase):
self._dialect = PostgreSQLDialect()
else:
raise UnsupportedDatabaseError(f'Unsupported database type: {type(db)}')
return self._dialect
[docs]
def create(self) -> str:
"""
Generates the SQL create statement.
:return: The trigger creation statement.
:raise MissingSQLStatement: if no SQL statements are provided.
:raise UnsupportedDatabaseError: if the trigger type is not supported by the database.
"""
if len(self._sql_list) == 0:
raise MissingSQLStatement('No SQL statements provided')
dialect = self._get_dialect()
# Check if the trigger type is supported
if not dialect.supports_trigger_type(self.trigger_when, self.trigger_action, self.on_view):
raise UnsupportedDatabaseError(
f'Trigger type {self.trigger_when} {self.trigger_action} is not supported by the database'
)
# Check if safe create is supported
if self.safe and not dialect.supports_safe_create():
# We can either ignore and continue without safe, or raise an error
# For now, we'll just ignore and continue
self.safe = False
# Check if update columns are supported
if self.update_columns and not dialect.supports_update_of_columns():
# We can either ignore and continue without column-specific updates, or raise an error
# For now, we'll ignore and continue
self.update_columns = []
# Generate the SQL
return dialect.create_trigger_sql(self)
[docs]
def drop(self, safe: bool = True) -> str:
"""
Generates the SQL drop statement.
:param safe: If True, add an IF EXIST. Defaults to True.
:type safe: bool, Optional
:return: The drop statement
:rtype: str
"""
dialect = self._get_dialect()
return dialect.drop_trigger_sql(self.trigger_name, safe)