14. Projet — Implémentation de GARs en Rust
Mettre en pratique les notions Rust (traits, generics, concurrence, tests) en implémentant une librairie d’agrégation robuste de gradients — le cœur de la thèse. Objectif : remplacer progressivement le prototype Python par du Rust performant, avec bindings PyO3.
Structure du Projet
gradient-core/
├── Cargo.toml
├── src/
│ ├── lib.rs # API publique
│ ├── gar.rs # trait Gar
│ ├── gar/
│ │ ├── median.rs
│ │ ├── krum.rs
│ │ ├── trimmed_mean.rs
│ │ └── bulyan.rs
│ ├── types.rs # Gradient, GarError, etc.
│ ├── attack.rs # attaques byzantines
│ └── utils/
│ ├── stats.rs # helpers statistiques
│ └── linalg.rs # distance, norme, flatten
├── tests/
│ ├── test_gar.rs
│ └── test_attack.rs
├── benches/
│ └── gar_bench.rs
└── examples/
└── simple_sgd.rs
Jalons (Milestones)
Jalon 1 — Types et trait Gar
Concepts Rust : structs, enums, traits, generics, Result, From/Into
use thiserror::Error;
#[derive(Error, Debug)]
pub enum GarError {
#[error("gradients vides")]
EmptyInput,
#[error("dimension mismatch: attendu {expected}, reçu {got}")]
DimensionMismatch { expected: usize, got: usize },
#[error("trop de byzantins: f={f}, n={n} (nécessite n >= 2f+3 pour Krum)")]
TooManyByzantines { f: usize, n: usize },
#[error("paramètre invalide: {0}")]
InvalidParameter(String),
}
pub trait Gar: Send + Sync {
fn aggregate(&self, grads: &[Vec<f64>]) -> Result<Vec<f64>, GarError>;
fn name(&self) -> &'static str;
fn breakdown_point(&self) -> f64;
fn complexity(&self) -> &'static str;
}Liens : 5-Traits et Generics, 3-Error Handling, 2-Types Structs Enums
Jalon 2 — Médiane coordonnée-par-coordonnée
Concepts Rust : itérateurs, slices, sort_unstable_by, partial_cmp
pub struct Median;
impl Gar for Median {
fn aggregate(&self, grads: &[Vec<f64>]) -> Result<Vec<f64>, GarError> {
if grads.is_empty() {
return Err(GarError::EmptyInput);
}
let d = grads[0].len();
let mut result = Vec::with_capacity(d);
for j in 0..d {
let mut col: Vec<f64> = grads.iter().map(|g| g[j]).collect();
col.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
let mid = col.len() / 2;
result.push(if col.len() % 2 == 0 {
(col[mid - 1] + col[mid]) / 2.0
} else {
col[mid]
});
}
Ok(result)
}
fn name(&self) -> &'static str { "Median" }
fn breakdown_point(&self) -> f64 { 0.5 }
fn complexity(&self) -> &'static str { "O(n log n · d)" }
}Variante : version avec
ndarraypour le calcul matriciel plus efficace plus tard.
Jalon 3 — Trimmed Mean et Krum
Concepts Rust : itérateurs avancés, filter_map, flottants, algorithmique O(n²d)
pub struct TrimmedMean {
trim_ratio: f64,
}
impl TrimmedMean {
pub fn new(trim_ratio: f64) -> Result<Self, GarError> {
if !(0.0..0.5).contains(&trim_ratio) {
return Err(GarError::InvalidParameter(
"trim_ratio doit être dans [0, 0.5)".into()
));
}
Ok(Self { trim_ratio })
}
fn breakdown(&self) -> f64 {
// point de rupture = trim_ratio
self.trim_ratio
}
}
impl Gar for TrimmedMean {
fn aggregate(&self, grads: &[Vec<f64>]) -> Result<Vec<f64>, GarError> {
if grads.is_empty() { return Err(GarError::EmptyInput); }
let n = grads.len();
let d = grads[0].len();
let k = (n as f64 * self.trim_ratio).ceil() as usize;
let mut result = Vec::with_capacity(d);
for j in 0..d {
let mut col: Vec<f64> = grads.iter().map(|g| g[j]).collect();
col.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
let trimmed: f64 = col[k..n - k].iter().sum();
result.push(trimmed / (n - 2 * k) as f64);
}
Ok(result)
}
fn name(&self) -> &'static str { "TrimmedMean" }
fn breakdown_point(&self) -> f64 { self.trim_ratio }
fn complexity(&self) -> &'static str { "O(n log n · d)" }
}pub struct Krum {
f: usize,
}
impl Gar for Krum {
fn aggregate(&self, grads: &[Vec<f64>]) -> Result<Vec<f64>, GarError> {
let n = grads.len();
let d = grads[0].len();
if n < 2 * self.f + 3 {
return Err(GarError::TooManyByzantines { f: self.f, n });
}
// scores[i] = somme des n-f-2 plus proches voisins
let mut scores = vec![0.0; n];
for i in 0..n {
let mut dists: Vec<f64> = (0..n)
.filter(|&j| j != i)
.map(|j| {
grads[i].iter()
.zip(grads[j].iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
})
.collect();
dists.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
scores[i] = dists.iter().take(n - self.f - 2).sum::<f64>();
}
let best = scores.iter()
.enumerate()
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| i)
.ok_or(GarError::EmptyInput)?;
Ok(grads[best].clone())
}
fn name(&self) -> &'static str { "Krum" }
fn breakdown_point(&self) -> f64 { 0.5 }
fn complexity(&self) -> &'static str { "O(n²d)" }
}Jalon 4 — Organisation modulaire et tests
Concepts Rust : modules, visibilité pub(crate), tests unitaires, tests d’intégration, doc-tests
// src/lib.rs
pub mod gar;
pub mod attack;
mod types;
mod utils;
pub use gar::{Median, Krum, TrimmedMean, Bulyan};
pub use types::GarError;// tests/test_gar.rs
use gradient_core::gar::{Median, Krum, Gar};
#[test]
fn test_median_known_values() {
let median = Median;
let grads = vec![
vec![1.0, 10.0],
vec![2.0, 20.0],
vec![3.0, 30.0],
];
let result = median.aggregate(&grads).unwrap();
assert_eq!(result, vec![2.0, 20.0]);
}
#[test]
fn test_krum_rejects_byzantine() {
let krum = Krum::new(1);
let honest = vec![
vec![1.0, 2.0],
vec![1.1, 2.1],
vec![0.9, 1.9],
];
let byzantine = vec![vec![1000.0, -1000.0]];
let all: Vec<_> = honest.iter().cloned().chain(byzantine).collect();
let result = krum.aggregate(&all).unwrap();
assert!((result[0] - 1.0).abs() < 0.2,
"Krum a sélectionné un byzantin: {:?}", result);
}Jalon 5 — Parallélisme et Benchmarks
Concepts Rust : rayon, Arc, benchmarks criterion, comparaison Python
// Version parallélisée avec rayon
use rayon::prelude::*;
impl Median {
pub fn aggregate_parallel(&self, grads: &[Vec<f64>]) -> Result<Vec<f64>, GarError> {
if grads.is_empty() { return Err(GarError::EmptyInput); }
let d = grads[0].len();
let result: Vec<f64> = (0..d).into_par_iter().map(|j| {
let mut col: Vec<f64> = grads.iter().map(|g| g[j]).collect();
col.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
let mid = col.len() / 2;
if col.len() % 2 == 0 {
(col[mid - 1] + col[mid]) / 2.0
} else {
col[mid]
}
}).collect();
Ok(result)
}
}// benches/gar_bench.rs
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use gradient_core::gar::{Median, Gar};
fn bench_median(c: &mut Criterion) {
let grads: Vec<Vec<f64>> = (0..50)
.map(|_| (0..10_000).map(|_| rand::random()).collect())
.collect();
let median = Median;
c.bench_function("median 50×10k", |b| {
b.iter(|| median.aggregate(black_box(&grads)))
});
}
criterion_group!(benches, bench_median);
criterion_main!(benches);Bonus — Bindings PyO3
// gradient-pyo3/src/lib.rs
use pyo3::prelude::*;
use gradient_core::gar::{Median, Gar};
#[pyclass]
struct PyMedian {
inner: Median,
}
#[pymethods]
impl PyMedian {
#[new]
fn new() -> Self { Self { inner: Median } }
fn aggregate(&self, grads: Vec<Vec<f64>>) -> PyResult<Vec<f64>> {
self.inner.aggregate(&grads)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
}
fn __repr__(&self) -> String {
format!("Median(breakdown={})", self.inner.breakdown_point())
}
}
#[pymodule]
fn gradient_gar(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PyMedian>()?;
Ok(())
}Correspondance Concepts Thèse ↔ Rust
| Concept Thèse | Rust | Note associée |
|---|---|---|
| GAR (point de rupture) | trait Gar { fn breakdown_point() } | 5-Traits et Generics |
| Gradient comme vecteur | Vec<f64>, ndarray::Array1<f64> | 9-String et Collections |
| Distance entre gradients | (a-b).powi(2).sum() via itérateurs | 4-Iterateurs et Closures |
| Agrégation parallèle | rayon::par_iter() | 6-Smart Pointers et Concurrence |
| Communication workers | tokio::net::TcpStream | 11-Async et Tokio |
| Sérialisation config/run | serde_json, toml | 10-IO et Serde |
| Bindings Python | PyO3 / maturin | 12-Unsafe et FFI |
| Accélération GPU | cudarc, kernels CUDA | 12-Unsafe et FFI |
Ressources
- Définitions : GAR
- Concepts avancés — compréhension théorique des GARs
- Librairie MLOSS — contexte du projet de thèse
- Rust Book — Ch. 10: Generics, Traits, Lifetimes
- Rust book — Ch. 12: I/O
🔗 Voir aussi
- Rust — index des notes Rust
- 5-Traits et Generics — définition du trait Gar
- 8-Tests et Documentation — benchmarks et tests
- Définitions — glossaire byzantin
- Library — specs de la librairie de thèse
- Concepts avancés