blockの巣

const genericsの引数によってベクトル型の定義を変えたもので行列型を定義したかった

2024/02/21 15:25 公開
Rust

nightlyの話です。
タイトルだけだと何がしたいのかわからないですね。

以前書いた Rustのconst genericsの引数によって型の定義を変える(っぽいこと) に関連するような話です。


やりたいことは

struct Matrix<T, const ROW: usize, const COL: usize> {
    elements: [Vector<T, COL>; ROW],
}

という定義で行列型を定義する。ただし、Vector型は引数の定数によって要素を以下のように変えたい。

// 1次元ベクトル
pub struct Vector1<T> {
    pub x: T,
}
// 2次元ベクトル
pub struct Vector2<T> {
    pub x: T,
    pub y: T,
}
// 3次元ベクトル
pub struct Vector3<T> {
    pub x: T,
    pub y: T,
    pub z: T,
}
// 4次元ベクトル
pub struct Vector4<T> {
    pub x: T,
    pub y: T,
    pub z: T,
    pub w: T,
}
// 5次元ベクトル以上
pub struct ArrayWrapper<T, const N: usize> {
    pub elements: [T; N],
}

つまり Matrix<f32, 1, 2>であれば

struct Matrix {
    elements: [Vector2<f32>; 1],
}

Matrix<f32, 3, 4>であれば

struct Matrix {
    elements: [Vector4<f32>; 3],
}

Matrix<f32, 3, 5>であれば

struct Matrix {
    elements: [ArrayWrapper<f32, 5>; 2],
}

のようになる。

Vector型の定義

これにはまずVector<T, D>を定義する必要があるが、これは冒頭に書いた Rustのconst genericsの引数によって型の定義を変える(っぽいこと) の内容でほぼ解決できます。 リンク先の内容ではD > 5の場合にエラーとなるのですが、以下のようにVectorTypeHolderUSE_ARRAY_WRAPPERという引数を追加して、D >= 5の場合にArrayWrapperを使うようにすれば解決できます。

// #![feature(generic_const_exprs)]が必要
pub trait VectorTypeHolder<T, const D: usize, const USE_ARRAY_WRAPPER: bool> {
    type Vector;
}

pub struct VectorTypeResolver<T, const D: usize> {
    _marker: std::marker::PhantomData<fn() -> [T; D]>,
}

impl<T> VectorTypeHolder<T, 1, false> for VectorTypeResolver<T, 1> {
    type Vector = Vector1<T>;
}
impl<T> VectorTypeHolder<T, 2, false> for VectorTypeResolver<T, 2> {
    type Vector = Vector2<T>;
}
impl<T> VectorTypeHolder<T, 3, false> for VectorTypeResolver<T, 3> {
    type Vector = Vector3<T>;
}
impl<T> VectorTypeHolder<T, 4, false> for VectorTypeResolver<T, 4> {
    type Vector = Vector4<T>;
}
impl<T, const N: usize> VectorTypeHolder<T, N, true> for VectorTypeResolver<T, N> {
    type Vector = ArrayWrapper<T, N>;
}

// D >= 5の場合に`USE_ARRAY_WRAPPER`をtrueにすることでArrayWrapperを使うようにする
pub type Vector<T, const D: usize> = <VectorTypeResolver<T, D> as VectorTypeHolder<T, D, {D >= 5}>>::Vector;

以下のようなコードを書くと確認できます。

fn main() {
    use std::any::type_name;
    println!("{}", type_name::<Vector<f32, 1>>());
    println!("{}", type_name::<Vector<f32, 2>>());
    println!("{}", type_name::<Vector<f32, 3>>());
    println!("{}", type_name::<Vector<f32, 4>>());
    println!("{}", type_name::<Vector<f32, 5>>());
    println!("{}", type_name::<Vector<f32, 6>>());
}

出力

playground::Vector1<f32>
playground::Vector2<f32>
playground::Vector3<f32>
playground::Vector4<f32>
playground::ArrayWrapper<f32, 5>
playground::ArrayWrapper<f32, 6>

Rust Playground

Matrix型の定義

ここまでできたのであればあとは簡単に見えます。Vector<T, D>は既に定義されているので

struct Matrix<T, const ROW: usize, const COL: usize> {
    elements: [Vector<T, COL>; ROW],
}

でOKに感じますが以下のようなコンパイルエラーが発生します。

error[E0308]: mismatched types
  --> src/main.rs:52:15
   |
52 |     elements: [Vector<T, COL>; ROW],
   |               ^^^^^^^^^^^^^^^^^^^^^ expected `{D >= 5}`, found `true`
   |
   = note: expected constant `{D >= 5}`
              found constant `true`

For more information about this error, try `rustc --explain E0308`.
warning: `playground` (bin "playground") generated 1 warning
error: could not compile `playground` (bin "playground") due to 1 previous error; 1 warning emitted

Rust Playground

型が不一致だそうです。何故……
何もわからないですができなかったことを覚えておくために記事を書きました。うまいことやる方法をご存知の方がいれば教えてください。