14. Projet — Implémentation de GARs en Rust

Mettre en pratique les notions Rust (traits, generics, concurrence, tests) en implémentant une librairie d’agrégation robuste de gradients — le cœur de la thèse. Objectif : remplacer progressivement le prototype Python par du Rust performant, avec bindings PyO3.


Structure du Projet

gradient-core/
├── Cargo.toml
├── src/
│   ├── lib.rs          # API publique
│   ├── gar.rs          # trait Gar
│   ├── gar/
│   │   ├── median.rs
│   │   ├── krum.rs
│   │   ├── trimmed_mean.rs
│   │   └── bulyan.rs
│   ├── types.rs        # Gradient, GarError, etc.
│   ├── attack.rs       # attaques byzantines
│   └── utils/
│       ├── stats.rs    # helpers statistiques
│       └── linalg.rs   # distance, norme, flatten
├── tests/
│   ├── test_gar.rs
│   └── test_attack.rs
├── benches/
│   └── gar_bench.rs
└── examples/
    └── simple_sgd.rs

Jalons (Milestones)

Jalon 1 — Types et trait Gar

Concepts Rust : structs, enums, traits, generics, Result, From/Into

use thiserror::Error;
 
#[derive(Error, Debug)]
pub enum GarError {
    #[error("gradients vides")]
    EmptyInput,
    #[error("dimension mismatch: attendu {expected}, reçu {got}")]
    DimensionMismatch { expected: usize, got: usize },
    #[error("trop de byzantins: f={f}, n={n} (nécessite n >= 2f+3 pour Krum)")]
    TooManyByzantines { f: usize, n: usize },
    #[error("paramètre invalide: {0}")]
    InvalidParameter(String),
}
 
pub trait Gar: Send + Sync {
    fn aggregate(&self, grads: &[Vec<f64>]) -> Result<Vec<f64>, GarError>;
    fn name(&self) -> &'static str;
    fn breakdown_point(&self) -> f64;
    fn complexity(&self) -> &'static str;
}

Liens : 5-Traits et Generics, 3-Error Handling, 2-Types Structs Enums

Jalon 2 — Médiane coordonnée-par-coordonnée

Concepts Rust : itérateurs, slices, sort_unstable_by, partial_cmp

pub struct Median;
 
impl Gar for Median {
    fn aggregate(&self, grads: &[Vec<f64>]) -> Result<Vec<f64>, GarError> {
        if grads.is_empty() {
            return Err(GarError::EmptyInput);
        }
        let d = grads[0].len();
        let mut result = Vec::with_capacity(d);
        for j in 0..d {
            let mut col: Vec<f64> = grads.iter().map(|g| g[j]).collect();
            col.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
            let mid = col.len() / 2;
            result.push(if col.len() % 2 == 0 {
                (col[mid - 1] + col[mid]) / 2.0
            } else {
                col[mid]
            });
        }
        Ok(result)
    }
 
    fn name(&self) -> &'static str { "Median" }
    fn breakdown_point(&self) -> f64 { 0.5 }
    fn complexity(&self) -> &'static str { "O(n log n · d)" }
}

Variante : version avec ndarray pour le calcul matriciel plus efficace plus tard.

Liens : 4-Iterateurs et Closures, 9-String et Collections

Jalon 3 — Trimmed Mean et Krum

Concepts Rust : itérateurs avancés, filter_map, flottants, algorithmique O(n²d)

pub struct TrimmedMean {
    trim_ratio: f64,
}
 
impl TrimmedMean {
    pub fn new(trim_ratio: f64) -> Result<Self, GarError> {
        if !(0.0..0.5).contains(&trim_ratio) {
            return Err(GarError::InvalidParameter(
                "trim_ratio doit être dans [0, 0.5)".into()
            ));
        }
        Ok(Self { trim_ratio })
    }
 
    fn breakdown(&self) -> f64 {
        // point de rupture = trim_ratio
        self.trim_ratio
    }
}
 
impl Gar for TrimmedMean {
    fn aggregate(&self, grads: &[Vec<f64>]) -> Result<Vec<f64>, GarError> {
        if grads.is_empty() { return Err(GarError::EmptyInput); }
        let n = grads.len();
        let d = grads[0].len();
        let k = (n as f64 * self.trim_ratio).ceil() as usize;
 
        let mut result = Vec::with_capacity(d);
        for j in 0..d {
            let mut col: Vec<f64> = grads.iter().map(|g| g[j]).collect();
            col.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
            let trimmed: f64 = col[k..n - k].iter().sum();
            result.push(trimmed / (n - 2 * k) as f64);
        }
        Ok(result)
    }
 
    fn name(&self) -> &'static str { "TrimmedMean" }
    fn breakdown_point(&self) -> f64 { self.trim_ratio }
    fn complexity(&self) -> &'static str { "O(n log n · d)" }
}
pub struct Krum {
    f: usize,
}
 
