Switch to a filter visitor

This commit is contained in:
ItsDrike 2025-07-27 15:36:10 +02:00
parent 26dc998f58
commit 25c98b6e42
Signed by: ItsDrike
GPG key ID: FA2745890B7048C0
4 changed files with 181 additions and 82 deletions

View file

@ -35,9 +35,9 @@ class ParentAwareNode(Node):
return aware_node
def walk_nodes(node: ParentAwareNode) -> Iterator[ParentAwareNode]:
def walk_nodes(node: Node) -> Iterator[Node]:
"""Yields every node in the tree, including the root one (pre-order DFS)."""
stack: list[ParentAwareNode] = [node]
stack: list[Node] = [node]
while len(stack) > 0:
cur_node = stack.pop()
yield cur_node

View file

@ -61,22 +61,6 @@ class SimpleSelector:
return cls(tag, cls_selectors, id_selectors)
@staticmethod
def _match_attrs(attr: Literal["class", "id"], sel: set[str], node: ParentAwareNode) -> bool:
"""Check that a given node contains all selected attributes of given kind."""
attr_str = node.attributes.get(attr)
attrs: set[str] = set(attr_str.split(" ")) if attr_str else set()
return sel.issubset(attrs)
def match_node(self, node: ParentAwareNode) -> bool:
if self.tag and node.tag != self.tag:
return False
if not self._match_attrs("class", set(self.classes), node):
return False
return self._match_attrs("id", set(self.ids), node)
@override
def __str__(self) -> str:
classes = ".".join(cls_name for cls_name in self.classes)
@ -159,18 +143,6 @@ class DescendantSelector:
child: SimpleSelector | ConcretePseudoClassSelector | DescendantSelector | SiblingSelector
direct: bool # descendant (" ") vs direct child (">")
def match_node(self, node: ParentAwareNode) -> bool:
if not self.child.match_node(node):
return False
if not node.parent:
return False
if self.direct:
return self.parent.match_node(node.parent)
return any(self.parent.match_node(parent) for parent in walk_parents(node))
@override
def __str__(self) -> str:
symbol = " > " if self.direct else " "
@ -183,29 +155,6 @@ class SiblingSelector:
selector: SimpleSelector | ConcretePseudoClassSelector | DescendantSelector | SiblingSelector
is_adjacent: bool # adjacent sibling ("+"), subsequent sibling ("~")
def match_node(self, node: ParentAwareNode) -> bool:
if not self.selector.match_node(node):
return False
if not node.parent:
return False
for i, sibling in enumerate(node.parent.children):
if sibling == node: # NOTE: is check might be safer here
child_index = i
break
else: # nobreak
raise AssertionError("Parent node doesn't contain it's child") # pragma: no cover
if child_index == 0:
return False # no previous siblings
if self.is_adjacent:
sibling = node.parent.children[child_index - 1]
return self.sibling_selector.match_node(sibling)
return any(self.sibling_selector.match_node(sibling) for sibling in node.parent.children[:child_index])
@override
def __str__(self) -> str:
symbol = " + " if self.is_adjacent else " ~ "
@ -219,9 +168,6 @@ type NonMultiSelector = SimpleSelector | ConcretePseudoClassSelector | Descendan
class MultiSelector:
selectors: list[NonMultiSelector]
def match_node(self, node: ParentAwareNode) -> bool:
return any(selector.match_node(node) for selector in self.selectors)
@override
def __str__(self) -> str:
return ", ".join(str(sel) for sel in self.selectors)
@ -248,12 +194,6 @@ class NotPseudoClassSelector:
not_selector = parse_tokens(TokenStream(sel.argument))
return cls(sel.selector, not_selector)
def match_node(self, node: ParentAwareNode) -> bool:
if self.selector and not self.selector.match_node(node):
return False
return not self.not_selector.match_node(node)
@override
def __str__(self) -> str:
sel = str(self.selector) if self.selector else ""
@ -295,24 +235,6 @@ class NthChildPseudoClassSelector:
return cls(sel.selector, n)
def match_node(self, node: ParentAwareNode) -> bool:
if node.parent is None:
return False
# The N indicates n-th child, but python indexes start at 0, subtract 1.
# Unless it's negative, in which case stay as-is (allowing n=-1)
child_index = self.n - 1 if self.n > 0 else self.n
try:
nth_child = node.parent.children[child_index]
except IndexError:
return False
if nth_child != node:
return False
return not (self.selector and not self.selector.match_node(node))
@override
def __str__(self) -> str:
if self.n == 1:

View file

@ -2,9 +2,10 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from .node_helpers import ParentAwareNode, walk_nodes
from .node_helpers import ParentAwareNode
from .parser import parse_tokens
from .tokenizer import TokenStream, tokenize_selector
from .visitor import find_all
if TYPE_CHECKING:
from .node import Node
@ -16,4 +17,4 @@ def query_selector_all(node: Node, selector: str) -> list[Node]:
toks = TokenStream(tokenize_selector(selector))
sel = parse_tokens(toks)
return [n.original for n in walk_nodes(root) if sel.match_node(n)]
return [node.original for node in find_all(sel, root)]

