Pythonでa++(インクリメント)を実現する

March 20, 2022, 3:29 p.m. edited March 21, 2022, 5:22 a.m.

#Python 

Python では a++ ができないのはよく知られる事実。

だが、 a++ したい!!!

という邪悪な話。最終的には

# coding: mylang

a = 5
print(a++)
print(a)
5
6

とできるようになる。なお、今回の実験は Python 3.9.1 かつ venv 環境下でおこなっている。

まずは観察

a++ を普通にやるとどうなるか。

a = 5
print(a++)

と書いて実行。

  File "main.py", line 2
    a++
       ^
SyntaxError: invalid syntax

というように構文エラーとなる。うん、知ってた。

新しい構文の追加

a++ は普通の Python では構文エラーとなるので、つまり Python に新たな構文を追加すれば良いということになる。そんな方法はあるのだろうかとググったら Can you add new statements to Python's syntax? の回答にこんなものが。これなら C にまで降りずとも Python で完結できそうだ。

つまり、ファイル冒頭にマジックコメントとして # coding: mylang として書いておき、この mylang として従来の utf8 Python に拡張した自作構文を追加するという戦略である(回答の Another fairly neat (albeit hacky) solution の方)。

ということで書いてみるのだが、様々な問題とぶつかることになった。

そもそも mylang が見つからない

普通は新たな codecs を追加する場合は、そのファイルを import して codecs.register() が通ることで認識されるのだが、マジックコメントの場合は import の前なので認識できないのである。回答では .pythonrcsite.py に書くようにとのことだったが、前者の .pythonrc は環境変数 PYTHONSTARTUP に指定する必要があり、というかそもそも対話型での実行でしか動かない。後者の site.pyPython の実行時に自動でインポートされるものであることを利用してここに書き込んでしまうということだが、さすがにそこまで環境を汚したくはない。

ここで、後者をもう少し追求すると、 usercustomize.py (もしくは sitecustomize.py) に行き着く。これは site.py が実行してくれるスクリプトであり、ユーザが置いておくものらしい。その置き場所は基本的には Python ライブラリフォルダ内であり、やはり環境を汚しかねない1。が、さらに参照場所のパスを追加することができ、 PYTHONPATH に追加すれば良い。また、 ENABLE_USER_SITEFalse の環境だと sitecustomize.py しか見てくれない。

したがって、 mylang の定義を sitecustomize.py に書いて、これをプロジェクトのフォルダに置いたうえで対象の main.pyPYTHONPATH=. python main.py で実行すれば良いということになる。

StreamReader が呼ばれない

回答のコードを参考に色々試していたのだが、まったく動く気配がない。というか StreamReader が呼ばれない。なぜだろうと思ってよく見たらこの回答は 2008 年のものだった。それはもう色々仕様が変わっていてもおかしくない。

そこで、もう少し調べると gist のコードが見つかった。 IncrementalDecoder を継承してコードを差し込む形である。もっとも、このコードでもまだ微妙に動かなかったので、 encodings.utf_8.IncrementalDecoder が継承している codecs.BufferedIncrementalDecoderdecode を参考に修正し、そのうえでインクリメントのコードを入れた。

なお、 untokenize はトークンタイプおよびトークン文字列のイテレータさえ返せば問題なく動く

どうやって実装しよう

a++ を置き換えてインクリメントとするシンタックスシュガーを実装しなければならない。単純に

a; a += 1

としてしまうと 2 つの文になってしまうので print 内部などで使えなくなってしまう。そこで、 Python 3.8 で導入されたセイウチ演算子を使う。どのように実装するか悩んだが、素晴らしい回答を見つけたので、

(a:=a+1)-1

とする。

sitecustomize.py の内容

上記の文献に加えて日本語資料も見つつ、以下の sitecustomize.py ができた。

import tokenize
from tokenize import TokenInfo
from token import tok_name
import codecs, encodings
from io import StringIO
from encodings import utf_8
from typing import NamedTuple

UTF8 = encodings.search_function('utf8')

class Token(NamedTuple):
    type: int
    name: str

