Switch to a filter visitor
This commit is contained in:
parent
26dc998f58
commit
25c98b6e42
4 changed files with 181 additions and 82 deletions
|
@ -35,9 +35,9 @@ class ParentAwareNode(Node):
|
|||
return aware_node
|
||||
|
||||
|
||||
def walk_nodes(node: ParentAwareNode) -> Iterator[ParentAwareNode]:
|
||||
def walk_nodes(node: Node) -> Iterator[Node]:
|
||||
"""Yields every node in the tree, including the root one (pre-order DFS)."""
|
||||
stack: list[ParentAwareNode] = [node]
|
||||
stack: list[Node] = [node]
|
||||
while len(stack) > 0:
|
||||
cur_node = stack.pop()
|
||||
yield cur_node
|
||||
|
|
|
@ -61,22 +61,6 @@ class SimpleSelector:
|
|||
|
||||
return cls(tag, cls_selectors, id_selectors)
|
||||
|
||||
@staticmethod
|
||||
def _match_attrs(attr: Literal["class", "id"], sel: set[str], node: ParentAwareNode) -> bool:
|
||||
"""Check that a given node contains all selected attributes of given kind."""
|
||||
attr_str = node.attributes.get(attr)
|
||||
attrs: set[str] = set(attr_str.split(" ")) if attr_str else set()
|
||||
return sel.issubset(attrs)
|
||||
|
||||
def match_node(self, node: ParentAwareNode) -> bool:
|
||||
if self.tag and node.tag != self.tag:
|
||||
return False
|
||||
|
||||
if not self._match_attrs("class", set(self.classes), node):
|
||||
return False
|
||||
|
||||
return self._match_attrs("id", set(self.ids), node)
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
classes = ".".join(cls_name for cls_name in self.classes)
|
||||
|
@ -159,18 +143,6 @@ class DescendantSelector:
|
|||
child: SimpleSelector | ConcretePseudoClassSelector | DescendantSelector | SiblingSelector
|
||||
direct: bool # descendant (" ") vs direct child (">")
|
||||
|
||||
def match_node(self, node: ParentAwareNode) -> bool:
|
||||
if not self.child.match_node(node):
|
||||
return False
|
||||
|
||||
if not node.parent:
|
||||
return False
|
||||
|
||||
if self.direct:
|
||||
return self.parent.match_node(node.parent)
|
||||
|
||||
return any(self.parent.match_node(parent) for parent in walk_parents(node))
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
symbol = " > " if self.direct else " "
|
||||
|
@ -183,29 +155,6 @@ class SiblingSelector:
|
|||
selector: SimpleSelector | ConcretePseudoClassSelector | DescendantSelector | SiblingSelector
|
||||
is_adjacent: bool # adjacent sibling ("+"), subsequent sibling ("~")
|
||||
|
||||
def match_node(self, node: ParentAwareNode) -> bool:
|
||||
if not self.selector.match_node(node):
|
||||
return False
|
||||
|
||||
if not node.parent:
|
||||
return False
|
||||
|
||||
for i, sibling in enumerate(node.parent.children):
|
||||
if sibling == node: # NOTE: is check might be safer here
|
||||
child_index = i
|
||||
break
|
||||
else: # nobreak
|
||||
raise AssertionError("Parent node doesn't contain it's child") # pragma: no cover
|
||||
|
||||
if child_index == 0:
|
||||
return False # no previous siblings
|
||||
|
||||
if self.is_adjacent:
|
||||
sibling = node.parent.children[child_index - 1]
|
||||
return self.sibling_selector.match_node(sibling)
|
||||
|
||||
return any(self.sibling_selector.match_node(sibling) for sibling in node.parent.children[:child_index])
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
symbol = " + " if self.is_adjacent else " ~ "
|
||||
|
@ -219,9 +168,6 @@ type NonMultiSelector = SimpleSelector | ConcretePseudoClassSelector | Descendan
|
|||
class MultiSelector:
|
||||
selectors: list[NonMultiSelector]
|
||||
|
||||
def match_node(self, node: ParentAwareNode) -> bool:
|
||||
return any(selector.match_node(node) for selector in self.selectors)
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return ", ".join(str(sel) for sel in self.selectors)
|
||||
|
@ -248,12 +194,6 @@ class NotPseudoClassSelector:
|
|||
not_selector = parse_tokens(TokenStream(sel.argument))
|
||||
return cls(sel.selector, not_selector)
|
||||
|
||||
def match_node(self, node: ParentAwareNode) -> bool:
|
||||
if self.selector and not self.selector.match_node(node):
|
||||
return False
|
||||
|
||||
return not self.not_selector.match_node(node)
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
sel = str(self.selector) if self.selector else ""
|
||||
|
@ -295,24 +235,6 @@ class NthChildPseudoClassSelector:
|
|||
|
||||
return cls(sel.selector, n)
|
||||
|
||||
def match_node(self, node: ParentAwareNode) -> bool:
|
||||
if node.parent is None:
|
||||
return False
|
||||
|
||||
# The N indicates n-th child, but python indexes start at 0, subtract 1.
|
||||
# Unless it's negative, in which case stay as-is (allowing n=-1)
|
||||
child_index = self.n - 1 if self.n > 0 else self.n
|
||||
|
||||
try:
|
||||
nth_child = node.parent.children[child_index]
|
||||
except IndexError:
|
||||
return False
|
||||
|
||||
if nth_child != node:
|
||||
return False
|
||||
|
||||
return not (self.selector and not self.selector.match_node(node))
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
if self.n == 1:
|
||||
|
|
|
@ -2,9 +2,10 @@ from __future__ import annotations
|
|||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .node_helpers import ParentAwareNode, walk_nodes
|
||||
from .node_helpers import ParentAwareNode
|
||||
from .parser import parse_tokens
|
||||
from .tokenizer import TokenStream, tokenize_selector
|
||||
from .visitor import find_all
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .node import Node
|
||||
|
@ -16,4 +17,4 @@ def query_selector_all(node: Node, selector: str) -> list[Node]:
|
|||
toks = TokenStream(tokenize_selector(selector))
|
||||
sel = parse_tokens(toks)
|
||||
|
||||
return [n.original for n in walk_nodes(root) if sel.match_node(n)]
|
||||
return [node.original for node in find_all(sel, root)]
|
||||
|
|
176
src/visitor.py
Normal file
176
src/visitor.py
Normal file
|
@ -0,0 +1,176 @@
|
|||
from collections.abc import Iterable, Iterator
|
||||
from typing import Literal
|
||||
|
||||
from src.node import Node
|
||||
from src.node_helpers import ParentAwareNode, walk_nodes
|
||||
from src.parser import (
|
||||
AnySelector,
|
||||
DescendantSelector,
|
||||
MultiSelector,
|
||||
NotPseudoClassSelector,
|
||||
NthChildPseudoClassSelector,
|
||||
SiblingSelector,
|
||||
SimpleSelector,
|
||||
)
|
||||
|
||||
|
||||
def _match_attrs(attr: Literal["class", "id"], sel: set[str], node: Node) -> bool:
|
||||
"""Check that a given node contains all selected attributes of given kind."""
|
||||
attr_str = node.attributes.get(attr)
|
||||
attrs: set[str] = set(attr_str.split(" ")) if attr_str else set()
|
||||
return sel.issubset(attrs)
|
||||
|
||||
|
||||
def _simple_selector_match(sel: SimpleSelector, node: Node) -> bool:
|
||||
if sel.tag and node.tag != sel.tag:
|
||||
return False
|
||||
|
||||
if not _match_attrs("class", set(sel.classes), node):
|
||||
return False
|
||||
|
||||
return _match_attrs("id", set(sel.ids), node)
|
||||
|
||||
|
||||
def visit_simple_selector(
|
||||
selector: SimpleSelector,
|
||||
context: Iterable[ParentAwareNode],
|
||||
recurse: bool = True,
|
||||
) -> Iterator[ParentAwareNode]:
|
||||
for node in context:
|
||||
if recurse:
|
||||
nodes = walk_nodes(node) # will include node too
|
||||
else:
|
||||
nodes = [node]
|
||||
|
||||
for subnode in nodes:
|
||||
if _simple_selector_match(selector, subnode):
|
||||
yield subnode
|
||||
|
||||
|
||||
def visit_descendant_selector(
|
||||
selector: DescendantSelector,
|
||||
context: Iterable[ParentAwareNode],
|
||||
recurse: bool = True,
|
||||
) -> Iterator[ParentAwareNode]:
|
||||
context = visit(selector.parent, context, recurse=recurse) # get all applicable parents
|
||||
|
||||
descendants = (child_node for node in context for child_node in node.children)
|
||||
|
||||
if not selector.direct:
|
||||
descendants = (subnode for node in descendants for subnode in walk_nodes(node))
|
||||
|
||||
yield from visit(selector.child, descendants, recurse=False)
|
||||
|
||||
|
||||
def visit_sibling_selector(
|
||||
selector: SiblingSelector,
|
||||
context: Iterable[ParentAwareNode],
|
||||
recurse: bool = True,
|
||||
) -> Iterator[ParentAwareNode]:
|
||||
for node in visit(selector.selector, context, recurse=recurse):
|
||||
if node.parent is None: # no siblings if no parent
|
||||
continue
|
||||
|
||||
for i, sibling in enumerate(node.parent.children):
|
||||
if sibling == node:
|
||||
child_index = i
|
||||
break
|
||||
else: # nobreak
|
||||
raise AssertionError("Parent node doesn't contain it's child") # pragma: no cover
|
||||
|
||||
if child_index == 0:
|
||||
continue # no previous siblings
|
||||
|
||||
if selector.is_adjacent:
|
||||
sibling = node.parent.children[child_index - 1]
|
||||
siblings = [sibling]
|
||||
else:
|
||||
siblings = node.parent.children[:child_index]
|
||||
|
||||
if any(visit(selector.sibling_selector, siblings, recurse=False)):
|
||||
yield node
|
||||
|
||||
|
||||
def visit_multi_selector(
|
||||
selector: MultiSelector,
|
||||
context: Iterable[ParentAwareNode],
|
||||
recurse: bool = True,
|
||||
) -> Iterator[ParentAwareNode]:
|
||||
for sel in selector.selectors:
|
||||
yield from visit(sel, context, recurse=recurse)
|
||||
|
||||
|
||||
def visit_not_pseudo_class_selector(
|
||||
selector: NotPseudoClassSelector,
|
||||
context: Iterable[ParentAwareNode],
|
||||
recurse: bool = True,
|
||||
) -> Iterator[ParentAwareNode]:
|
||||
if selector.selector:
|
||||
context = visit(selector.selector, context, recurse=recurse)
|
||||
elif recurse:
|
||||
context = (subnode for node in context for subnode in walk_nodes(node))
|
||||
|
||||
for node in context:
|
||||
if any(visit(selector.not_selector, [node], recurse=False)):
|
||||
continue # exclude this node (not condition matched)
|
||||
yield node
|
||||
|
||||
|
||||
def visit_nth_child_pseudo_class_selector(
|
||||
selector: NthChildPseudoClassSelector,
|
||||
context: Iterable[ParentAwareNode],
|
||||
recurse: bool = True,
|
||||
) -> Iterator[ParentAwareNode]:
|
||||
if selector.selector:
|
||||
context = visit(selector.selector, context, recurse=recurse)
|
||||
elif recurse:
|
||||
context = (subnode for node in context for subnode in walk_nodes(node))
|
||||
|
||||
for node in context:
|
||||
if node.parent is None:
|
||||
continue
|
||||
|
||||
# The N indicates n-th child, but python indexes start at 0, subtract 1.
|
||||
# Unless it's negative, in which case stay as-is (allowing n=-1)
|
||||
child_index = selector.n - 1 if selector.n > 0 else selector.n
|
||||
|
||||
try:
|
||||
nth_child = node.parent.children[child_index]
|
||||
except IndexError:
|
||||
continue
|
||||
|
||||
if nth_child != node:
|
||||
continue
|
||||
|
||||
yield node
|
||||
|
||||
|
||||
def visit(
|
||||
selector: AnySelector,
|
||||
context: Iterable[ParentAwareNode],
|
||||
recurse: bool = True,
|
||||
) -> Iterator[ParentAwareNode]:
|
||||
match selector:
|
||||
case MultiSelector():
|
||||
yield from visit_multi_selector(selector, context, recurse=recurse)
|
||||
case DescendantSelector():
|
||||
yield from visit_descendant_selector(selector, context, recurse=recurse)
|
||||
case SiblingSelector():
|
||||
yield from visit_sibling_selector(selector, context, recurse=recurse)
|
||||
case NotPseudoClassSelector():
|
||||
yield from visit_not_pseudo_class_selector(selector, context, recurse=recurse)
|
||||
case NthChildPseudoClassSelector():
|
||||
yield from visit_nth_child_pseudo_class_selector(selector, context, recurse=recurse)
|
||||
case SimpleSelector():
|
||||
yield from visit_simple_selector(selector, context, recurse=recurse)
|
||||
|
||||
|
||||
def find_all(selector: AnySelector, root: ParentAwareNode) -> Iterator[ParentAwareNode]:
|
||||
known: set[int] = set()
|
||||
for node in visit(selector, [root]):
|
||||
id_ = id(node)
|
||||
if id_ in known:
|
||||
continue
|
||||
|
||||
known.add(id(node))
|
||||
yield node
|
Loading…
Add table
Add a link
Reference in a new issue