5. Traits et Generics

Les traits sont l’équivalent des interfaces en Rust. Combinés aux generics, ils permettent d’écrire du code polymorphe avec résolution à la compilation (monomorphisation) — pas de surcoût d’exécution.


5.1 Traits — Définir un comportement partagé

Définition et implémentation

trait Aggregate {
    fn aggregate(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>, String>;
    fn name(&self) -> &'static str;
}
 
struct Median;
 
impl Aggregate for Median {
    fn aggregate(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>, String> {
        if gradients.is_empty() {
            return Err("pas de gradients".into());
        }
        let d = gradients[0].len();
        let mut result = Vec::with_capacity(d);
        for j in 0..d {
            let mut col: Vec<f64> = gradients.iter().map(|g| g[j]).collect();
            col.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
            let mid = col.len() / 2;
            result.push(if col.len() % 2 == 0 {
                (col[mid - 1] + col[mid]) / 2.0
            } else {
                col[mid]
            });
        }
        Ok(result)
    }
 
    fn name(&self) -> &'static str {
        "Median"
    }
}
 
let median = Median;
let grads = vec![vec![1.0, 5.0], vec![2.0, 3.0], vec![9.0, 4.0]];
println!("{}: {:?}", median.name(), median.aggregate(&grads));

Traits par défaut

trait Aggregate {
    fn aggregate(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>, String>;
 
    fn name(&self) -> &'static str {
        "Générique"  // implémentation par défaut
    }
 
