use std::sync::Mutex; use crate::body::Body; use rayon::prelude::*; const G: f64 = 6.67430e-11; pub struct Simulator { pub bodies: Vec, pub time: f64, pub timestep: f64, } pub fn distance_squared(a: [f64; 2], b: [f64; 2]) -> f64 { let dx = a[0] - b[0]; let dy = a[1] - b[1]; dx * dx + dy * dy } impl Simulator { pub fn new(timestep: f64) -> Self { Self { bodies: Vec::new(), time: 0.0, timestep, } } pub fn add_body(&mut self, body: Body) { self.bodies.push(body); } pub fn step(&mut self) { let dt = self.timestep; let n = self.bodies.len(); #[derive(Clone)] struct State { position: [f64; 2], velocity: [f64; 2], } let original_states: Vec = self .bodies .iter() .map(|b| State { position: b.position, velocity: b.velocity, }) .collect(); let masses: Vec = self.bodies.iter().map(|b| b.mass).collect(); fn compute_accelerations(states: &[State], masses: &[f64]) -> Vec<[f64; 2]> { let n = states.len(); let accels = (0..n).map(|_| Mutex::new([0.0, 0.0])).collect::>(); (0..n).into_par_iter().for_each(|i| { for j in (i + 1)..n { let dx = states[j].position[0] - states[i].position[0]; let dy = states[j].position[1] - states[i].position[1]; let dist_sq = dx * dx + dy * dy; let dist = dist_sq.sqrt(); if dist < 1e-3 { continue; } let force = G * masses[i] * masses[j] / dist_sq; let ax = force * dx / dist; let ay = force * dy / dist; { let mut a_i_lock = accels[i].lock().unwrap(); a_i_lock[0] += ax / masses[i]; a_i_lock[1] += ay / masses[i]; } { let mut a_j_lock = accels[j].lock().unwrap(); a_j_lock[0] -= ax / masses[j]; a_j_lock[1] -= ay / masses[j]; } } }); accels .into_iter() .map(|mutex| mutex.into_inner().unwrap()) .collect() } let k1_pos = original_states.iter().map(|s| s.velocity).collect::>(); let k1_vel = compute_accelerations(&original_states, &masses); let mut temp_states = original_states .iter() .enumerate() .map(|(i, s)| State { position: [ s.position[0] + k1_pos[i][0] * dt / 2.0, s.position[1] + k1_pos[i][1] * dt / 2.0, ], velocity: [ s.velocity[0] + k1_vel[i][0] * dt / 2.0, s.velocity[1] + k1_vel[i][1] * dt / 2.0, ], }) .collect::>(); let k2_pos = temp_states.iter().map(|s| s.velocity).collect::>(); let k2_vel = compute_accelerations(&temp_states, &masses); for i in 0..n { temp_states[i].position[0] = original_states[i].position[0] + k2_pos[i][0] * dt / 2.0; temp_states[i].position[1] = original_states[i].position[1] + k2_pos[i][1] * dt / 2.0; temp_states[i].velocity[0] = original_states[i].velocity[0] + k2_vel[i][0] * dt / 2.0; temp_states[i].velocity[1] = original_states[i].velocity[1] + k2_vel[i][1] * dt / 2.0; } let k3_pos = temp_states.iter().map(|s| s.velocity).collect::>(); let k3_vel = compute_accelerations(&temp_states, &masses); for i in 0..n { temp_states[i].position[0] = original_states[i].position[0] + k3_pos[i][0] * dt; temp_states[i].position[1] = original_states[i].position[1] + k3_pos[i][1] * dt; temp_states[i].velocity[0] = original_states[i].velocity[0] + k3_vel[i][0] * dt; temp_states[i].velocity[1] = original_states[i].velocity[1] + k3_vel[i][1] * dt; } let k4_pos = temp_states.iter().map(|s| s.velocity).collect::>(); let k4_vel = compute_accelerations(&temp_states, &masses); for i in 0..n { let body = &mut self.bodies[i]; body.position[0] += (dt / 6.0) * (k1_pos[i][0] + 2.0 * k2_pos[i][0] + 2.0 * k3_pos[i][0] + k4_pos[i][0]); body.position[1] += (dt / 6.0) * (k1_pos[i][1] + 2.0 * k2_pos[i][1] + 2.0 * k3_pos[i][1] + k4_pos[i][1]); body.velocity[0] += (dt / 6.0) * (k1_vel[i][0] + 2.0 * k2_vel[i][0] + 2.0 * k3_vel[i][0] + k4_vel[i][0]); body.velocity[1] += (dt / 6.0) * (k1_vel[i][1] + 2.0 * k2_vel[i][1] + 2.0 * k3_vel[i][1] + k4_vel[i][1]); } self.time += dt; } }