Source code for promis.models.gaussian_mixture

"""This module implements Gaussian Mixture Models (GMM)."""

#
# Copyright (c) Simon Kohaut, Honda Research Institute Europe GmbH
#
# This file is part of ProMis and licensed under the BSD 3-Clause License.
# You should have received a copy of the BSD 3-Clause License along with ProMis.
# If not, see https://opensource.org/license/bsd-3-clause/.
#

# Standard Library
from copy import deepcopy

# Third Party
from numpy import array, ndarray
from numpy.linalg import inv

# ProMis
from promis.models.gaussian import Gaussian


[docs] class GaussianMixture: """The Gaussian Mixture Model (GMM) for representing multi-modal probability distribution. Args: components: An initial list of components to consider in this GMM """ def __init__(self, components: list[Gaussian] | None = None): # Setup attributes self.components: list[Gaussian] = components if components else [] def __iter__(self): return self.components.__iter__() def __getitem__(self, key: int) -> Gaussian: return self.components[key] def __len__(self) -> int: return len(self.components) def __add__(self, other: "GaussianMixture") -> "GaussianMixture": return GaussianMixture(self.components + other.components)
[docs] def append(self, component: Gaussian): """Appends a new Gaussian to this Mixture's list of components. Args: component: The new Gaussian to append """ self.components.append(component)
[docs] def modes(self, threshold: float = 0.5) -> list[ndarray]: """Extract all modes of the mixture model that are above a set threshold. Args: threshold: Weight that a component needs to have to be considered Returns: The locations of all modes with weight larger than the threshold """ # Memory for all estimated modes modes: list[ndarray] = [] # Every component with sufficient weight is considered to be a target for component in self.components: if component.w > threshold: # A component with weight over 1 represents multiple targets modes += [component.x for _ in range(int(round(component.w)))] # Return all extracted modes return modes
[docs] def prune(self, threshold: float, merge_distance: float, max_components: int) -> None: """Reduces the number of gaussian mixture components. Args: threshold: Truncation threshold s.t. components with weight < threshold are removed merge_distance: Merging threshold s.t. components 'close enough' will be merged max_components: Maximum number of gaussians after pruning """ # Select a subset of components to be pruned selected = [component for component in self.components if component.w > threshold] # Create new list for pruned mixture model pruned: list[Gaussian] = [] # While candidates for pruning exist ... while selected: # Find mean of component with maximum weight index = max(range(len(selected)), key=lambda index: selected[index].w) mean = selected[index].x # Select components to be merged and remove merged from selected mergeable = [ c for c in selected if ((c.x - mean).T @ inv(c.P) @ (c.x - mean)).item() <= merge_distance ] selected = [c for c in selected if c not in mergeable] # Compute new mixture component merged_weight = sum([component.w for component in mergeable]) merged_mean = array( sum([component.w * component.x for component in mergeable]) / merged_weight ) merged_covariance = array( sum( [ component.w * (component.P + (mean - component.x) @ (mean - component.x).T) for component in mergeable ] ) / merged_weight ) # Store the component pruned.append(Gaussian(merged_mean, merged_covariance, merged_weight)) # Remove components with minimum weight if maximum number is exceeded while len(pruned) > max_components: # Find index of component with minimum weight index = min(range(len(pruned)), key=lambda index: pruned[index].w) # Remove the component del pruned[index] # Update GMM with pruned model self.components = deepcopy(pruned)