176
src/visitor.py Normal file
View file

@ -0,0 +1,176 @@
from collections.abc import Iterable, Iterator
from typing import Literal
from src.node import Node
from src.node_helpers import ParentAwareNode, walk_nodes
from src.parser import (
AnySelector,
DescendantSelector,
MultiSelector,
NotPseudoClassSelector,
NthChildPseudoClassSelector,
SiblingSelector,
SimpleSelector,
)
def _match_attrs(attr: Literal["class", "id"], sel: set[str], node: Node) -> bool:
"""Check that a given node contains all selected attributes of given kind."""
attr_str = node.attributes.get(attr)
attrs: set[str] = set(attr_str.split(" ")) if attr_str else set()
return sel.issubset(attrs)
def _simple_selector_match(sel: SimpleSelector, node: Node) -> bool:
if sel.tag and node.tag != sel.tag:
return False
if not _match_attrs("class", set(sel.classes), node):
return False
return _match_attrs("id", set(sel.ids), node)
def visit_simple_selector(
selector: SimpleSelector,
context: Iterable[ParentAwareNode],
recurse: bool = True,
) -> Iterator[ParentAwareNode]:
for node in context:
if recurse:
nodes = walk_nodes(node) # will include node too
else:
nodes = [node]
for subnode in nodes:
if _simple_selector_match(selector, subnode):
yield subnode
def visit_descendant_selector(
selector: DescendantSelector,
context: Iterable[ParentAwareNode],
recurse: bool = True,
) -> Iterator[ParentAwareNode]:
context = visit(selector.parent, context, recurse=recurse) # get all applicable parents
descendants = (child_node for node in context for child_node in node.children)
if not selector.direct:
descendants = (subnode for node in descendants for subnode in walk_nodes(node))
yield from visit(selector.child, descendants, recurse=False)
def visit_sibling_selector(
selector: SiblingSelector,
context: Iterable[ParentAwareNode],
recurse: bool = True,
) -> Iterator[ParentAwareNode]:
for node in visit(selector.selector, context, recurse=recurse):
if node.parent is None: # no siblings if no parent
continue
for i, sibling in enumerate(node.parent.children):
if sibling == node:
child_index = i
break
else: # nobreak
raise AssertionError("Parent node doesn't contain it's child") # pragma: no cover
if child_index == 0:
continue # no previous siblings
if selector.is_adjacent:
sibling = node.parent.children[child_index - 1]
siblings = [sibling]
else:
siblings = node.parent.children[:child_index]
if any(visit(selector.sibling_selector, siblings, recurse=False)):
yield node
def visit_multi_selector(
selector: MultiSelector,
context: Iterable[ParentAwareNode],
recurse: bool = True,
) -> Iterator[ParentAwareNode]:
for sel in selector.selectors:
yield from visit(sel, context, recurse=recurse)
def visit_not_pseudo_class_selector(
selector: NotPseudoClassSelector,
context: Iterable[ParentAwareNode],
recurse: bool = True,
) -> Iterator[ParentAwareNode]:
if selector.selector:
context = visit(selector.selector, context, recurse=recurse)
elif recurse:
context = (subnode for node in context for subnode in walk_nodes(node))
for node in context:
if any(visit(selector.not_selector, [node], recurse=False)):
continue # exclude this node (not condition matched)
yield node
def visit_nth_child_pseudo_class_selector(
selector: NthChildPseudoClassSelector,
context: Iterable[ParentAwareNode],
recurse: bool = True,
) -> Iterator[ParentAwareNode]:
if selector.selector:
context = visit(selector.selector, context, recurse=recurse)
elif recurse:
context = (subnode for node in context for subnode in walk_nodes(node))
for node in context:
if node.parent is None:
continue
# The N indicates n-th child, but python indexes start at 0, subtract 1.
# Unless it's negative, in which case stay as-is (allowing n=-1)
child_index = selector.n - 1 if selector.n > 0 else selector.n
try:
nth_child = node.parent.children[child_index]
except IndexError:
continue
if nth_child != node:
continue
yield node
def visit(
selector: AnySelector,
context: Iterable[ParentAwareNode],
recurse: bool = True,
) -> Iterator[ParentAwareNode]:
match selector:
case MultiSelector():
yield from visit_multi_selector(selector, context, recurse=recurse)
case DescendantSelector():
yield from visit_descendant_selector(selector, context, recurse=recurse)
case SiblingSelector():
yield from visit_sibling_selector(selector, context, recurse=recurse)
case NotPseudoClassSelector():
yield from visit_not_pseudo_class_selector(selector, context, recurse=recurse)
case NthChildPseudoClassSelector():
yield from visit_nth_child_pseudo_class_selector(selector, context, recurse=recurse)
case SimpleSelector():
yield from visit_simple_selector(selector, context, recurse=recurse)
def find_all(selector: AnySelector, root: ParentAwareNode) -> Iterator[ParentAwareNode]:
known: set[int] = set()
for node in visit(selector, [root]):
id_ = id(node)
if id_ in known:
continue
known.add(id(node))
yield node