Source code for mafw.steering_gui.models.steering_tree_model

#  Copyright 2026 European Union
#  Author: Bulgheroni Antonio (antonio.bulgheroni@ec.europa.eu)
#  SPDX-License-Identifier: EUPL-1.2
"""Composite Qt model that embeds the processor pipeline under the steering tree.

:Author: Bulgheroni Antonio
:Description: Presents the fixed steering sections plus the editable processors subtree.
"""

from __future__ import annotations

import json
from dataclasses import dataclass, field
from enum import IntEnum
from pathlib import Path
from typing import TYPE_CHECKING, Sequence, Union, no_type_check

from PySide6.QtCore import (
    QAbstractItemModel,
    QByteArray,
    QMimeData,
    QModelIndex,
    QObject,
    QPersistentModelIndex,
    Qt,
)
from PySide6.QtGui import QColor, QIcon

from mafw.steering.models import ProcessorSchemaStatus

from .pipeline import PipelineItem, ProcessorPipeline
from .processor_pipeline_model import PipelineRoles, ProcessorPipelineModel

if TYPE_CHECKING:
    from mafw.steering_gui.controllers.steering_controller import SteeringController, SteeringControllerError
else:
    SteeringController = object
    SteeringControllerError = Exception

ModelIndex = Union[QModelIndex, QPersistentModelIndex]

_RESOURCE_DIR = Path(__file__).resolve().parent.parent / 'resources'


def _load_icon(name: str) -> QIcon:
    icon = QIcon(str(_RESOURCE_DIR / name))
    return icon if not icon.isNull() else QIcon()


_PROCESSOR_ICON = _load_icon('basil--processor-outline.svg')
_GROUP_ICON = _load_icon('clarity--file-group-line.svg')
_UNKNOWN_ICON = _load_icon('mingcute--file-unknown-line.svg')


