【環境・ソート編】Rustで色々なアルゴリズムを書きつつPyO3を使ってPythonで動かす

Aug. 13, 2023, 12:08 a.m. edited Aug. 14, 2023, 1:36 a.m.

#PyO3  #Rust  #Python 

仕事内容が変わりアルゴリズムを本格的に知っておこうという気持ちになったので、Rust で色々書いていく企画。せっかくなので、PyO3(maturin)を使って、Rust で書いたアルゴリズムを Python で使えるようにしていく。

GitHub レポジトリはここ: algorithm_dict

今回実装するのはバブルソート・選択ソート・マージソート・クイックソートの4種類。それぞれ generics をサポートするように作っていく。

アルゴリズムの概要と実装

概要も何も、有名なアルゴリズムなので…。Wikipedia を見れば十分。実装は sort.rs にある。

ただ、マージソートやクイックソートではデメリットも存在する再帰関数を使わずに、キュー(VecDeque)を利用して実現している。

PyO3 および maturin

Rust で書いたプログラムを Python で呼び出すのに使うライブラリ・ツールとして PyO3+maturin はスタンダードになりつつある。以前使い方を記事(もう古い)に書いたのだが、もうその使い方も古くなってしまった。このようにバージョンが更新されるたびに色々変化が大きいので、使い方は各自ググるようにしてほしい。

参考

PyO3 で generics

まず、普通に generics バージョンのバブルソートを考えると、以下のようになる:

fn bubble_sort<T: Copy + PartialOrd>(vec: &mut [T]) {
    for i in 0..vec.len()-1 {
        for j in i..vec.len() {
            if vec[i] > vec[j] {
                let tmp = vec[i];
                vec[i] = vec[j];
                vec[j] = tmp;
            }
        }
    }
}

これは普通に

fn main() {
    let mut li = vec![10, -5, 3];
    bubble_sort(&mut li);
    println!("{:?}", li);

    let mut lf = vec![3.14, 2.718];
    bubble_sort(&mut lf);
    println!("{:?}", lf);
}

で確認できる。(ちなみに引数で Vec を使わない理由

一見すると、これを単に pyfunction としてやればいいだけに見えるが、 PyO3 では generics はサポートされていないのである(参考:想像以上に丁寧な teratail の回答)。

ではどうすれば良いかというと、その teratail での回答に近い、 StackOverflow に載っていたこの手法を用いる。つまり、いったん引数を PyObject として受け取り、それをキャストして generics 版の本体のバブルソートを呼び出すというものである:

fn _call_each_type_sort(fi: fn(&mut [i64]), ff: fn(&mut [f64]), fu8: fn(&mut [u8]), vec: PyObject) -> PyResult<PyObject> {
    Python::with_gil(|py| {
        if let Ok(mut vec) = vec.extract::<Vec<i64>>(py) {
            fi(&mut vec);
            return Ok(vec.to_object(py));
        }
        else if let Ok(mut vec) = vec.extract::<Vec<f64>>(py) {
            ff(&mut vec);
            return Ok(vec.to_object(py));
        }
        else if let Ok(vec) = vec.extract::<String>(py) {
            let mut vec = vec.into_bytes();
            fu8(&mut vec);
            return Ok(String::from_utf8(vec).unwrap().to_object(py));
        }
        Err(PyTypeError::new_err("Not supported"))
    })
}

#[pyfunction]
pub fn bubble_sort(vec: PyObject) -> PyResult<PyObject> {
    _call_each_type_sort(_bubble_sort::<i64>, _bubble_sort::<f64>, _bubble_sort::<u8>, vec)
}

fn _bubble_sort<T: Copy + PartialOrd>(vec: &mut [T]) {
    for i in 0..vec.len()-1 {
        for j in i..vec.len() {
            if vec[i] > vec[j] {
                let tmp = vec[i];
                vec[i] = vec[j];
                vec[j] = tmp;
            }
        }
    }
}

また、この際、他のソートでも同じようなことをするので、まとめて _call_each_type_sort で扱えるようにしている。(引数に generics の関数を渡すことさえできればもっと綺麗になるのに…

テストおよび展望

テストには pytest を利用している(初めて使ったけど pytest と打つだけで全部テストしてくれるの便利すぎでは??)。そのテストケースは次のようにしている(全体は test_sort.py):

def test_sort():
    l = [3, 10, -2, 60, 2]
    sorted_l = sorted(l)

    assert sorted_l == bubble_sort(l)
    assert sorted_l == selection_sort(l)
    assert sorted_l == merge_sort(l)
    assert sorted_l == quick_sort(l)

    l = [0.53, -0.83, 0.1234, 24, -56, 12, 1, 1]
    sorted_l = sorted(l)

    assert sorted_l == bubble_sort(l)
    assert sorted_l == selection_sort(l)
    assert sorted_l == merge_sort(l)
    assert sorted_l == quick_sort(l)

    l = 'hgrwoghoqgqpoh204hru'
    sorted_l = ''.join(sorted(l))

    assert sorted_l == bubble_sort(l)
    assert sorted_l == selection_sort(l)
    assert sorted_l == merge_sort(l)
    assert sorted_l == quick_sort(l)

    l = np.array([5, -1, 3, 0, 4, 3])
    sorted_l = sorted(l)

    assert sorted_l == bubble_sort(l)
    assert sorted_l == selection_sort(l)
    assert sorted_l == merge_sort(l)
    assert sorted_l == quick_sort(l)

このように、整数のリストや浮動小数のリストだけでなく、文字列や numpy まで動くようになっている。素晴らしい!

ソートの次はグラフ系をやろうかなと思っている。