def inject(a):
    l = []
    var = None
    one_plus = False
    for type, name, _, _, _ in a:
        #print(tok_name[type], name)
        l.append(Token(type, name))
        #print(l)
        if l[0].type == tokenize.NAME:
            if len(l) > 1:
                if l[1].name != '+' and l[1].name != '-':
                    for t, n in l:
                        yield t, n
                    l = []
                    continue
            if len(l) < 3:
                continue
            if l[1].name == '+' and l[2].name == '+':
                yield tokenize.OP, '('
                yield tokenize.OP, '('
                yield tokenize.NAME, l[0].name
                yield tokenize.OP, ':='
                yield tokenize.NAME, l[0].name
                yield tokenize.OP, '+'
                yield tokenize.NUMBER, '1'
                yield tokenize.OP, ')'
                yield tokenize.OP, '-'
                yield tokenize.NUMBER, '1'
                yield tokenize.OP, ')'
                l = []
                continue
            if l[1].name == '-' and l[2].name == '-':
                yield tokenize.OP, '('
                yield tokenize.OP, '('
                yield tokenize.NAME, l[0].name
                yield tokenize.OP, ':='
                yield tokenize.NAME, l[0].name
                yield tokenize.OP, '-'
                yield tokenize.NUMBER, '1'
                yield tokenize.OP, ')'
                yield tokenize.OP, '+'
                yield tokenize.NUMBER, '1'
                yield tokenize.OP, ')'
                l = []
                continue
            for t, n in l:
                yield t, n
            l = []
            continue
        elif l[0].name == '+':
            if len(l) > 1:
                if l[1].name != '+' and l[1].name != '-':
                    for t, n in l:
                        yield t, n
                    l = []
                    continue
            if len(l) < 3:
                continue
            if l[1].name == '+' and l[2].type == tokenize.NAME:
                yield tokenize.OP, '('
                yield tokenize.NAME, l[2].name
                yield tokenize.OP, ':='
                yield tokenize.NAME, l[2].name
                yield tokenize.OP, '+'
                yield tokenize.NUMBER, '1'
                yield tokenize.OP, ')'
                l = []
                continue
            for t, n in l:
                yield t, n
            l = []
            continue
        elif l[0].name == '-':
            if len(l) > 1:
                if l[1].name != '+' and l[1].name != '-':
                    for t, n in l:
                        yield t, n
                    l = []
                    continue
            if len(l) < 3:
                continue
            if l[1].name == '-' and l[2].type == tokenize.NAME:
                yield tokenize.OP, '('
                yield tokenize.NAME, l[2].name
                yield tokenize.OP, ':='
                yield tokenize.NAME, l[2].name
                yield tokenize.OP, '-'
                yield tokenize.NUMBER, '1'
                yield tokenize.OP, ')'
                l = []
                continue
            for t, n in l:
                yield t, n
            l = []
            continue

        for t, n in l:
            yield t, n
        l = []

    for t, n in l:
        yield t, n
    l = []

def transform(stream):
    #print("call transform")
    a = tokenize.generate_tokens(StringIO(stream).readline)
    return tokenize.untokenize(inject(a))

def decode(input, errors='strict'):
    #print('call decode')
    if isinstance(input, memoryview):
        input = input.tobytes().decode('utf-8')
    return UTF8.decode(transform(StringIO(input)), errors)

class IncrementalDecoder(utf_8.IncrementalDecoder):
    def decode(self, input, final=False):
        #print('decode in incrementaldecoder')
        return transform(super().decode(input, final))

class StreamReader(utf_8.StreamReader):
    def __init__(self, *args, **kwargs):
        codecs.StreamReader.__init__(self, *args, **kwargs)
        self.stream = StringIO(transform(self.stream))

def search_function(s):
    #print('call search_function')
    if s != 'mylang':
        return None

    #print('mylang is used')
    return codecs.CodecInfo(
        name='mylang',
        encode=UTF8.encode,
        decode=decode,
        incrementalencoder=UTF8.incrementalencoder,
        incrementaldecoder=IncrementalDecoder,
        streamreader=StreamReader,
        streamwriter=UTF8.streamwriter)

codecs.register(search_function)

これをプロジェクトフォルダ(適当に ~/hoge/)に置いたうえで ~/hoge/main.py

# coding: mylang

a = 5
print(a++)
print(a)

b = ++a
print(b--)
print(--b)
print(b)

と書いて ~/hoge 内で PYTHONPATH=. python main.py と実行すると

5
6
7
5
5

と期待していた出力が得られる。

今後の展望

これせっかくなのでライブラリにしたいなぁ・・・。


  1. python -m site で確認できる。この sys.path に入っていれば良いみたい