from typing import Dict, Tuple, List

# Tunable parameters:
ROOT_LEN = 5   # number of least-significant digits handled directly
CELL_LEN = 3   # digits per "cell" above the root

# helpers
_pow10_cache: List[int] = []
def pow10(k: int) -> int:
    while len(_pow10_cache) <= k:
        _pow10_cache.append(10 ** len(_pow10_cache))
    return _pow10_cache[k]

def digit_sum(x: int) -> int:
    s = 0
    while x:
        s += x % 10
        x //= 10
    return s

# ---- root-level simulation and memoization ----
# root_value is the current integer stored in the root block (0 <= root_value < 10**ROOT_LEN)
root_value: int = 0
# memo: (prev_digit_sum, initial_root_value) -> (steps_taken_until_underflow, root_value_after_underflow)
_root_memo: Dict[Tuple[int,int], Tuple[int,int]] = {}

def _root_steps(prev_digit_sum: int) -> int:
    """
    Simulate repeated root_value := root_value - (digit_sum(root_value) + prev_digit_sum)
    until root_value < 0 (underflow) or until we hit exact zero with prev_digit_sum == 0.
    Returns the number of steps performed. Mutates global root_value to the value after underflow
    (i.e., root_value + 10**ROOT_LEN) and memoizes the result.
    """
    global root_value, _root_memo
    key = (prev_digit_sum, root_value)
    if key in _root_memo:
        steps, new_root = _root_memo[key]
        root_value = new_root
        return steps

    start_root = root_value
    steps = 0
    limit = pow10(ROOT_LEN)
    while root_value >= 0:
        # apply one subtraction step on the block-level (root sees prev_digit_sum)
        root_value -= (digit_sum(root_value) + prev_digit_sum)
        steps += 1
        # if we've reached exact zero and there are no more higher digits contributing, we're done
        if root_value == 0 and prev_digit_sum == 0:
            _root_memo[key] = (steps, root_value)
            return steps

    # root_value < 0 means we underflowed and will borrow from a higher cell:
    # model the borrow by adding back one root-block (i.e. +10**ROOT_LEN)
    root_value += limit
    _root_memo[key] = (steps, root_value)
    return steps

# ---- general cell-level recursion + memoization per cell-level ----
# cell_memos[level] stores a dict keyed by (prev_digit_sum, initial_root_value)
# mapping -> (total_steps_for_this_level, root_value_after_processing)
_cell_memos: List[Dict[Tuple[int,int], Tuple[int,int]]] = []

def _process_cells_from_prefix(cells_start: List[int]) -> int:
    """
    Given a list of cell starting values (from most-significant cell down to the last cell
    just above the root), compute total steps to reach 0 by enumerating the actual sequence
    of cell-values that will be visited.
    This function handles the 'prefix' style iteration:
      - for the top-most cell we iterate from its starting value down to 0;
      - while we are still at the starting value for a cell, the next (lower) cell uses its start value;
      - once a cell drops below its start value, all lower cells iterate over their full ranges.
    """
    # We will implement a recursive generator with a 'following_prefix' flag:
    max_cell_value = pow10(CELL_LEN) - 1
    k = len(cells_start)  # number of cells above root

    # ensure memos exist for each cell level
    while len(_cell_memos) < k:
        _cell_memos.append({})

    def recurse(pos: int, prev_dsum: int, following_prefix: bool) -> int:
        """
        pos: index into cells_start (0 = most-significant cell)
        prev_dsum: digit-sum contributed by cells more significant than current position
        following_prefix: True if all more-significant cells have taken their starting value so far
        """
        global root_value

        # base case: no more cells -> operate on root with the accumulated prev_dsum
        if pos == k:
            return _root_steps(prev_dsum)

        # For memoization use the same key strategy as original code:
        # Only memoize at the moment we call this function with following_prefix==False
        # (i.e., when lower cells will run full ranges) — because only then results repeat
        memo_key = (prev_dsum, root_value)
        if not following_prefix:
            # cell index for memo table is pos (lower positions will always be full range)
            if memo_key in _cell_memos[pos]:
                steps_here, new_root = _cell_memos[pos][memo_key]
                root_value = new_root
                return steps_here

        # determine the starting value for this cell's loop
        start_val = cells_start[pos] if following_prefix else max_cell_value

        total = 0
        # loop from current start down to 0 (these are the values this cell will take)
        for d in range(start_val, -1, -1):
            next_following = following_prefix and (d == start_val)
            total += recurse(pos + 1, prev_dsum + digit_sum(d), next_following)

        # memoize if not following prefix (results repeat for same (prev_dsum, root_value))
        if not following_prefix:
            _cell_memos[pos][memo_key] = (total, root_value)
        return total

    return recurse(0, 0, True)

# ---- public API ----

def steps_from(m: int) -> int:
    """
    Return number of iterations of x := x - digit_sum(x) starting from integer m until reaching 0.
    Works for any non-negative integer m (including 10**k).
    """
    if m < 0:
        raise ValueError("m must be non-negative")

    global root_value, _cell_memos, _root_memo
    # reset mutable global state (memoization persists across runs for speed, but root_value must be set)
    # keep memo dicts (they help across repeated queries), but reset root_value
    root_value = 0

    # quick trivial case
    if m == 0:
        return 0

    # split m into root and cells
    root_base = pow10(ROOT_LEN)
    root_value = m % root_base
    remaining = m // root_base

    # build cells_start list: most-significant cell first
    cells_start: List[int] = []
    if remaining > 0:
        # top partial cell may have fewer than CELL_LEN digits; treat it as a full cell with value < pow10(CELL_LEN)
        # we'll extract digits in big-endian order for the cells above root
        tmp = []
        while remaining > 0:
            tmp.append(remaining % pow10(CELL_LEN))
            remaining //= pow10(CELL_LEN)
        # tmp now has least-significant cell first; reverse to get most-significant first
        cells_start = list(reversed(tmp))

    # if there are no cells above the root, just simulate the root to 0
    if not cells_start:
        # ensure root memoization logic works with prev_dsum 0
        return _root_steps(0)

    # otherwise enumerate the actual prefixes that will be visited.
    return _process_cells_from_prefix(cells_start)

# -----------------------
# Example usage (not executed here):
#   print(steps_from(2))         # -> should print 3 per your example: 20->18->9->0
#   print(steps_from(10**k))     # -> matches original a(k)
# -----------------------
print(steps_from(672597526590221696))

Embed on website

To embed this program on your website, copy the following code and paste it into your website's HTML: