ふたを開けると力づく

この記事は, 室蘭工業大学 Advent Calendar 2021 18日目 の記事です.

はじめに

静的に型がチェックされない言語って辛いですよね.PythonJavascriptをはじめとした動的型付け言語は,実行時に諸々のエラーが現れるため非常に対処がしづらいです.そこで最近ではTypescriptやmypyといった静的に型チェックができるものが普及しつつあり,より安全なコードが書けるようになってきています.
しかし,いずれのケースでも結局は動的型付け言語を実行することになるため,動的型付け言語特有の特殊なコードはいくつもあります.その一つがオーバーロードです.
前回の記事ではTypescriptのオーバーロードに焦点をあてて特殊なコードを説明していました.本記事ではPythonオーバーロードに焦点をあて,どのようにオーバーロードを実現していくかについて記していきます.

singledispatch

実はPythonは標準ライブラリとしてオーバーロードを実現するための機能を提供しています.それがsingledispatchです.

singledispatchはPython3.4より導入された機能で,デコレータを活用することによってオーバーロードを実現します.以下に実行コードを示します.

from functools import singledispatch


@singledispatch
def f(a) -> None:
    """定義されていない型の場合に呼ばれる"""
    print("recieved any")


@f.register
def _f_int(n: int) -> None:
    """int のオーバーロード"""
    print(f"recieved int {n}")


@f.register
def _f_str(s: str) -> None:
    """str のオーバーロード"""
    print(f"recieved string {s}")


if __name__ == "__main__":
    f(10)
    f("str")
    f([])
recieved int 10
recieved string str
recieved any

オーバーロードを実現したい関数にsingledispatchをデコレートし,それぞれの実装に関数名.registerをデコレートすることでオーバーロード可能となります.Python3.6まではregisterは明示的に型を指定する必要があり,関数名.register(型名)とします.3.7以降でも明示的に指定することで好きな型でオーバーロードができます.デフォルトではアノテーションされた型が入ります.

さらに,メソッドに対するオーバーロードを実現するための機能としてsingledispatchmethodも存在します.singledispatchmethodはPython3.8により導入された機能で,singledispatchと同様に以下のように使用できます.

from functools import singledispatchmethod


class C:
    @singledispatchmethod
    def f(self, a) -> None:
        """定義されていない型の場合に呼ばれる"""
        print("recieved any")

    @f.register
    def _f_int(self, n: int) -> None:
        """int のオーバーロード"""
        print(f"recieved int {n}")

    @f.register
    def _f_str(self, s: str) -> None:
        """str のオーバーロード"""
        print(f"recieved string {s}")


if __name__ == "__main__":
    c: C = C()
    c.f(10)
    c.f("str")
    c.f([])
recieved int 10
recieved string str
recieved any

singledispatchの仕組み

Pythonは標準でオーバーロードを容易にできるためこれ以上言うことはないように思えますが,実は欠点もいくつか存在します.代表的なものでは以下があります.

これらに対処するための有効な手段はsingledispatchの拡張です.ただ,singledispatchについて知らなければ拡張は困難です.そこで本章では拡張するための1歩としてsingledispatchがどのようにしてオーバーロードを実現しているかの仕組みについて説明します.

singledispatchの構成

singledispatchは主に次の要素から構成されています.

  • registry
    • オーバーロードする関数を格納する変数
    • {type: func}という形式で保存している
  • register()
    • オーバーロードする関数を格納する関数
    • デコレートした関数をregistryに登録する
  • wrapper()
    • オーバーロードを担う関数
    • デコレートした関数が呼ばれたとき,registryから適切な関数を取り出して呼び出す

まず,オーバーロードしたい関数をデコレートし,その関数の代わりにwrapperが呼ばれるようにします.次に,registerによってオーバーロードの実態をregistryに登録します.そして,元の関数を呼び出された際にwrapperが代わりにregistryに登録された関数を呼び出すことでオーバーロードが実現されます.registryからの呼び出しはdispatch関数によって行います.
以下に簡単なイメージ図を示します.

f:id:chutmdo:20211219231556p:plain

singledispatchを拡張

さて,さっそくsingledispatchを拡張していきます.
まず,改良のベースとなるコードを作成します.singledispatchから使える機能を拝借し,以下をベースコードとします.なお,update_wrapperはなくても動作しますが,関数名が正しく取得できなくなるなど実用上で不便があるので書いておくと良いです.

from functools import singledispatch, update_wrapper

def mydispatch(func):
    dispatcher = singledispatch(func)

    def wrapper(*args, **kw):
        return dispatcher.dispatch(args[0].__class__)(*args, **kw)

    wrapper.register = dispatcher.register
    update_wrapper(wrapper, func)
    return wrapper

ベースコードができたので,先に述べた課題を解決する形で拡張例を記していきます.

任意の引数でオーバーロードする

