diff --git a/src/node_helpers.py b/src/node_helpers.py index 035e6b5..52bae34 100644 --- a/src/node_helpers.py +++ b/src/node_helpers.py @@ -35,9 +35,9 @@ class ParentAwareNode(Node): return aware_node -def walk_nodes(node: ParentAwareNode) -> Iterator[ParentAwareNode]: +def walk_nodes(node: Node) -> Iterator[Node]: """Yields every node in the tree, including the root one (pre-order DFS).""" - stack: list[ParentAwareNode] = [node] + stack: list[Node] = [node] while len(stack) > 0: cur_node = stack.pop() yield cur_node diff --git a/src/parser.py b/src/parser.py index be86700..d1982c4 100644 --- a/src/parser.py +++ b/src/parser.py @@ -61,22 +61,6 @@ class SimpleSelector: return cls(tag, cls_selectors, id_selectors) - @staticmethod - def _match_attrs(attr: Literal["class", "id"], sel: set[str], node: ParentAwareNode) -> bool: - """Check that a given node contains all selected attributes of given kind.""" - attr_str = node.attributes.get(attr) - attrs: set[str] = set(attr_str.split(" ")) if attr_str else set() - return sel.issubset(attrs) - - def match_node(self, node: ParentAwareNode) -> bool: - if self.tag and node.tag != self.tag: - return False - - if not self._match_attrs("class", set(self.classes), node): - return False - - return self._match_attrs("id", set(self.ids), node) - @override def __str__(self) -> str: classes = ".".join(cls_name for cls_name in self.classes) @@ -159,18 +143,6 @@ class DescendantSelector: child: SimpleSelector | ConcretePseudoClassSelector | DescendantSelector | SiblingSelector direct: bool # descendant (" ") vs direct child (">") - def match_node(self, node: ParentAwareNode) -> bool: - if not self.child.match_node(node): - return False - - if not node.parent: - return False - - if self.direct: - return self.parent.match_node(node.parent) - - return any(self.parent.match_node(parent) for parent in walk_parents(node)) - @override def __str__(self) -> str: symbol = " > " if self.direct else " " @@ -183,29 +155,6 @@ class SiblingSelector: selector: SimpleSelector | ConcretePseudoClassSelector | DescendantSelector | SiblingSelector is_adjacent: bool # adjacent sibling ("+"), subsequent sibling ("~") - def match_node(self, node: ParentAwareNode) -> bool: - if not self.selector.match_node(node): - return False - - if not node.parent: - return False - - for i, sibling in enumerate(node.parent.children): - if sibling == node: # NOTE: is check might be safer here - child_index = i - break - else: # nobreak - raise AssertionError("Parent node doesn't contain it's child") # pragma: no cover - - if child_index == 0: - return False # no previous siblings - - if self.is_adjacent: - sibling = node.parent.children[child_index - 1] - return self.sibling_selector.match_node(sibling) - - return any(self.sibling_selector.match_node(sibling) for sibling in node.parent.children[:child_index]) - @override def __str__(self) -> str: symbol = " + " if self.is_adjacent else " ~ " @@ -219,9 +168,6 @@ type NonMultiSelector = SimpleSelector | ConcretePseudoClassSelector | Descendan class MultiSelector: selectors: list[NonMultiSelector] - def match_node(self, node: ParentAwareNode) -> bool: - return any(selector.match_node(node) for selector in self.selectors) - @override def __str__(self) -> str: return ", ".join(str(sel) for sel in self.selectors) @@ -248,12 +194,6 @@ class NotPseudoClassSelector: not_selector = parse_tokens(TokenStream(sel.argument)) return cls(sel.selector, not_selector) - def match_node(self, node: ParentAwareNode) -> bool: - if self.selector and not self.selector.match_node(node): - return False - - return not self.not_selector.match_node(node) - @override def __str__(self) -> str: sel = str(self.selector) if self.selector else "" @@ -295,24 +235,6 @@ class NthChildPseudoClassSelector: return cls(sel.selector, n) - def match_node(self, node: ParentAwareNode) -> bool: - if node.parent is None: - return False - - # The N indicates n-th child, but python indexes start at 0, subtract 1. - # Unless it's negative, in which case stay as-is (allowing n=-1) - child_index = self.n - 1 if self.n > 0 else self.n - - try: - nth_child = node.parent.children[child_index] - except IndexError: - return False - - if nth_child != node: - return False - - return not (self.selector and not self.selector.match_node(node)) - @override def __str__(self) -> str: if self.n == 1: diff --git a/src/qualifier.py b/src/qualifier.py index 8465d07..1a41859 100644 --- a/src/qualifier.py +++ b/src/qualifier.py @@ -2,9 +2,10 @@ from __future__ import annotations from typing import TYPE_CHECKING -from .node_helpers import ParentAwareNode, walk_nodes +from .node_helpers import ParentAwareNode from .parser import parse_tokens from .tokenizer import TokenStream, tokenize_selector +from .visitor import find_all if TYPE_CHECKING: from .node import Node @@ -16,4 +17,4 @@ def query_selector_all(node: Node, selector: str) -> list[Node]: toks = TokenStream(tokenize_selector(selector)) sel = parse_tokens(toks) - return [n.original for n in walk_nodes(root) if sel.match_node(n)] + return [node.original for node in find_all(sel, root)] diff --git a/src/visitor.py b/src/visitor.py new file mode 100644 index 0000000..dbd965d --- /dev/null +++ b/src/visitor.py @@ -0,0 +1,176 @@ +from collections.abc import Iterable, Iterator +from typing import Literal + +from src.node import Node +from src.node_helpers import ParentAwareNode, walk_nodes +from src.parser import ( + AnySelector, + DescendantSelector, + MultiSelector, + NotPseudoClassSelector, + NthChildPseudoClassSelector, + SiblingSelector, + SimpleSelector, +) + + +def _match_attrs(attr: Literal["class", "id"], sel: set[str], node: Node) -> bool: + """Check that a given node contains all selected attributes of given kind.""" + attr_str = node.attributes.get(attr) + attrs: set[str] = set(attr_str.split(" ")) if attr_str else set() + return sel.issubset(attrs) + + +def _simple_selector_match(sel: SimpleSelector, node: Node) -> bool: + if sel.tag and node.tag != sel.tag: + return False + + if not _match_attrs("class", set(sel.classes), node): + return False + + return _match_attrs("id", set(sel.ids), node) + + +def visit_simple_selector( + selector: SimpleSelector, + context: Iterable[ParentAwareNode], + recurse: bool = True, +) -> Iterator[ParentAwareNode]: + for node in context: + if recurse: + nodes = walk_nodes(node) # will include node too + else: + nodes = [node] + + for subnode in nodes: + if _simple_selector_match(selector, subnode): + yield subnode + + +def visit_descendant_selector( + selector: DescendantSelector, + context: Iterable[ParentAwareNode], + recurse: bool = True, +) -> Iterator[ParentAwareNode]: + context = visit(selector.parent, context, recurse=recurse) # get all applicable parents + + descendants = (child_node for node in context for child_node in node.children) + + if not selector.direct: + descendants = (subnode for node in descendants for subnode in walk_nodes(node)) + + yield from visit(selector.child, descendants, recurse=False) + + +def visit_sibling_selector( + selector: SiblingSelector, + context: Iterable[ParentAwareNode], + recurse: bool = True, +) -> Iterator[ParentAwareNode]: + for node in visit(selector.selector, context, recurse=recurse): + if node.parent is None: # no siblings if no parent + continue + + for i, sibling in enumerate(node.parent.children): + if sibling == node: + child_index = i + break + else: # nobreak + raise AssertionError("Parent node doesn't contain it's child") # pragma: no cover + + if child_index == 0: + continue # no previous siblings + + if selector.is_adjacent: + sibling = node.parent.children[child_index - 1] + siblings = [sibling] + else: + siblings = node.parent.children[:child_index] + + if any(visit(selector.sibling_selector, siblings, recurse=False)): + yield node + + +def visit_multi_selector( + selector: MultiSelector, + context: Iterable[ParentAwareNode], + recurse: bool = True, +) -> Iterator[ParentAwareNode]: + for sel in selector.selectors: + yield from visit(sel, context, recurse=recurse) + + +def visit_not_pseudo_class_selector( + selector: NotPseudoClassSelector, + context: Iterable[ParentAwareNode], + recurse: bool = True, +) -> Iterator[ParentAwareNode]: + if selector.selector: + context = visit(selector.selector, context, recurse=recurse) + elif recurse: + context = (subnode for node in context for subnode in walk_nodes(node)) + + for node in context: + if any(visit(selector.not_selector, [node], recurse=False)): + continue # exclude this node (not condition matched) + yield node + + +def visit_nth_child_pseudo_class_selector( + selector: NthChildPseudoClassSelector, + context: Iterable[ParentAwareNode], + recurse: bool = True, +) -> Iterator[ParentAwareNode]: + if selector.selector: + context = visit(selector.selector, context, recurse=recurse) + elif recurse: + context = (subnode for node in context for subnode in walk_nodes(node)) + + for node in context: + if node.parent is None: + continue + + # The N indicates n-th child, but python indexes start at 0, subtract 1. + # Unless it's negative, in which case stay as-is (allowing n=-1) + child_index = selector.n - 1 if selector.n > 0 else selector.n + + try: + nth_child = node.parent.children[child_index] + except IndexError: + continue + + if nth_child != node: + continue + + yield node + + +def visit( + selector: AnySelector, + context: Iterable[ParentAwareNode], + recurse: bool = True, +) -> Iterator[ParentAwareNode]: + match selector: + case MultiSelector(): + yield from visit_multi_selector(selector, context, recurse=recurse) + case DescendantSelector(): + yield from visit_descendant_selector(selector, context, recurse=recurse) + case SiblingSelector(): + yield from visit_sibling_selector(selector, context, recurse=recurse) + case NotPseudoClassSelector(): + yield from visit_not_pseudo_class_selector(selector, context, recurse=recurse) + case NthChildPseudoClassSelector(): + yield from visit_nth_child_pseudo_class_selector(selector, context, recurse=recurse) + case SimpleSelector(): + yield from visit_simple_selector(selector, context, recurse=recurse) + + +def find_all(selector: AnySelector, root: ParentAwareNode) -> Iterator[ParentAwareNode]: + known: set[int] = set() + for node in visit(selector, [root]): + id_ = id(node) + if id_ in known: + continue + + known.add(id(node)) + yield node