    fn validate(&self, gradients: &[Vec<f64>]) -> Result<(), String> {
        if gradients.is_empty() {
            return Err("gradients vides".into());
        }
        let d = gradients[0].len();
        for (i, g) in gradients.iter().enumerate() {
            if g.len() != d {
                return Err(format!("gradient {i}: taille {} != {d}", g.len()));
            }
        }
        Ok(())
    }
}

Traits dérivés (#[derive(...)])

#[derive(Debug, Clone, PartialEq)]
struct Gradient {
    values: Vec<f64>,
    worker_id: usize,
}
// Génère automatiquement :
// - Debug : {:?}
// - Clone : .clone()
// - PartialEq : ==

Autres derives : Copy, Eq, Hash, Ord, PartialOrd, Default.

Opérateurs et surcharge

Les opérateurs (+, *, [], etc.) sont des traits dans std::ops :

use std::ops::{Add, Mul};
 
#[derive(Debug, Clone)]
struct Vector(Vec<f64>);
 
impl Add for Vector {
    type Output = Self;
    fn add(self, other: Self) -> Self {
        Vector(self.0.iter().zip(other.0.iter())
            .map(|(a, b)| a + b).collect())
    }
}
 
impl Mul<f64> for Vector {
    type Output = Self;
    fn mul(self, scalar: f64) -> Self {
        Vector(self.0.iter().map(|a| a * scalar).collect())
    }
}
 
let a = Vector(vec![1.0, 2.0]);
let b = Vector(vec![3.0, 4.0]);
println!("{:?}", a + b * 2.0);  // Vector([7.0, 10.0])

5.2 Generics — Polymorphisme paramétrique

Fonctions génériques

fn first<T>(slice: &[T]) -> Option<&T> {
    slice.first()
}
 
fn median_wrapper<T: Aggregate>(agg: &T, grads: &[Vec<f64>]) -> Vec<f64> {
    agg.aggregate(grads).expect("agrégation échouée")
}

Structs génériques

struct GradientBatch<T> {
    workers: Vec<Vec<T>>,  // [worker_i][coordonnée_j]
    n_workers: usize,
    dimension: usize,
}
 
impl<T: Copy + std::ops::Add<Output = T> + From<f64>> GradientBatch<T> {
    fn new(gradients: Vec<Vec<T>>) -> Result<Self, String> {
        let n = gradients.len();
        if n == 0 {
            return Err("batch vide".into());
        }
        let d = gradients[0].len();
        Ok(Self {
            workers: gradients,
            n_workers: n,
            dimension: d,
        })
    }
}

Contraintes de traits (where)

fn compute_median<T>(values: &mut [T]) -> Option<T>
where
    T: Ord + Copy + From<f64>,
{
    let n = values.len();
    if n == 0 {
        return None;
    }
    values.sort();
    let mid = n / 2;
    if n % 2 == 0 {
        Some((values[mid - 1] + values[mid]) / 2.0.into())
    } else {
        Some(values[mid])
    }
}
 
// Équivalent sans where :
fn compute_median<T: Ord + Copy + From<f64>>(values: &mut [T]) -> Option<T> { ... }

5.3 Trait Bounds et Polymorphisme

impl Trait — syntaxe concise (args)

fn process(agg: &impl Aggregate) {
    println!("using {}", agg.name());
}

impl Trait — syntaxe concise (retour)

fn create_median() -> impl Aggregate {
    Median  // retourne n'importe quel type qui implémente Aggregate
}

Limite : on ne peut retourner qu’un seul type concret. Pour retourner plusieurs types possibles, il faut Box<dyn Trait>.

dyn Trait — dispatch dynamique

fn run_gar(agg: &dyn Aggregate, grads: &[Vec<f64>]) -> Vec<f64> {
    agg.aggregate(grads).unwrap()
    // dispatch dynamique (vtable) : overhead d'un pointeur indirect
}
 
// Collection hétérogène
let gars: Vec<Box<dyn Aggregate>> = vec![
    Box::new(Median),
    Box::new(Krum),
    Box::new(TrimmedMean),
];

Static dispatch vs Dynamic dispatch

ApprocheSyntaxeRésolutionPerformance
Generics (impl Trait, <T: Trait>)fn f(t: &impl Aggregate)Compilation (monomorphisation)Zéro overhead
Trait objects (dyn Trait)fn f(t: &dyn Aggregate)Runtime (vtable lookup)1 indirection
// STATIC DISPATCH : une copie de `compute` pour chaque type T
fn compute<T: Aggregate>(agg: &T) { ... }
 
// DYNAMIC DISPATCH : une seule fonction, résolution à l'exécution
fn compute_dyn(agg: &dyn Aggregate) { ... }

5.4 Traits Courants de la Std

TraitDescriptionExemple
CloneDupliquer (clone())grads.clone()
CopyCopie implicite par bitType scalaire
DebugFormatage {:?}println!("{:?}", x)
DisplayFormatage {}println!("{x}")
DefaultValeur par défautVec::default()
PartialEq / EqTest d’égalité ==a == b
PartialOrd / OrdComparaison <, >a < b
HashHachageUtilisé dans HashMap
IteratorProduction d’élémentsiter.next()
From<T> / Into<T>Conversionx.into()
DerefDéréférencement *xBox, Rc, Arc
DropNettoyage à la sortieLibération mémoire
ErrorErreur (avec source())Types d’erreur personnalisés
Send / SyncSécurité concurrenteTypes thread-safe

5.5 Exemple : Architecture GAR avec Traits

// Trait central de la librairie
pub trait Gar: Send + Sync {
    fn aggregate(&self, grad_matrix: &[Vec<f64>]) -> Result<Vec<f64>, GarError>;
    fn name(&self) -> &'static str;
    fn breakdown_point(&self) -> f64;
    fn complexity(&self) -> &'static str;
}
 
#[derive(Debug)]
pub enum GarError {
    EmptyInput,
    DimensionMismatch { expected: usize, got: usize },
    TooManyByzantines { f: usize, n: usize },
}
 
impl std::fmt::Display for GarError { ... }
impl std::error::Error for GarError {}
 
// Implémentation Krum
pub struct Krum {
    f: usize,  // nombre de byzantins supposés
}
 
impl Krum {
    pub fn new(f: usize) -> Self {
        Self { f }
    }
}
 
impl Gar for Krum {
    fn aggregate(&self, grad_matrix: &[Vec<f64>]) -> Result<Vec<f64>, GarError> {
        if grad_matrix.is_empty() {
            return Err(GarError::EmptyInput);
        }
        let n = grad_matrix.len();
        let d = grad_matrix[0].len();
        if n < 2 * self.f + 3 {
            return Err(GarError::TooManyByzantines {
                f: self.f, n,
            });
        }
        // Calcul des distances entre toutes les paires
        let mut scores = vec![0.0; n];
        for i in 0..n {
            let mut distances: Vec<f64> = Vec::new();
            for j in 0..n {
                if i == j { continue; }
                let dist: f64 = grad_matrix[i].iter()
                    .zip(grad_matrix[j].iter())
                    .map(|(a, b)| (a - b).powi(2))
                    .sum();
                distances.push(dist);
            }
            distances.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
            scores[i] = distances.iter().take(n - self.f - 1).sum();
        }
        // Sélectionner l'index avec le plus petit score
        let best_idx = scores.iter()
            .enumerate()
            .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
            .map(|(i, _)| i)
            .unwrap();
        Ok(grad_matrix[best_idx].clone())
    }
 
    fn name(&self) -> &'static str { "Krum" }
    fn breakdown_point(&self) -> f64 { 0.5 }
    fn complexity(&self) -> &'static str { "O(n²d)" }
}

5.6 Résumé

ConceptButCoût
TraitDéfinir un comportement communZéro (sauf dyn)
Generic <T>Code réutilisable pour plusieurs typesMonomorphisation (1 copie par type)
impl TraitSyntaxe concise pour args/retourStatic dispatch
dyn TraitPolymorphisme dynamiqueDynamic dispatch (vtable)
#[derive]Implémentation automatiqueZéro
where clauseContraintes lisibles sur les genericsZéro

🔗 Voir aussi