[docs] @dataclass class _SectionNode: """Internal tree node for the fixed steering sections.""" label: str key: str children: list['_SectionNode'] = field(default_factory=list)
[docs] class SteeringTreeRoles(IntEnum): """Custom roles exported by the steering tree model.""" PIPELINE_ITEM = Qt.ItemDataRole.UserRole + 1 SECTION_KEY = Qt.ItemDataRole.UserRole + 2 IS_GROUP = Qt.ItemDataRole.UserRole + 3
[docs] class SteeringTreeModel(QAbstractItemModel): """Tree model exposing the steering sections and embedded processor pipeline.""" _mime_type = 'application/x-mafw-pipeline-path' def __init__( self, controller: SteeringController, pipeline: ProcessorPipeline | None = None, parent: QObject | None = None, ) -> None: super().__init__(parent) self._controller = controller self._pipeline_model = ProcessorPipelineModel(pipeline or ProcessorPipeline(), self) self._root = _SectionNode( label='Steering File', key='root', children=[ _SectionNode(label='Globals', key='globals'), _SectionNode(label='Database', key='database'), _SectionNode(label='UI', key='ui'), _SectionNode(label='Processors', key='processors'), ], ) # to avoid reconnecting multiple times signals to the wrapper widget self._expand_signals_connected = False @property def expand_signals_connected(self) -> bool: return self._expand_signals_connected @expand_signals_connected.setter def expand_signals_connected(self, value: bool) -> None: self._expand_signals_connected = value
[docs] def set_pipeline(self, pipeline: ProcessorPipeline) -> None: """Replace the pipeline subtree and reset the model.""" self.beginResetModel() self._pipeline_model.set_pipeline(pipeline) self.endResetModel()
[docs] def pipeline_item_from_index(self, index: ModelIndex) -> PipelineItem | None: """Return the pipeline item stored at the given index, if any.""" if not index.isValid(): return None pointer = index.internalPointer() return pointer if isinstance(pointer, PipelineItem) else None
[docs] def section_key_from_index(self, index: ModelIndex) -> str | None: """Return the section key stored at the given index, if any.""" if not index.isValid(): return None pointer = index.internalPointer() if isinstance(pointer, _SectionNode): return pointer.key return 'processors' if isinstance(pointer, PipelineItem) else None
def index(self, row: int, column: int, parent: ModelIndex = QModelIndex()) -> QModelIndex: if column != 0 or row < 0: return QModelIndex() if not parent.isValid(): if row != 0: return QModelIndex() return self.createIndex(0, 0, self._root) parent_pointer = parent.internalPointer() if isinstance(parent_pointer, _SectionNode): if parent_pointer.key == 'processors': items = self._pipeline_model.root_items() if row >= len(items): return QModelIndex() return self.createIndex(row, 0, items[row]) if row >= len(parent_pointer.children): return QModelIndex() return self.createIndex(row, 0, parent_pointer.children[row]) if isinstance(parent_pointer, PipelineItem): if row >= len(parent_pointer.children): return QModelIndex() return self.createIndex(row, 0, parent_pointer.children[row]) return QModelIndex() @no_type_check def parent(self, index: ModelIndex | None = None) -> QModelIndex | QObject | None: if index is None: return super().parent() if not index.isValid(): return QModelIndex() pointer = index.internalPointer() if pointer is self._root: return QModelIndex() if isinstance(pointer, _SectionNode): return self.createIndex(0, 0, self._root) if isinstance(pointer, PipelineItem): if pointer.parent is None: processors_node = next(child for child in self._root.children if child.key == 'processors') row = self._root.children.index(processors_node) return self.createIndex(row, 0, processors_node) parent_item = pointer.parent if parent_item.parent is None: row = self._pipeline_model.root_items().index(parent_item) else: row = parent_item.parent.children.index(parent_item) return self.createIndex(row, 0, parent_item) return QModelIndex() def rowCount(self, parent: ModelIndex = QModelIndex()) -> int: if not parent.isValid(): return 1 pointer = parent.internalPointer() if pointer is self._root: return len(self._root.children) if isinstance(pointer, _SectionNode): if pointer.key == 'processors': return len(self._pipeline_model.root_items()) return len(pointer.children) if isinstance(pointer, PipelineItem): return len(pointer.children) return 0 def columnCount(self, parent: ModelIndex = QModelIndex()) -> int: return 1 def data(self, index: ModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> object: if not index.isValid(): return None pointer = index.internalPointer() if isinstance(pointer, _SectionNode): if role == Qt.ItemDataRole.DisplayRole: return pointer.label if role in (Qt.ItemDataRole.UserRole, SteeringTreeRoles.SECTION_KEY): return pointer.key return None if isinstance(pointer, PipelineItem): name = pointer.name() is_unknown = ( not pointer.is_group() and getattr(pointer.config, 'processor_status', None) == ProcessorSchemaStatus.UNKNOWN ) if role == Qt.ItemDataRole.DisplayRole: return f'Group: {name}' if pointer.is_group() else name if role == Qt.ItemDataRole.DecorationRole: if is_unknown: return _UNKNOWN_ICON return _GROUP_ICON if pointer.is_group() else _PROCESSOR_ICON if role == Qt.ItemDataRole.ForegroundRole: if is_unknown: return QColor(Qt.GlobalColor.red) return None if role == SteeringTreeRoles.PIPELINE_ITEM: return pointer if role == SteeringTreeRoles.IS_GROUP: return pointer.is_group() if role == Qt.ItemDataRole.UserRole: return 'processors' if role == PipelineRoles.NAME: return pointer.name() return None return None def headerData(self, section: int, orientation: Qt.Orientation, role: int = Qt.ItemDataRole.DisplayRole) -> object: if orientation == Qt.Orientation.Horizontal and role == Qt.ItemDataRole.DisplayRole and section == 0: return 'Steering File' return None def flags(self, index: ModelIndex) -> Qt.ItemFlag: if not index.isValid(): return Qt.ItemFlag.ItemIsEnabled pointer = index.internalPointer() if isinstance(pointer, _SectionNode): flags = Qt.ItemFlag.ItemIsEnabled | Qt.ItemFlag.ItemIsSelectable if pointer.key == 'processors': flags |= Qt.ItemFlag.ItemIsDropEnabled return flags if isinstance(pointer, PipelineItem): flags = Qt.ItemFlag.ItemIsEnabled | Qt.ItemFlag.ItemIsSelectable | Qt.ItemFlag.ItemIsDragEnabled if pointer.is_group(): flags |= Qt.ItemFlag.ItemIsDropEnabled return flags return Qt.ItemFlag.ItemIsEnabled def mimeTypes(self) -> list[str]: return [self._mime_type] def mimeData(self, indexes: Sequence[ModelIndex]) -> QMimeData: data = QMimeData() paths: list[list[int]] = [] seen_ids: set[int] = set() for index in indexes: item = self.pipeline_item_from_index(index) if item is None: continue item_id = id(item) if item_id in seen_ids: continue seen_ids.add(item_id) path = self._path_for_item(item) if path is not None: paths.append(path) if paths: payload = json.dumps(paths).encode('utf-8') data.setData(self._mime_type, QByteArray(payload)) return data def dropMimeData( self, data: QMimeData, action: Qt.DropAction, row: int, column: int, parent: ModelIndex, ) -> bool: if action != Qt.DropAction.MoveAction: return False if not data.hasFormat(self._mime_type): return False if column not in (-1, 0): return False if not parent.isValid(): return False target_parent = parent.internalPointer() if isinstance(target_parent, _SectionNode): if target_parent.key != 'processors': return False target_group = None elif isinstance(target_parent, PipelineItem): if not target_parent.is_group(): return False target_group = target_parent.name() else: return False raw = data.data(self._mime_type) try: payload = bytes(raw.data()) except TypeError: return False try: paths = json.loads(payload.decode('utf-8')) except (json.JSONDecodeError, UnicodeDecodeError): return False if not isinstance(paths, list) or not paths: return False seen_ids: set[int] = set() entries: list[tuple[str, bool, str | None]] = [] for path in paths: item = self._item_from_path(path) if item is None: continue item_id = id(item) if item_id in seen_ids: continue seen_ids.add(item_id) source_group = item.parent.name() if item.parent and item.parent.is_group() else None entries.append((item.name(), item.is_group(), source_group)) if not entries: return False target_position = row if row >= 0 else self._target_append_index(target_group) inserted = 0 try: for entry_name, entry_is_group, source_group in entries: self._controller.move_pipeline_entry( entry_name=entry_name, entry_is_group=entry_is_group, source_group=source_group, target_group=target_group, position=target_position + inserted, ) inserted += 1 except SteeringControllerError: return False self.set_pipeline(self._controller.build_pipeline()) return True def supportedDropActions(self) -> Qt.DropAction: return Qt.DropAction.MoveAction def _path_for_item(self, item: PipelineItem) -> list[int] | None: indices: list[int] = [] current = item while current is not None: parent = current.parent if parent is None: try: indices.append(self._pipeline_model.root_items().index(current)) except ValueError: return None break try: indices.append(parent.children.index(current)) except ValueError: return None current = parent return list(reversed(indices)) def _item_from_path(self, path: list[int]) -> PipelineItem | None: if not path: return None items = self._pipeline_model.root_items() item: PipelineItem | None = None for idx in path: if idx < 0 or idx >= len(items): return None item = items[idx] items = item.children return item def _target_append_index(self, target_group: str | None) -> int: if target_group is None: return len(self._pipeline_model.root_items()) group = self._controller.get_group_snapshot(target_group) return len(group.processors)