ベースコードではargs[0]により1番目の引数をkeyとしています.ここを変えるだけで別の引数をkeyにできます.

def wrapper(*args, **kw):
    return dispatcher.dispatch(args[1].__class__)(*args, **kw)

特殊な型をオーバーロード:type

引数の型で直接registryへアクセスせず,型を加工してからアクセスさせることで特殊な型のオーバーロードも実現できます.
配列の次元でオーバーロードしたい場合,まずは以下のように値を特殊な型に変換する関数を記述します.

class Dim1:    ...

class Dim2:    ...

def get_type(val) -> type:
    def d(lst):
        return 0 if type(lst) != list else 1 + d(lst[0])

    dim: int = d(val)
    if dim == 1:
        return Dim1
    if dim == 2:
        return Dim2
    return val.__class__

次にこの関数を用いて分岐をするようにwrpper関数を書き換えます.

def wrapper(*args, **kw):
    t: type = get_type(args[0])
    return dispatcher.dispatch(t)(*args, **kw)

最後に関数の登録時に特殊な型を指定すればオーバーロードが実現できました.

@f.register
def f_int(n: int):
    print(f"recieved int {n}")


@f.register(Dim1)
def f_1dim(n: list[int]):
    print("recieved 1 dim")


@f.register(Dim2)
def f_2dim(lst: list[list[int]]):
    print("recieved 2 dim")


if __name__ == "__main__":
    f(10)
    f([10])
    f([[10]])
recieved int 10
recieved 1 dim
recieved 2 dim

特殊な型をオーバーロード:str

上の例ではオーバーロード用の特殊な型,Dim1,Dim2を定義していましたが,型が増えるたびにclassを追加するのは不便です.そこで,registryに登録する形式を変更することでより高い拡張性を実現します.
以下の例ではregistryを自前で定義し,{str: func}という形式で保存するようにしました.registryの自前定義に伴って,registrerも自分で定義します.

def mydispatch(func):
    registry: dict = {}

    def wrapper(*args, **kw):
        t: type = get_type_str(args[0])
        return registry[t](*args, **kw)

    def register(s, func=None):
        if func is None:
            return lambda f: register(s, f)
        registry.setdefault(s, func)
        return func

    wrapper.register = register
    update_wrapper(wrapper, func)
    return wrapper

引数情報を文字列に変換する関数,関数の登録をそれぞれ以下のように書けばオーバーロードができます.

def get_type_str(val) -> str:
    def d(lst):
        return 0 if type(lst) != list else 1 + d(lst[0])

    dim: int = d(val)
    if dim == 1:
        return "list[int]"
    if dim == 2:
        return "list[list[int]]"
    return val.__class__.__name__


def mydispatch(func):
    # 略

@mydispatch
def f():
    ...

@f.register("int")
def f_int(n: int):
    print(f"recieved int {n}")


@f.register("list[int]")
def f_1dim(n: list[int]):
    print("recieved 1 dim")


@f.register("list[list[int]]")
def f_2dim(lst: list[list[int]]):
    print("recieved 2 dim")


if __name__ == "__main__":
    f(10)
    f([10])
    f([[10]])
recieved int 10
recieved 1 dim
recieved 2 dim

型以外でオーバーロードする

registryは何らかの情報を関数と結びつけるものです.今までは暗黙の了解で型情報を用いていましたが,実際はなんでもよいです.
例えば,整数値が正か負か0かで分岐することも以下のように簡単にできます.

def mydispatch(func):
    registry: dict = {}

    def wrapper(*args, **kw):
        n: int = args[0]
        key: int = 0 if n == 0 else 1 if n > 0 else -1
        return registry[key](*args, **kw)

    def register(n: int, func=None):
        if func is None:
            return lambda f: register(n, f)
        registry.setdefault(n, func)
        return func

    wrapper.register = register
    update_wrapper(wrapper, func)
    return wrapper
@mydispatch
def f():
    ...


@f.register(-1)
def f_negative(n: int):
    print(f"n is negative: {n}")


@f.register(0)
def f_zero(n: int):
    print(f"n is zero: {n}")


@f.register(1)
def f_positive(n: int):
    print(f"n is positive: {n}")


if __name__ == "__main__":
    f(-10)
    f(0)
    f(5)
n is negative: -10
n is zero: 0
n is positive: 5

ここまで拡張できると実用的にも便利なことが多々ありそうですね.

おわりに

Typescriptと比べると簡単だし拡張楽だし最高です.また記事本文では触れていませんが,型アノテーションの情報を使ってオーバーロードもできるので見た目だけなら静的型付けとほぼ変わらなくなって素晴らしいです.
アノテーション情報の利用部分も拡張できればより良いものになると思いますが,その辺はまだ読んでいないので本記事では触れません.まだまだ拡張の余地を残すPython,調べててとても楽しいなぁ.