代码之家  ›  专栏  ›  技术社区  ›  LudvigH

如何在迭代中将类型异构数据与类型泛型算法相结合?

  •  0
  • LudvigH  · 技术社区  · 10 月前

    我用rust编写了一些通用的ODE求解器。我希望通过泛型获得两件事:高效的代码生成和编译时检查。然后,我有一个异构的初始值问题(IVP)集合,我想用这些求解器来处理。不同的IVP使用不同类型的浮点数,具有不同形状的向量等。我现在想在这组异构的IVP上对我的求解器进行基准测试。

    经过一些edt、重写等,我将代码简化为大约200行的示例代码库。它不求解ODE,而是进行矩阵乘法,因为这需要更少的代码行。类型泛型算法的一般问题是相同的。请看下面或在操场上: rust playground

    我决定将所有问题规范封装在枚举中,这样我就可以有一个集合来迭代。问题集有一个 type_id 方法来反思与此问题相关的泛型类型参数。然后,有一个哈希表,其中包含该特定类型id的通用算法的实例化。

    该解决方案在技术上可行,但很脆弱。我希望Rust中有一个更好的解决方案来实现这一点。也许是一些不安全的选角、宏魔法什么的?!

    当前方法的问题:

    • 我需要根据我使用的问题添加枚举变量。
    • 我需要在解算器和问题枚举中添加变量,并且没有编译时检查,我有所有需要的实例化。
    • 我需要在HashMap中为每个枚举添加实例,并且没有编译时检查,我拥有所有需要的实例。
    • 的实施 Problem::type_id 有一个从枚举变量的名称到类型id的手动映射。这很容易出错。

    是否有一种方法可以通过使用其他Rust机制来解决这些问题,允许在问题集上迭代并对问题应用所有算法,同时仍然保留单态化和泛型?

    use nalgebra as na; // 0.32.3
    #[allow(non_camel_case_types,non_snake_case)]
    
    
    mod bench_utils{
        #[allow(non_snake_case)]
        use nalgebra as na;
        /// Generic problem description
        /// They all represent two matrices A and B that needs to be multiplied
        #[derive(Clone)]
        pub struct GenProb<T,D>
        where
        D: na::Dim,
        na::DefaultAllocator: na::allocator::Allocator<T, na::Const<3>, D>
         + na::allocator::Allocator<T, D,na::Const<3>> 
        {
            pub name: String,
            pub A: na::OMatrix<T,na::Const<3>,D>,
            pub B: na::OMatrix<T,D,na::Const<3>>
        }
        
        /// Generic solver description
        /// A matrix multiplication algorithm
        #[derive(Clone)]
        pub struct GenSolv<T,D>
        where 
        T: na::Scalar,
        D: na::Dim,
        na::DefaultAllocator: na::allocator::Allocator<T, na::Const<3>, D>
         + na::allocator::Allocator<T, D,na::Const<3>> 
        {
            pub name: String,
            pub solver: fn(na::OMatrix<T,na::Const<3>,D>,na::OMatrix<T,D,na::Const<3>>) -> na::OMatrix<T,na::Const<3>,na::Const<3>>
        }
        
        /// Instantiate all thinkable GenProb that you might need
        pub enum Problem {
            P_1_64(GenProb<f64,na::Const<1>>),
            P_1_32(GenProb<f32,na::Const<1>>),
            P_5_64(GenProb<f64,na::Const<5>>),
            P_5_32(GenProb<f32,na::Const<5>>),
            P_dyn_64(GenProb<f64,na::Dyn>),
            P_dyn_32(GenProb<f32,na::Dyn>),
        }
        impl Problem {
            pub fn type_id(&self) -> std::any::TypeId {
                match self {
                    Problem::P_1_64(_) => std::any::TypeId::of::<GenProb<f64,na::Const<1>>>(),
                    Problem::P_1_32(_) => std::any::TypeId::of::<GenProb<f32,na::Const<1>>>(),
                    Problem::P_5_64(_) => std::any::TypeId::of::<GenProb<f64,na::Const<5>>>(),
                    Problem::P_5_32(_) => std::any::TypeId::of::<GenProb<f32,na::Const<5>>>(),
                    Problem::P_dyn_64(_) => std::any::TypeId::of::<GenProb<f64,na::Dyn>>(),
                    Problem::P_dyn_32(_) => std::any::TypeId::of::<GenProb<f32,na::Dyn>>(),
                }
            }
        }
        
        /// Instantiate all thinkable GenProb that you might need
        pub enum Solver {
            S_1_64(GenSolv<f64,na::Const<1>>),
            S_1_32(GenSolv<f32,na::Const<1>>),
            S_5_64(GenSolv<f64,na::Const<5>>),
            S_5_32(GenSolv<f32,na::Const<5>>),
            S_dyn_64(GenSolv<f64,na::Dyn>),
            S_dyn_32(GenSolv<f32,na::Dyn>),
        }
        
        pub fn process<T,D>(prob:GenProb<T,D>,solv:GenSolv<T,D>) where
        D: na::Dim,
        T: na::RealField,
        na::DefaultAllocator: na::allocator::Allocator<T, na::Const<3>, D>
         + na::allocator::Allocator<T, D,na::Const<3>>,
        {
            let now = std::time::Instant::now();
            let _C = (solv.solver)(prob.A,prob.B);
            let elapsed = now.elapsed();
            println!("Running {:10} on {:15} took {:?} seconds",solv.name,prob.name,elapsed)
        }
    }
    
    mod multipliers{
        //! A module containing matrix multiplication methods
        
        use nalgebra as na;
        use nalgebra; // 0.32.5
    
        /// The built in matrix multiply
        #[allow(non_snake_case)]
        pub fn na_star<T,D>(
            A: na::OMatrix<T,na::Const<3>,D>,
            B: na::OMatrix<T,D,na::Const<3>>
        ) -> na::OMatrix<T,na::Const<3>,na::Const<3>>
        where
        T: na::RealField,
        D: na::Dim,
        na::DefaultAllocator: na::allocator::Allocator<T, na::Const<3>, D>
         + na::allocator::Allocator<T, D,na::Const<3>> {
                A * B
        }
        
        /// An alternative implementation
        #[allow(non_snake_case)]
        pub fn manual<T,D>(
            A: na::OMatrix<T,na::Const<3>,D>,
            B: na::OMatrix<T,D,na::Const<3>>
        ) -> na::OMatrix<T,na::Const<3>,na::Const<3>>
        where
        T: na::RealField + std::iter::Sum + Copy,
        D: na::Dim,
        na::DefaultAllocator: na::allocator::Allocator<T, na::Const<3>, D>
         + na::allocator::Allocator<T, D,na::Const<3>> {
            let mut out = vec![];
            for r in 0..3 {
                for c in 0..3 {
                    out.push(
                        A.row(r).iter().zip(B.column(c).iter()).map(|(&a,&b)| a*b).sum()
                    )
                }
            }
            na::Matrix3::from_vec(out)
        }
    }
    
    use bench_utils::*;
    
    pub fn main() {
    
        let problems = vec![
            Problem::P_1_64(GenProb{name:"1 64 static".into(), A: na::Matrix3x1::new(0.0, 1.0, 2.0),B: na::Matrix1x3::new(2.0, 1.0, 0.0)}),
            Problem::P_5_64(GenProb{name:"5 64 static".into(), A: na::Matrix3x5::new(1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0),B: na::Matrix5x3::new(1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0)}),
            Problem::P_1_32(GenProb{name:"1 32 static".into(), A: na::Matrix3x1::new(0.0, 1.0, 2.0),B: na::Matrix1x3::new(2.0, 1.0, 0.0)}),
            Problem::P_5_32(GenProb{name:"5 32 static".into(), A: na::Matrix3x5::new(1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0),B: na::Matrix5x3::new(1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0)}),
            Problem::P_dyn_64(GenProb{name:"1 64 dyn".into(), A: na::Matrix::<f64,na::Const<3>,na::Dyn,_>::from_vec(vec![0.0,1.0,2.0]), B: na::Matrix::<f64,na::Dyn,na::Const<3>,_>::from_vec(vec![0.0,1.0,2.0]) }),
            Problem::P_dyn_64(GenProb{name:"5 64 dyn".into(), A: na::Matrix::<f64,na::Const<3>,na::Dyn,_>::from_vec(vec![1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0]), B: na::Matrix::<f64,na::Dyn,na::Const<3>,_>::from_vec(vec![1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0]) }),
            Problem::P_dyn_32(GenProb{name:"1 32 dyn".into(), A: na::Matrix::<_,na::Const<3>,na::Dyn,_>::from_vec(vec![0.0,1.0,2.0]), B: na::Matrix::<_,na::Dyn,na::Const<3>,_>::from_vec(vec![0.0,1.0,2.0]) }),
            Problem::P_dyn_32(GenProb{name:"5 32 dyn".into(), A: na::Matrix::<_,na::Const<3>,na::Dyn,_>::from_vec(vec![1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0]), B: na::Matrix::<_,na::Dyn,na::Const<3>,_>::from_vec(vec![1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0]) }),
            ];
        let solvers: std::collections::HashMap<_, Vec<_>> = std::collections::HashMap::from([
            (std::any::TypeId::of::<GenProb<f64,na::Const<1>>>(),vec![
                Solver::S_1_64(GenSolv{name:"nalgebra*".into(),solver:multipliers::na_star}),
                Solver::S_1_64(GenSolv{name:"manual".into(),solver:multipliers::manual}),
            ]),
            (std::any::TypeId::of::<GenProb<f64,na::Const<5>>>(),vec![
                Solver::S_5_64(GenSolv{name:"nalgebra*".into(),solver:multipliers::na_star}),
                Solver::S_5_64(GenSolv{name:"manual".into(),solver:multipliers::manual}),
            ]),
            (std::any::TypeId::of::<GenProb<f64,na::Dyn>>(),vec![
                Solver::S_dyn_64(GenSolv{name:"nalgebra*".into(),solver:multipliers::na_star}),
                Solver::S_dyn_64(GenSolv{name:"manual".into(),solver:multipliers::manual}),
            ]),
            (std::any::TypeId::of::<GenProb<f32,na::Const<1>>>(),vec![
                Solver::S_1_32(GenSolv{name:"nalgebra*".into(),solver:multipliers::na_star}),
                Solver::S_1_32(GenSolv{name:"manual".into(),solver:multipliers::manual}),
            ]),
            (std::any::TypeId::of::<GenProb<f32,na::Const<5>>>(),vec![
                Solver::S_5_32(GenSolv{name:"nalgebra*".into(),solver:multipliers::na_star}),
                Solver::S_5_32(GenSolv{name:"manual".into(),solver:multipliers::manual}),
            ]),
            (std::any::TypeId::of::<GenProb<f32,na::Dyn>>(),vec![
                Solver::S_dyn_32(GenSolv{name:"nalgebra*".into(),solver:multipliers::na_star}),
                Solver::S_dyn_32(GenSolv{name:"manual".into(),solver:multipliers::manual}),
            ]),
            
        ]);
        for prob in &problems{ 
            let id = prob.type_id();
            for solv in &solvers[&id] {
                match (prob,solv) {
                    ( Problem::P_1_64(p), Solver::S_1_64(s))=>process(p.clone(),s.clone()),
                    ( Problem::P_5_64(p), Solver::S_5_64(s))=>process(p.clone(),s.clone()),
                    ( Problem::P_dyn_64(p), Solver::S_dyn_64(s))=>process(p.clone(),s.clone()),
                    ( Problem::P_1_32(p), Solver::S_1_32(s))=>process(p.clone(),s.clone()),
                    ( Problem::P_5_32(p), Solver::S_5_32(s))=>process(p.clone(),s.clone()),
                    ( Problem::P_dyn_32(p), Solver::S_dyn_32(s))=>process(p.clone(),s.clone()),
                    _ => eprintln!("This problem/solver pair is not compatible")
                }
            }
        }
    }
    
    0 回复  |  直到 9 月前