blockの巣

Rustのconst genericsの引数によって型の定義を変える(っぽいこと)

2023/05/11 10:41 公開
2023/07/07 01:15

コード修正

Generics Rust

若干釣りっぽいタイトルです

やりたいこと

Vector<T, const D: usize>のような要素型と次元数を受け取るVector型があったとして

// x, yを持つ
let v2 = Vector::<f32, 2> { x: 1.1, y: 2.2 };
// x, y, zを持つ
let v3 = Vector::<f32, 3> {
    x: 1.1,
    y: 2.2,
    z: 3.3,
};
// x, y, z, wを持つ
let v4 = Vector::<f32, 4> {
    x: 1.1,
    y: 2.2,
    z: 3.3,
    w: 3.3,
};
// 要素数5の配列を持つ
let v5 = Vector::<f32, 5> {
    elements: [1.1, 2.2, 3.3, 4.4, 5.5],
};

のようにVector::<T, 2>x, yを持つ、
Vector::<T, 3>x, y, zを持つ、
Vector::<T, 4>x, y, z, wを持つ、
D >= 5となるVector::<T, D>は配列elements: [T; D]を持つ、
というような実装にしたい。

実装

実装全文。
重要なのはVectorTypeHolderVectorTypeResolverです。

use std::marker::PhantomData;

pub struct Vector2<T> {
    pub x: T,
    pub y: T,
}
pub struct Vector3<T> {
    pub x: T,
    pub y: T,
    pub z: T,
    pub w: T,
}
pub struct Vector4<T> {
    pub x: T,
    pub y: T,
    pub w: T,
}
pub struct ArrayWrapper<T, const D: usize> {
    elements: [T; D],
}

pub trait VectorTypeHolder<T, const D: usize> {
    type Vector;
}
pub struct VectorTypeResolver<T, const D: usize> {
    _marker: PhantomData<fn() -> [T; D]>,
}

impl<T> VectorTypeHolder<T, 2> for VectorTypeResolver<T, 2> {
    type Vector = Vector2<T>;
}
impl<T> VectorTypeHolder<T, 3> for VectorTypeResolver<T, 3> {
    type Vector = Vector3<T>;
}
impl<T> VectorTypeHolder<T, 4> for VectorTypeResolver<T, 4> {
    type Vector = Vector4<T>;
}
impl<T> VectorTypeHolder<T, 5> for VectorTypeResolver<T, N> {
    type Vector = ArrayWrapper<T, 5>;
}

pub type Vector<T, const D: usize> = <VectorTypeResolver<T, D> as VectorTypeHolder<T, D>>::Vector;

解説

Vector2, Vector3, Vector4と配列をラップしたArrayWrapperはそれぞれ2~4次元の時に使用する型と5次元以上のときに使用する型です。 下記のVectorTypeHolderは要素型と次元数をgenerics引数として受けとりVectorという関連型を持つtraitです。

pub trait VectorTypeHolder<T, const D: usize> {
    type Vector;
}

ここからVectorTypeHolderと型を取り出すためにVectorTypeResolverという型を定義し、VectorTypeResolverに対してVectorTypeHolderを実装します。

pub struct VectorTypeResolver<T, const D: usize> {
    _marker: PhantomData<fn() -> [T; D]>,
}

VectorTypeResolverが持つ_markerはジェネリクス引数を消費するだけのものなので無視して大丈夫です。気になる方はPhantomDataを調べてください。
traitを実装するときconst genericsの値を指定して実装することができるため以下のようにDが2のときのみの実装を行うことができます。

impl<T> VectorTypeHolder<T, 2> for VectorTypeResolver<T, 2> {
    type Vector = Vector2<T>;
}

同様にVectorTypeHolder::VectorがDが3,4のときにVector3, Vector4を、Dが5のときはArrayWrapper<T, 5>となるように実装します。

impl<T> VectorTypeHolder<T, 2> for VectorTypeResolver<T, 2> {
    type Vector = Vector2<T>;
}
impl<T> VectorTypeHolder<T, 3> for VectorTypeResolver<T, 3> {
    type Vector = Vector3<T>;
}
impl<T> VectorTypeHolder<T, 4> for VectorTypeResolver<T, 4> {
    type Vector = Vector4<T>;
}
impl<T> VectorTypeHolder<T, 5> for VectorTypeResolver<T, 5> {
    type Vector = ArrayWrapper<T, 5>;
}

最後にVector<T, D>として型を取得できるように

pub type Vector<T, const D: usize> = <VectorTypeResolver<T, D> as VectorTypeHolder<T, D>>::Vector;

と型エイリアスを定義します。

ここまででで最初のやりたいことが実現できるようになりました。

Dが5より大きい場合

実はこの実装だとD > 5の場合にコンパイルエラーになります。
それを解消するには

impl<T> VectorTypeHolder<T, 6> for VectorTypeResolver<T, 6> {
    type Vector = ArrayWrapper<T, 6>;
}
impl<T> VectorTypeHolder<T, 7> for VectorTypeResolver<T, 7> {
    type Vector = ArrayWrapper<T, 7>;
}
impl<T> VectorTypeHolder<T, 8> for VectorTypeResolver<T, 8> {
    type Vector = ArrayWrapper<T, 8>;
}
//
// 以下D >= 9の実装が続く
//

というように使用する分VectorTypeHolderの実装をする必要があります。
seq_macroクレートのseqマクロを使用すると以下のようにまとめて実装できます。

// use seq_macro::seq;
seq!(N in 5..256 {
    impl<T> VectorTypeHolder<T, N> for VectorTypeResolver<T, N> {
        type Vector = ArrayWrapper<T, N>;
    }
});

参考

テンプレートの特殊化(Rust版)