#!/usr/bin/env python3
import re
import sys
from argparse import ArgumentParser
from typing import Any, Sequence
class TrieNode:
def __init__(self):
self.children = {}
self.is_end = False
def __len__(self) -> int:
return len(self.children)
def __bool__(self) -> bool:
return len(self.children) != 0
def __getitem__(self, key: str):
return self.children[key]
def __setitem__(self, key: str, value: Any):
self.children[key] = value
def __iter__(self):
return iter(self.children)
class Trie:
def __init__(self):
self.root = TrieNode()
def insert(self, word: str) -> None:
node = self.root
for char in word:
if char not in node.children:
node.children[char] = TrieNode()
node = node.children[char]
node.is_end = True
def to_regex(self):
return f"^{self._dfs(self.root)}$"
@staticmethod
def _escape(s: str, char_class: bool = False) -> str:
if char_class:
return re.sub(r"([\\\[\]])", r"\\\1", s)
return re.sub(r"([.*+?{}\\\[\]])", r"\\\1", s)
@staticmethod
def _as_char_class(s: Sequence[str]) -> str:
if len(s) == 0:
raise ValueError("empty sequence")
ranges = []
sorted_chars = list(sorted(s))
ranges.append([sorted_chars[0], sorted_chars[0]])
for i in range(1, len(sorted_chars)):
char = sorted_chars[i]
if ord(char) == ord(ranges[-1][1]) + 1:
ranges[-1][1] = char
else:
ranges.append([char, char])
result = ""
enclose = len(ranges) > 1
for start, end in ranges:
if start == end:
result += Trie._escape(start)
continue
enclose = True
if ord(start) == ord(end) - 1:
result += Trie._escape(start) + Trie._escape(end)
else:
result += Trie._escape(start, True) + "-" + Trie._escape(end, True)
return f"[{result}]" if enclose else result
@staticmethod
def _dfs(node: TrieNode) -> str:
if not node:
return ""
single_chars = []
multi_chars = {}
for char in sorted(node):
subpattern = Trie._dfs(node[char])
if subpattern == "":
single_chars.append(char)
else:
if subpattern not in multi_chars:
multi_chars[subpattern] = []
multi_chars[subpattern].append(char)
if not multi_chars:
pattern = Trie._as_char_class(single_chars)
brackets = False
elif not single_chars and len(multi_chars) == 1:
subpattern, chars = next(iter(multi_chars.items()))
pattern = Trie._as_char_class(chars) + subpattern
brackets = node.is_end or len(chars) > 1
else:
brackets = True
pattern = "|".join(
map(
lambda x: Trie._as_char_class(x[1]) + x[0],
multi_chars.items(),
)
)
if single_chars:
pattern += "|" + Trie._as_char_class(single_chars)
if brackets:
pattern = f"(?:{pattern})"
if node.is_end:
pattern += "?"
return pattern
def main():
parser = ArgumentParser()
parser.add_argument("file", nargs="*", help="file to read words from")
args = parser.parse_args()
trie = Trie()
if not args.file:
for line in sys.stdin:
line = line.rstrip("\n")
if line:
trie.insert(line)
else:
for file in args.file:
if file == "-":
for line in sys.stdin:
line = line.rstrip("\n")
if line:
trie.insert(line)
else:
with open(file) as f:
for line in f:
line = line.rstrip("\n")
if line:
trie.insert(line)
print(trie.to_regex())
main()
To embed this program on your website, copy the following code and paste it into your website's HTML: