5. Traits et Generics
Les traits sont l’équivalent des interfaces en Rust. Combinés aux generics, ils permettent d’écrire du code polymorphe avec résolution à la compilation (monomorphisation) — pas de surcoût d’exécution.
5.1 Traits — Définir un comportement partagé
Définition et implémentation
trait Aggregate {
fn aggregate(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>, String>;
fn name(&self) -> &'static str;
}
struct Median;
impl Aggregate for Median {
fn aggregate(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>, String> {
if gradients.is_empty() {
return Err("pas de gradients".into());
}
let d = gradients[0].len();
let mut result = Vec::with_capacity(d);
for j in 0..d {
let mut col: Vec<f64> = gradients.iter().map(|g| g[j]).collect();
col.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
let mid = col.len() / 2;
result.push(if col.len() % 2 == 0 {
(col[mid - 1] + col[mid]) / 2.0
} else {
col[mid]
});
}
Ok(result)
}
fn name(&self) -> &'static str {
"Median"
}
}
let median = Median;
let grads = vec![vec![1.0, 5.0], vec![2.0, 3.0], vec![9.0, 4.0]];
println!("{}: {:?}", median.name(), median.aggregate(&grads));Traits par défaut
trait Aggregate {
fn aggregate(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>, String>;
fn name(&self) -> &'static str {
"Générique" // implémentation par défaut
}
fn validate(&self, gradients: &[Vec<f64>]) -> Result<(), String> {
if gradients.is_empty() {
return Err("gradients vides".into());
}
let d = gradients[0].len();
for (i, g) in gradients.iter().enumerate() {
if g.len() != d {
return Err(format!("gradient {i}: taille {} != {d}", g.len()));
}
}
Ok(())
}
}Traits dérivés (#[derive(...)])
#[derive(Debug, Clone, PartialEq)]
struct Gradient {
values: Vec<f64>,
worker_id: usize,
}
// Génère automatiquement :
// - Debug : {:?}
// - Clone : .clone()
// - PartialEq : ==Autres derives : Copy, Eq, Hash, Ord, PartialOrd, Default.
Opérateurs et surcharge
Les opérateurs (+, *, [], etc.) sont des traits dans std::ops :
use std::ops::{Add, Mul};
#[derive(Debug, Clone)]
struct Vector(Vec<f64>);
impl Add for Vector {
type Output = Self;
fn add(self, other: Self) -> Self {
Vector(self.0.iter().zip(other.0.iter())
.map(|(a, b)| a + b).collect())
}
}
impl Mul<f64> for Vector {
type Output = Self;
fn mul(self, scalar: f64) -> Self {
Vector(self.0.iter().map(|a| a * scalar).collect())
}
}
let a = Vector(vec![1.0, 2.0]);
let b = Vector(vec![3.0, 4.0]);
println!("{:?}", a + b * 2.0); // Vector([7.0, 10.0])5.2 Generics — Polymorphisme paramétrique
Fonctions génériques
fn first<T>(slice: &[T]) -> Option<&T> {
slice.first()
}
fn median_wrapper<T: Aggregate>(agg: &T, grads: &[Vec<f64>]) -> Vec<f64> {
agg.aggregate(grads).expect("agrégation échouée")
}Structs génériques
struct GradientBatch<T> {
workers: Vec<Vec<T>>, // [worker_i][coordonnée_j]
n_workers: usize,
dimension: usize,
}
impl<T: Copy + std::ops::Add<Output = T> + From<f64>> GradientBatch<T> {
fn new(gradients: Vec<Vec<T>>) -> Result<Self, String> {
let n = gradients.len();
if n == 0 {
return Err("batch vide".into());
}
let d = gradients[0].len();
Ok(Self {
workers: gradients,
n_workers: n,
dimension: d,
})
}
}Contraintes de traits (where)
fn compute_median<T>(values: &mut [T]) -> Option<T>
where
T: Ord + Copy + From<f64>,
{
let n = values.len();
if n == 0 {
return None;
}
values.sort();
let mid = n / 2;
if n % 2 == 0 {
Some((values[mid - 1] + values[mid]) / 2.0.into())
} else {
Some(values[mid])
}
}
// Équivalent sans where :
fn compute_median<T: Ord + Copy + From<f64>>(values: &mut [T]) -> Option<T> { ... }5.3 Trait Bounds et Polymorphisme
impl Trait — syntaxe concise (args)
fn process(agg: &impl Aggregate) {
println!("using {}", agg.name());
}impl Trait — syntaxe concise (retour)
fn create_median() -> impl Aggregate {
Median // retourne n'importe quel type qui implémente Aggregate
}Limite : on ne peut retourner qu’un seul type concret. Pour retourner plusieurs types possibles, il faut Box<dyn Trait>.
dyn Trait — dispatch dynamique
fn run_gar(agg: &dyn Aggregate, grads: &[Vec<f64>]) -> Vec<f64> {
agg.aggregate(grads).unwrap()
// dispatch dynamique (vtable) : overhead d'un pointeur indirect
}
// Collection hétérogène
let gars: Vec<Box<dyn Aggregate>> = vec![
Box::new(Median),
Box::new(Krum),
Box::new(TrimmedMean),
];Static dispatch vs Dynamic dispatch
| Approche | Syntaxe | Résolution | Performance |
|---|---|---|---|
Generics (impl Trait, <T: Trait>) | fn f(t: &impl Aggregate) | Compilation (monomorphisation) | Zéro overhead |
Trait objects (dyn Trait) | fn f(t: &dyn Aggregate) | Runtime (vtable lookup) | 1 indirection |
// STATIC DISPATCH : une copie de `compute` pour chaque type T
fn compute<T: Aggregate>(agg: &T) { ... }
// DYNAMIC DISPATCH : une seule fonction, résolution à l'exécution
fn compute_dyn(agg: &dyn Aggregate) { ... }5.4 Traits Courants de la Std
| Trait | Description | Exemple |
|---|---|---|
Clone | Dupliquer (clone()) | grads.clone() |
Copy | Copie implicite par bit | Type scalaire |
Debug | Formatage {:?} | println!("{:?}", x) |
Display | Formatage {} | println!("{x}") |
Default | Valeur par défaut | Vec::default() |
PartialEq / Eq | Test d’égalité == | a == b |
PartialOrd / Ord | Comparaison <, > | a < b |
Hash | Hachage | Utilisé dans HashMap |
Iterator | Production d’éléments | iter.next() |
From<T> / Into<T> | Conversion | x.into() |
Deref | Déréférencement *x | Box, Rc, Arc |
Drop | Nettoyage à la sortie | Libération mémoire |
Error | Erreur (avec source()) | Types d’erreur personnalisés |
Send / Sync | Sécurité concurrente | Types thread-safe |
5.5 Exemple : Architecture GAR avec Traits
// Trait central de la librairie
pub trait Gar: Send + Sync {
fn aggregate(&self, grad_matrix: &[Vec<f64>]) -> Result<Vec<f64>, GarError>;
fn name(&self) -> &'static str;
fn breakdown_point(&self) -> f64;
fn complexity(&self) -> &'static str;
}
#[derive(Debug)]
pub enum GarError {
EmptyInput,
DimensionMismatch { expected: usize, got: usize },
TooManyByzantines { f: usize, n: usize },
}
impl std::fmt::Display for GarError { ... }
impl std::error::Error for GarError {}
// Implémentation Krum
pub struct Krum {
f: usize, // nombre de byzantins supposés
}
impl Krum {
pub fn new(f: usize) -> Self {
Self { f }
}
}
impl Gar for Krum {
fn aggregate(&self, grad_matrix: &[Vec<f64>]) -> Result<Vec<f64>, GarError> {
if grad_matrix.is_empty() {
return Err(GarError::EmptyInput);
}
let n = grad_matrix.len();
let d = grad_matrix[0].len();
if n < 2 * self.f + 3 {
return Err(GarError::TooManyByzantines {
f: self.f, n,
});
}
// Calcul des distances entre toutes les paires
let mut scores = vec![0.0; n];
for i in 0..n {
let mut distances: Vec<f64> = Vec::new();
for j in 0..n {
if i == j { continue; }
let dist: f64 = grad_matrix[i].iter()
.zip(grad_matrix[j].iter())
.map(|(a, b)| (a - b).powi(2))
.sum();
distances.push(dist);
}
distances.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
scores[i] = distances.iter().take(n - self.f - 1).sum();
}
// Sélectionner l'index avec le plus petit score
let best_idx = scores.iter()
.enumerate()
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| i)
.unwrap();
Ok(grad_matrix[best_idx].clone())
}
fn name(&self) -> &'static str { "Krum" }
fn breakdown_point(&self) -> f64 { 0.5 }
fn complexity(&self) -> &'static str { "O(n²d)" }
}5.6 Résumé
| Concept | But | Coût |
|---|---|---|
| Trait | Définir un comportement commun | Zéro (sauf dyn) |
Generic <T> | Code réutilisable pour plusieurs types | Monomorphisation (1 copie par type) |
impl Trait | Syntaxe concise pour args/retour | Static dispatch |
dyn Trait | Polymorphisme dynamique | Dynamic dispatch (vtable) |
#[derive] | Implémentation automatique | Zéro |
where clause | Contraintes lisibles sur les generics | Zéro |