#!/usr/bin/env python3
import re
import sys
from argparse import ArgumentParser
from typing import Any, Sequence


class TrieNode:
    def __init__(self):
        self.children = {}
        self.is_end = False

    def __len__(self) -> int:
        return len(self.children)

    def __bool__(self) -> bool:
        return len(self.children) != 0

    def __getitem__(self, key: str):
        return self.children[key]

    def __setitem__(self, key: str, value: Any):
        self.children[key] = value

    def __iter__(self):
        return iter(self.children)


class Trie:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, word: str) -> None:
        node = self.root
        for char in word:
            if char not in node.children:
                node.children[char] = TrieNode()
            node = node.children[char]
        node.is_end = True

    def to_regex(self):
        return f"^{self._dfs(self.root)}$"

    @staticmethod
    def _escape(s: str, char_class: bool = False) -> str:
        if char_class:
            return re.sub(r"([\\\[\]])", r"\\\1", s)

        return re.sub(r"([.*+?{}\\\[\]])", r"\\\1", s)

    @staticmethod
    def _as_char_class(s: Sequence[str]) -> str:
        if len(s) == 0:
            raise ValueError("empty sequence")

        ranges = []
        sorted_chars = list(sorted(s))
        ranges.append([sorted_chars[0], sorted_chars[0]])

        for i in range(1, len(sorted_chars)):
            char = sorted_chars[i]
            if ord(char) == ord(ranges[-1][1]) + 1:
                ranges[-1][1] = char
            else:
                ranges.append([char, char])

        result = ""
        enclose = len(ranges) > 1
        for start, end in ranges:
            if start == end:
                result += Trie._escape(start)
                continue

            enclose = True
            if ord(start) == ord(end) - 1:
                result += Trie._escape(start) + Trie._escape(end)
            else:
                result += Trie._escape(start, True) + "-" + Trie._escape(end, True)

        return f"[{result}]" if enclose else result

    @staticmethod
    def _dfs(node: TrieNode) -> str:
        if not node:
            return ""

        single_chars = []
        multi_chars = {}
        for char in sorted(node):
            subpattern = Trie._dfs(node[char])
            if subpattern == "":
                single_chars.append(char)
            else:
                if subpattern not in multi_chars:
                    multi_chars[subpattern] = []
                multi_chars[subpattern].append(char)

        if not multi_chars:
            pattern = Trie._as_char_class(single_chars)
            brackets = False
        elif not single_chars and len(multi_chars) == 1:
            subpattern, chars = next(iter(multi_chars.items()))
            pattern = Trie._as_char_class(chars) + subpattern
            brackets = node.is_end or len(chars) > 1
        else:
            brackets = True
            pattern = "|".join(
                map(
                    lambda x: Trie._as_char_class(x[1]) + x[0],
                    multi_chars.items(),
                )
            )
            if single_chars:
                pattern += "|" + Trie._as_char_class(single_chars)

        if brackets:
            pattern = f"(?:{pattern})"

        if node.is_end:
            pattern += "?"

        return pattern


def main():
    parser = ArgumentParser()
    parser.add_argument("file", nargs="*", help="file to read words from")
    args = parser.parse_args()

    trie = Trie()
    if not args.file:
        for line in sys.stdin:
            line = line.rstrip("\n")
            if line:
                trie.insert(line)
    else:
        for file in args.file:
            if file == "-":
                for line in sys.stdin:
                    line = line.rstrip("\n")
                    if line:
                        trie.insert(line)
            else:
                with open(file) as f:
                    for line in f:
                        line = line.rstrip("\n")
                        if line:
                            trie.insert(line)

    print(trie.to_regex())


main()

Embed on website

To embed this program on your website, copy the following code and paste it into your website's HTML: