我用rust编写了一些通用的ODE求解器。我希望通过泛型获得两件事:高效的代码生成和编译时检查。然后,我有一个异构的初始值问题(IVP)集合,我想用这些求解器来处理。不同的IVP使用不同类型的浮点数,具有不同形状的向量等。我现在想在这组异构的IVP上对我的求解器进行基准测试。
经过一些edt、重写等,我将代码简化为大约200行的示例代码库。它不求解ODE,而是进行矩阵乘法,因为这需要更少的代码行。类型泛型算法的一般问题是相同的。请看下面或在操场上:
rust playground
我决定将所有问题规范封装在枚举中,这样我就可以有一个集合来迭代。问题集有一个
type_id
方法来反思与此问题相关的泛型类型参数。然后,有一个哈希表,其中包含该特定类型id的通用算法的实例化。
该解决方案在技术上可行,但很脆弱。我希望Rust中有一个更好的解决方案来实现这一点。也许是一些不安全的选角、宏魔法什么的?!
当前方法的问题:
-
我需要根据我使用的问题添加枚举变量。
-
我需要在解算器和问题枚举中添加变量,并且没有编译时检查,我有所有需要的实例化。
-
我需要在HashMap中为每个枚举添加实例,并且没有编译时检查,我拥有所有需要的实例。
-
的实施
Problem::type_id
有一个从枚举变量的名称到类型id的手动映射。这很容易出错。
是否有一种方法可以通过使用其他Rust机制来解决这些问题,允许在问题集上迭代并对问题应用所有算法,同时仍然保留单态化和泛型?
use nalgebra as na; // 0.32.3
#[allow(non_camel_case_types,non_snake_case)]
mod bench_utils{
#[allow(non_snake_case)]
use nalgebra as na;
/// Generic problem description
/// They all represent two matrices A and B that needs to be multiplied
#[derive(Clone)]
pub struct GenProb<T,D>
where
D: na::Dim,
na::DefaultAllocator: na::allocator::Allocator<T, na::Const<3>, D>
+ na::allocator::Allocator<T, D,na::Const<3>>
{
pub name: String,
pub A: na::OMatrix<T,na::Const<3>,D>,
pub B: na::OMatrix<T,D,na::Const<3>>
}
/// Generic solver description
/// A matrix multiplication algorithm
#[derive(Clone)]
pub struct GenSolv<T,D>
where
T: na::Scalar,
D: na::Dim,
na::DefaultAllocator: na::allocator::Allocator<T, na::Const<3>, D>
+ na::allocator::Allocator<T, D,na::Const<3>>
{
pub name: String,
pub solver: fn(na::OMatrix<T,na::Const<3>,D>,na::OMatrix<T,D,na::Const<3>>) -> na::OMatrix<T,na::Const<3>,na::Const<3>>
}
/// Instantiate all thinkable GenProb that you might need
pub enum Problem {
P_1_64(GenProb<f64,na::Const<1>>),
P_1_32(GenProb<f32,na::Const<1>>),
P_5_64(GenProb<f64,na::Const<5>>),
P_5_32(GenProb<f32,na::Const<5>>),
P_dyn_64(GenProb<f64,na::Dyn>),
P_dyn_32(GenProb<f32,na::Dyn>),
}
impl Problem {
pub fn type_id(&self) -> std::any::TypeId {
match self {
Problem::P_1_64(_) => std::any::TypeId::of::<GenProb<f64,na::Const<1>>>(),
Problem::P_1_32(_) => std::any::TypeId::of::<GenProb<f32,na::Const<1>>>(),
Problem::P_5_64(_) => std::any::TypeId::of::<GenProb<f64,na::Const<5>>>(),
Problem::P_5_32(_) => std::any::TypeId::of::<GenProb<f32,na::Const<5>>>(),
Problem::P_dyn_64(_) => std::any::TypeId::of::<GenProb<f64,na::Dyn>>(),
Problem::P_dyn_32(_) => std::any::TypeId::of::<GenProb<f32,na::Dyn>>(),
}
}
}
/// Instantiate all thinkable GenProb that you might need
pub enum Solver {
S_1_64(GenSolv<f64,na::Const<1>>),
S_1_32(GenSolv<f32,na::Const<1>>),
S_5_64(GenSolv<f64,na::Const<5>>),
S_5_32(GenSolv<f32,na::Const<5>>),
S_dyn_64(GenSolv<f64,na::Dyn>),
S_dyn_32(GenSolv<f32,na::Dyn>),
}
pub fn process<T,D>(prob:GenProb<T,D>,solv:GenSolv<T,D>) where
D: na::Dim,
T: na::RealField,
na::DefaultAllocator: na::allocator::Allocator<T, na::Const<3>, D>
+ na::allocator::Allocator<T, D,na::Const<3>>,
{
let now = std::time::Instant::now();
let _C = (solv.solver)(prob.A,prob.B);
let elapsed = now.elapsed();
println!("Running {:10} on {:15} took {:?} seconds",solv.name,prob.name,elapsed)
}
}
mod multipliers{
//! A module containing matrix multiplication methods
use nalgebra as na;
use nalgebra; // 0.32.5
/// The built in matrix multiply
#[allow(non_snake_case)]
pub fn na_star<T,D>(
A: na::OMatrix<T,na::Const<3>,D>,
B: na::OMatrix<T,D,na::Const<3>>
) -> na::OMatrix<T,na::Const<3>,na::Const<3>>
where
T: na::RealField,
D: na::Dim,
na::DefaultAllocator: na::allocator::Allocator<T, na::Const<3>, D>
+ na::allocator::Allocator<T, D,na::Const<3>> {
A * B
}
/// An alternative implementation
#[allow(non_snake_case)]
pub fn manual<T,D>(
A: na::OMatrix<T,na::Const<3>,D>,
B: na::OMatrix<T,D,na::Const<3>>
) -> na::OMatrix<T,na::Const<3>,na::Const<3>>
where
T: na::RealField + std::iter::Sum + Copy,
D: na::Dim,
na::DefaultAllocator: na::allocator::Allocator<T, na::Const<3>, D>
+ na::allocator::Allocator<T, D,na::Const<3>> {
let mut out = vec![];
for r in 0..3 {
for c in 0..3 {
out.push(
A.row(r).iter().zip(B.column(c).iter()).map(|(&a,&b)| a*b).sum()
)
}
}
na::Matrix3::from_vec(out)
}
}
use bench_utils::*;
pub fn main() {
let problems = vec![
Problem::P_1_64(GenProb{name:"1 64 static".into(), A: na::Matrix3x1::new(0.0, 1.0, 2.0),B: na::Matrix1x3::new(2.0, 1.0, 0.0)}),
Problem::P_5_64(GenProb{name:"5 64 static".into(), A: na::Matrix3x5::new(1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0),B: na::Matrix5x3::new(1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0)}),
Problem::P_1_32(GenProb{name:"1 32 static".into(), A: na::Matrix3x1::new(0.0, 1.0, 2.0),B: na::Matrix1x3::new(2.0, 1.0, 0.0)}),
Problem::P_5_32(GenProb{name:"5 32 static".into(), A: na::Matrix3x5::new(1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0),B: na::Matrix5x3::new(1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0)}),
Problem::P_dyn_64(GenProb{name:"1 64 dyn".into(), A: na::Matrix::<f64,na::Const<3>,na::Dyn,_>::from_vec(vec![0.0,1.0,2.0]), B: na::Matrix::<f64,na::Dyn,na::Const<3>,_>::from_vec(vec![0.0,1.0,2.0]) }),
Problem::P_dyn_64(GenProb{name:"5 64 dyn".into(), A: na::Matrix::<f64,na::Const<3>,na::Dyn,_>::from_vec(vec![1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0]), B: na::Matrix::<f64,na::Dyn,na::Const<3>,_>::from_vec(vec![1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0]) }),
Problem::P_dyn_32(GenProb{name:"1 32 dyn".into(), A: na::Matrix::<_,na::Const<3>,na::Dyn,_>::from_vec(vec![0.0,1.0,2.0]), B: na::Matrix::<_,na::Dyn,na::Const<3>,_>::from_vec(vec![0.0,1.0,2.0]) }),
Problem::P_dyn_32(GenProb{name:"5 32 dyn".into(), A: na::Matrix::<_,na::Const<3>,na::Dyn,_>::from_vec(vec![1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0]), B: na::Matrix::<_,na::Dyn,na::Const<3>,_>::from_vec(vec![1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0,1.0,2.0,3.0,4.0,5.0]) }),
];
let solvers: std::collections::HashMap<_, Vec<_>> = std::collections::HashMap::from([
(std::any::TypeId::of::<GenProb<f64,na::Const<1>>>(),vec![
Solver::S_1_64(GenSolv{name:"nalgebra*".into(),solver:multipliers::na_star}),
Solver::S_1_64(GenSolv{name:"manual".into(),solver:multipliers::manual}),
]),
(std::any::TypeId::of::<GenProb<f64,na::Const<5>>>(),vec![
Solver::S_5_64(GenSolv{name:"nalgebra*".into(),solver:multipliers::na_star}),
Solver::S_5_64(GenSolv{name:"manual".into(),solver:multipliers::manual}),
]),
(std::any::TypeId::of::<GenProb<f64,na::Dyn>>(),vec![
Solver::S_dyn_64(GenSolv{name:"nalgebra*".into(),solver:multipliers::na_star}),
Solver::S_dyn_64(GenSolv{name:"manual".into(),solver:multipliers::manual}),
]),
(std::any::TypeId::of::<GenProb<f32,na::Const<1>>>(),vec![
Solver::S_1_32(GenSolv{name:"nalgebra*".into(),solver:multipliers::na_star}),
Solver::S_1_32(GenSolv{name:"manual".into(),solver:multipliers::manual}),
]),
(std::any::TypeId::of::<GenProb<f32,na::Const<5>>>(),vec![
Solver::S_5_32(GenSolv{name:"nalgebra*".into(),solver:multipliers::na_star}),
Solver::S_5_32(GenSolv{name:"manual".into(),solver:multipliers::manual}),
]),
(std::any::TypeId::of::<GenProb<f32,na::Dyn>>(),vec![
Solver::S_dyn_32(GenSolv{name:"nalgebra*".into(),solver:multipliers::na_star}),
Solver::S_dyn_32(GenSolv{name:"manual".into(),solver:multipliers::manual}),
]),
]);
for prob in &problems{
let id = prob.type_id();
for solv in &solvers[&id] {
match (prob,solv) {
( Problem::P_1_64(p), Solver::S_1_64(s))=>process(p.clone(),s.clone()),
( Problem::P_5_64(p), Solver::S_5_64(s))=>process(p.clone(),s.clone()),
( Problem::P_dyn_64(p), Solver::S_dyn_64(s))=>process(p.clone(),s.clone()),
( Problem::P_1_32(p), Solver::S_1_32(s))=>process(p.clone(),s.clone()),
( Problem::P_5_32(p), Solver::S_5_32(s))=>process(p.clone(),s.clone()),
( Problem::P_dyn_32(p), Solver::S_dyn_32(s))=>process(p.clone(),s.clone()),
_ => eprintln!("This problem/solver pair is not compatible")
}
}
}
}