impl Gar for Krum {
    fn aggregate(&self, grads: &[Vec<f64>]) -> Result<Vec<f64>, GarError> {
        let n = grads.len();
        let d = grads[0].len();
        if n < 2 * self.f + 3 {
            return Err(GarError::TooManyByzantines { f: self.f, n });
        }
 
        // scores[i] = somme des n-f-2 plus proches voisins
        let mut scores = vec![0.0; n];
        for i in 0..n {
            let mut dists: Vec<f64> = (0..n)
                .filter(|&j| j != i)
                .map(|j| {
                    grads[i].iter()
                        .zip(grads[j].iter())
                        .map(|(a, b)| (a - b).powi(2))
                        .sum::<f64>()
                })
                .collect();
            dists.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
            scores[i] = dists.iter().take(n - self.f - 2).sum::<f64>();
        }
 
        let best = scores.iter()
            .enumerate()
            .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
            .map(|(i, _)| i)
            .ok_or(GarError::EmptyInput)?;
 
        Ok(grads[best].clone())
    }
 
    fn name(&self) -> &'static str { "Krum" }
    fn breakdown_point(&self) -> f64 { 0.5 }
    fn complexity(&self) -> &'static str { "O(n²d)" }
}

Liens : GAR — Gradient Aggregation Rule, Concepts avancés

Jalon 4 — Organisation modulaire et tests

Concepts Rust : modules, visibilité pub(crate), tests unitaires, tests d’intégration, doc-tests

// src/lib.rs
pub mod gar;
pub mod attack;
mod types;
mod utils;
 
pub use gar::{Median, Krum, TrimmedMean, Bulyan};
pub use types::GarError;
// tests/test_gar.rs
use gradient_core::gar::{Median, Krum, Gar};
 
#[test]
fn test_median_known_values() {
    let median = Median;
    let grads = vec![
        vec![1.0, 10.0],
        vec![2.0, 20.0],
        vec![3.0, 30.0],
    ];
    let result = median.aggregate(&grads).unwrap();
    assert_eq!(result, vec![2.0, 20.0]);
}
 
#[test]
fn test_krum_rejects_byzantine() {
    let krum = Krum::new(1);
    let honest = vec![
        vec![1.0, 2.0],
        vec![1.1, 2.1],
        vec![0.9, 1.9],
    ];
    let byzantine = vec![vec![1000.0, -1000.0]];
    let all: Vec<_> = honest.iter().cloned().chain(byzantine).collect();
 
    let result = krum.aggregate(&all).unwrap();
    assert!((result[0] - 1.0).abs() < 0.2,
        "Krum a sélectionné un byzantin: {:?}", result);
}

Liens : 7-Modules et Organisation, 8-Tests et Documentation

Jalon 5 — Parallélisme et Benchmarks

Concepts Rust : rayon, Arc, benchmarks criterion, comparaison Python

// Version parallélisée avec rayon
use rayon::prelude::*;
 
impl Median {
    pub fn aggregate_parallel(&self, grads: &[Vec<f64>]) -> Result<Vec<f64>, GarError> {
        if grads.is_empty() { return Err(GarError::EmptyInput); }
        let d = grads[0].len();
 
        let result: Vec<f64> = (0..d).into_par_iter().map(|j| {
            let mut col: Vec<f64> = grads.iter().map(|g| g[j]).collect();
            col.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
            let mid = col.len() / 2;
            if col.len() % 2 == 0 {
                (col[mid - 1] + col[mid]) / 2.0
            } else {
                col[mid]
            }
        }).collect();
 
        Ok(result)
    }
}
// benches/gar_bench.rs
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use gradient_core::gar::{Median, Gar};
 
fn bench_median(c: &mut Criterion) {
    let grads: Vec<Vec<f64>> = (0..50)
        .map(|_| (0..10_000).map(|_| rand::random()).collect())
        .collect();
    let median = Median;
 
    c.bench_function("median 50×10k", |b| {
        b.iter(|| median.aggregate(black_box(&grads)))
    });
}
 
criterion_group!(benches, bench_median);
criterion_main!(benches);

Liens : 6-Smart Pointers et Concurrence, 8.4 Benchmarks

Bonus — Bindings PyO3

// gradient-pyo3/src/lib.rs
use pyo3::prelude::*;
use gradient_core::gar::{Median, Gar};
 
#[pyclass]
struct PyMedian {
    inner: Median,
}
 
#[pymethods]
impl PyMedian {
    #[new]
    fn new() -> Self { Self { inner: Median } }
 
    fn aggregate(&self, grads: Vec<Vec<f64>>) -> PyResult<Vec<f64>> {
        self.inner.aggregate(&grads)
            .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
    }
 
    fn __repr__(&self) -> String {
        format!("Median(breakdown={})", self.inner.breakdown_point())
    }
}
 
#[pymodule]
fn gradient_gar(_py: Python, m: &PyModule) -> PyResult<()> {
    m.add_class::<PyMedian>()?;
    Ok(())
}

Liens : 12.4 Exposer Rust en C (l’inverse), Library


Correspondance Concepts Thèse ↔ Rust

Concept ThèseRustNote associée
GAR (point de rupture)trait Gar { fn breakdown_point() }5-Traits et Generics
Gradient comme vecteurVec<f64>, ndarray::Array1<f64>9-String et Collections
Distance entre gradients(a-b).powi(2).sum() via itérateurs4-Iterateurs et Closures
Agrégation parallèlerayon::par_iter()6-Smart Pointers et Concurrence
Communication workerstokio::net::TcpStream11-Async et Tokio
Sérialisation config/runserde_json, toml10-IO et Serde
Bindings PythonPyO3 / maturin12-Unsafe et FFI
Accélération GPUcudarc, kernels CUDA12-Unsafe et FFI

Ressources


🔗 Voir aussi