Overhaul sync.py script

This commit is contained in:
ItsDrike 2022-11-20 02:11:57 +01:00
parent 843904b67d
commit 89a5b15bbd
No known key found for this signature in database
GPG key ID: B014E761034AF742

250
sync.py
View file

@ -1,13 +1,17 @@
#!/usr/bin/env python
from __future__ import annotations
import os
import argparse
import hashlib
import os
import shutil
import subprocess
import sys
from collections.abc import Iterable, Iterator
from collections.abc import Iterable, Iterator, Sequence
from enum import Enum, auto
from pathlib import Path
from typing import NamedTuple
from typing import NamedTuple, Optional
try:
import rich
@ -21,6 +25,49 @@ DOTHOMEDIR = Path("./home")
HOMEDIR = Path(f"~{os.environ.get('SUDO_USER', os.getlogin())}") # Make sure we use correct home even in sudo
def yes_no(prompt: str, default: Optional[bool] = None) -> bool:
"""Get a yes/no answer to given prompt by the user."""
if default is None:
prompt += " [y/n]: "
elif default is True:
prompt += "[Y/n]: "
elif default is False:
prompt += "[y/N]: "
else:
raise ValueError("Invalid default value")
while True:
inp = input(prompt).lower()
if inp in {"y", "yes"}:
return True
elif inp in {"n", "no"}:
return False
elif inp == "" and default is not None:
return default
def int_input(msg: str) -> int:
"""Get a valid integer input from the user."""
while True:
x = input(msg)
try:
return int(x)
except ValueError:
print("Invalid input, expected a number.")
def choice_input(msg: str, choices: Sequence[str]) -> str:
"""Get one of given choices based on a valid user input."""
print(f"{msg}")
for index, choice in enumerate(choices, start=1):
print(f"{index}: {choice}")
x = int_input("Enter your choice: ")
if x < 1 or x > len(choices):
print("Invalid choice! Try again")
return choice_input(msg, choices)
return choices[x - 1]
class DiffStatus(Enum):
NOT_FOUND = auto()
PERMISSION_ERROR = auto()
@ -125,7 +172,7 @@ def iter_diffs() -> Iterator[FileDiff]:
yield FileDiff(dot_file, sys_file, diff_status)
def print_status(diffs: Iterable[FileDiff]) -> None:
def print_report(diffs: Iterable[FileDiff]) -> None:
"""Pretty print the individual diff statuses."""
# Exhause the iterable, and ensure we work on a copy
diffs = list(diffs)
@ -168,6 +215,118 @@ def print_status(diffs: Iterable[FileDiff]) -> None:
rich.print(table)
class FixChoice(Enum):
OVERWRITE_SYSTEM = auto()
OVERWRITE_DOTFILE = auto()
SKIP = auto()
@classmethod
def pick(cls, file_path: Path, system_type: Optional[str], dotfile_type: str) -> FixChoice:
if system_type is None:
overwrite_system_prompt = f"Create non-existing {dotfile_type}"
overwrite_dotfile_prompt = f"Delete dotfile {dotfile_type}"
else:
overwrite_system_prompt = f"Overwrite the system {system_type} with the dotfile {dotfile_type}"
overwrite_dotfile_prompt = f"Overwrite the dotfile {dotfile_type} with the system {system_type}"
answer = choice_input(
f"How to fix {file_path}?",
[overwrite_system_prompt, overwrite_dotfile_prompt, "Skip this fix"],
)
if answer == overwrite_system_prompt:
return cls.OVERWRITE_SYSTEM
elif answer == overwrite_dotfile_prompt:
return cls.OVERWRITE_DOTFILE
elif answer == "Skip this fix":
return cls.SKIP
raise Exception("This can't happen (just here for typing.NoReturn)")
def apply_fix(diff: FileDiff) -> None:
if diff.status is DiffStatus.PERMISSION_ERROR:
print("Skipping fix: insufficient permissions")
elif diff.status is DiffStatus.UNEXPECTED_DIRECTORY:
_choice = FixChoice.pick(diff.sys_file, "directory", "file")
if _choice is FixChoice.SKIP:
return
elif _choice is FixChoice.OVERWRITE_SYSTEM:
shutil.rmtree(diff.sys_file)
shutil.copy(diff.dot_file, diff.sys_file, follow_symlinks=False)
elif _choice is FixChoice.OVERWRITE_DOTFILE:
diff.dot_file.unlink()
shutil.copytree(diff.sys_file, diff.dot_file, symlinks=True)
elif diff.status is DiffStatus.UNEXPECTED_SYMLINK:
_choice = FixChoice.pick(diff.sys_file, "symlink", "file")
if _choice is FixChoice.SKIP:
return
elif _choice is FixChoice.OVERWRITE_SYSTEM:
diff.sys_file.unlink()
shutil.copy(diff.dot_file, diff.sys_file, follow_symlinks=False)
elif _choice is FixChoice.OVERWRITE_DOTFILE:
diff.dot_file.unlink()
shutil.copy(diff.sys_file, diff.dot_file, follow_symlinks=False)
elif diff.status is DiffStatus.EXPECTED_SYMLINK:
_choice = FixChoice.pick(diff.sys_file, "file", "symlink")
if _choice is FixChoice.SKIP:
return
elif _choice is FixChoice.OVERWRITE_SYSTEM:
diff.sys_file.unlink()
shutil.copy(diff.dot_file, diff.sys_file, follow_symlinks=False)
elif _choice is FixChoice.OVERWRITE_DOTFILE:
diff.dot_file.unlink()
shutil.copy(diff.sys_file, diff.dot_file, follow_symlinks=False)
elif diff.status is DiffStatus.SYMLINK_DIFFERS:
_choice = FixChoice.pick(diff.sys_file, "symlink", "file")
if _choice is FixChoice.SKIP:
return
elif _choice is FixChoice.OVERWRITE_SYSTEM:
diff.sys_file.unlink()
shutil.copy(diff.dot_file, diff.sys_file, follow_symlinks=False)
elif _choice is FixChoice.OVERWRITE_DOTFILE:
diff.dot_file.unlink()
shutil.copy(diff.sys_file, diff.dot_file, follow_symlinks=False)
elif diff.status is DiffStatus.NOT_FOUND:
_choice = FixChoice.pick(diff.sys_file, None, "file")
if _choice is FixChoice.SKIP:
return
elif _choice is FixChoice.OVERWRITE_SYSTEM:
shutil.copy(diff.dot_file, diff.sys_file, follow_symlinks=False)
elif _choice is FixChoice.OVERWRITE_DOTFILE:
diff.dot_file.unlink()
elif diff.status is DiffStatus.CONTENT_DIFFERS:
_choice = FixChoice.pick(diff.sys_file, "file", "file")
if _choice is FixChoice.SKIP:
return
elif _choice is FixChoice.OVERWRITE_SYSTEM:
shutil.copy(diff.dot_file, diff.sys_file, follow_symlinks=False)
elif _choice is FixChoice.OVERWRITE_DOTFILE:
shutil.copy(diff.sys_file, diff.dot_file, follow_symlinks=False)
def show_diffs(diffs: Iterable[FileDiff], ask_show_diff: bool, apply_fix_prompt: bool) -> None:
for diff in diffs:
if diff.status is DiffStatus.MATCH:
continue
if diff.status is DiffStatus.CONTENT_DIFFERS:
if ask_show_diff is False or yes_no(f"Show diff for {diff.sys_file}?"):
subprocess.run(["git", "diff", str(diff.sys_file), str(diff.dot_file)])
else:
_str_status = diff.status.name.replace("_", " ")
print(f"Skipping {diff.sys_file} diff for status: {_str_status}")
if apply_fix_prompt:
apply_fix(diff)
print("---")
def exclude_fun(diff: FileDiff) -> bool:
EXCLUDE_RULES = [
lambda d: d.status is DiffStatus.MATCH,
@ -190,12 +349,87 @@ def exclude_fun(diff: FileDiff) -> bool:
return True
def main() -> None:
diffs = iter_diffs()
diffs = filter(exclude_fun, diffs)
print_status(diffs)
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
prog="sync.py",
description="Compare the differences between system file and dotfile file versions",
)
parser.add_argument(
"-E",
"--no-apply-excludes",
help="Don't apply exclude rules (some files are expected to be modified, use if you want to see these too)",
action="store_true",
)
parser.add_argument(
"-R",
"--no-show-report",
help="Don't show the full report of all files, and their diff status in the beginning",
action="store_true",
)
exc_grp = parser.add_mutually_exclusive_group()
exc_grp.add_argument(
"-d",
"--show-diffs",
help="Show file diffs comparing the versions (using git diff)",
action="store_true",
default=None,
)
exc_grp.add_argument(
"-D",
"--no-show-diffs",
help="Don't show file diffs comparing the versions",
dest="show_diffs",
action="store_false",
)
parser.add_argument(
"-A",
"--no-ask-each-diff",
help="Don't ask whether to show a diff or not for each modified file (can be annoying for many files)",
action="store_true",
)
parser.add_argument(
"-f",
"--apply-fixes",
help=(
"Asks whether to overwrite the system file with the dotfile version (or vice-versa)."
" This option can only be used with --show-diffs"
),
action="store_true",
default=None,
)
ns = parser.parse_args()
if ns.no_ask_each_diff and not ns.show_diffs:
parser.error("-A/--no-ask-each-diff only makes sense with -d/--show-diffs")
if ns.apply_fixes and not ns.show_diffs:
parser.error("-f/--apply-fixes only makes sense with -d/--show-diffs")
return ns
def main() -> None:
ns = get_args()
diffs = iter_diffs()
if not ns.no_apply_excludes:
diffs = filter(exclude_fun, diffs)
diffs = list(diffs)
if not ns.no_show_report:
print_report(diffs)
if ns.show_diffs is True or ns.show_diffs is None and yes_no("Show diffs for modified files?"):
if ns.apply_fixes is None:
apply_fixes = yes_no("Apply fixes?")
else:
apply_fixes = ns.apply_fixes
show_diffs(diffs, ask_show_diff=not ns.no_ask_each_diff, apply_fix_prompt=apply_fixes)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n\nStopped...", file=sys.stderr)