blockの巣

【Rust】ユーザー定義型にrandクレートの乱数生成を実装する

2024/09/10 14:40 公開
Rust

この記事はRust 1.81.0rand 0.8.5で動作確認しています。


下記のようなPoint2Dに対して、randクレートを使って乱数生成を実装する。

#[derive(Debug)]
pub struct Point {
    x: f32,
    y: f32,
}

値の範囲を指定しない乱数生成

値の範囲を指定しない生成であればDistribution<Point>Standardを実装すれば良い。

use rand::prelude::*;
use rand::distributions::Standard;

impl Distribution<Point> for Standard
{
    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Point {
        Point { x: rng.gen(), y: rng.gen() }
    }
}

fn main() {
    let mut rng = rand::thread_rng();
    let p: Point = rng.gen();
    println!("{:?}", p);
}

値の範囲を指定する乱数生成

範囲を指定する方は若干複雑ですが、必要なのはUniformSamplerSampleUniformを実装すること。
UniformSamplerUniformを使って範囲を指定するためのトレイトで、SampleUniformUniformSamplerを使って乱数生成を行うためのトレイト。 注意する点として、Pointがprivateだとコンパイルエラーになる。

use rand::distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler};
use rand::distributions::Uniform;
use rand::prelude::*;

pub struct UniformPoint {
    x_sampler: Uniform<f32>,
    y_sampler: Uniform<f32>,
}

impl UniformSampler for UniformPoint {
    type X = Point;
    fn new<B1, B2>(low: B1, high: B2) -> Self
    where
        B1: SampleBorrow<Self::X> + Sized,
        B2: SampleBorrow<Self::X> + Sized,
    {
        Self {
            x_sampler: Uniform::new(low.borrow().x, high.borrow().x),
            y_sampler: Uniform::new(low.borrow().y, high.borrow().y),
        }
    }

    fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self
    where
        B1: SampleBorrow<Self::X> + Sized,
        B2: SampleBorrow<Self::X> + Sized,
    {
        Self {
            x_sampler: Uniform::new_inclusive(low.borrow().x, high.borrow().x),
            y_sampler: Uniform::new_inclusive(low.borrow().y, high.borrow().y),
        }
    }

    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
        Point {
            x: self.x_sampler.sample(rng),
            y: self.y_sampler.sample(rng),
        }
    }
}

impl SampleUniform for Point {
    type Sampler = UniformPoint;
}

fn main() {
    let mut rng = rand::thread_rng();
    let uniform = UniformPoint::new(Point { x: -10.0, y: -5.0 }, Point { x: 10.0, y: 5.0 });
    let p = uniform.sample(&mut rng);
    println!("{:?}", p);
}

上記の例ではxyの範囲をPoint2つで受け取っているが、別の型でも問題ないはず。

Pointの要素型をGenericsにした場合のUniformSamplerの実装

細かいコードは省略するが

pub struct Point<T> {
    x: T,
    y: T,
}

という型に対してであれば、UniformPoint<T>TSampleUniformを実装していることを条件にすれば良い。