Upper Face Dynamics Deviation (FDD)

Module Interface

class torchmetrics.multimodal.fdd.UpperFaceDynamicsDeviation(template, upper_face_map, **kwargs)[source]

Implements the Upper Facial Dynamics Deviation (FDD) metric for 3D talking head evaluation.

The Upper Face Dynamics Deviation (FDD) metric evaluates the quality of facial expressions in the upper face region for 3D talking head models. It quantifies the deviation in vertex motion dynamics between the predicted and ground truth sequences by comparing the temporal variation (standard deviation) of per-vertex squared displacements relative to a neutral template. Lower values of FDD indicate closer alignment of the predicted upper-face motion dynamics with the ground truth.

The metric is defined as:

\[\text{FDD} = \frac{1}{|S_U|} \sum_{v \in S_U} \Big( \text{std}(\| x_{1:T,v} - \text{template}_v \|_2^2) - \text{std}(\| \hat{x}_{1:T,v} - \text{template}_v \|_2^2) \Big)\]

where \(T\) is the number of frames, \(S_U\) is the set of upper-face vertices with \(M = |S_U|\), \(x_{t,v}\) are the 3D coordinates of vertex \(v\) at frame \(t\) in the ground truth sequence, and \(\hat{x}_{t,v} \in \mathbb{R}^3\) are the corresponding predicted vertices. The neutral template coordinate of vertex \(v\) is denoted as \(\text{template}_v \in \mathbb{R}^3\). The operator \(\text{std}(\cdot)\) computes the standard deviation of the temporal sequence.

As input to forward and update, the metric accepts the following input:

  • preds (Tensor): Predicted vertices tensor of shape (T, V, 3) where T is the number of frames,

    V is the number of vertices, and 3 represents XYZ coordinates.

  • target (Tensor): Ground truth vertices tensor of shape (T, V, 3) where T is the number of

    frames, V is the number of vertices, and 3 represents XYZ coordinates.

As output of forward and compute, the metric returns the following output:

  • fdd_score (Tensor): A scalar tensor containing the mean Face Dynamics Deviation

    across all upper-face vertices.

Parameters:
  • template (Tensor) – Template mesh tensor of shape (V, 3) representing the neutral face.

  • upper_face_map (List[int]) – List of vertex indices for the upper-face region.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:

ValueError – If the number of dimensions of vertices_pred or vertices_gt is not 3. If template does not have shape (No_of_vertices, 3). If vertices_pred and vertices_gt do not have the same vertex and coordinate dimensions. If template shape does not match the vertex-coordinate dimensions of vertices_pred (and vertices_gt). If upper_face_map is empty or contains invalid vertex indices.

Example

>>> import torch
>>> from torchmetrics.multimodal.fdd import UpperFaceDynamicsDeviation
>>> template = torch.randn(100, 3, generator=torch.manual_seed(41))
>>> metric = UpperFaceDynamicsDeviation(template=template, upper_face_map=[0, 1, 2, 3, 4])
>>> vertices_pred = torch.randn(10, 100, 3, generator=torch.manual_seed(42))
>>> vertices_gt = torch.randn(10, 100, 3, generator=torch.manual_seed(43))
>>> metric(vertices_pred, vertices_gt)
tensor(0.2131)
compute()[source]

Compute the Upper Face Dynamics Deviation over all accumulated states.

Returns:

A scalar tensor with the mean FDD value

Return type:

torch.Tensor

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.multimodal.fdd import UpperFaceDynamicsDeviation
>>> metric = UpperFaceDynamicsDeviation(template=torch.randn(100, 3), upper_face_map=[0, 1, 2, 3, 4])
>>> vertices_pred = torch.randn(10, 100, 3, generator=torch.manual_seed(42))
>>> vertices_gt = torch.randn(10, 100, 3, generator=torch.manual_seed(43))
>>> metric.update(vertices_pred, vertices_gt)
>>> fig_, ax_ = metric.plot()
../_images/fdd-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.multimodal.fdd import UpperFaceDynamicsDeviation
>>> metric = UpperFaceDynamicsDeviation(template=torch.randn(100, 3), upper_face_map=[0, 1, 2, 3, 4])
>>> values = []
>>> for _ in range(10):
...     vertices_pred = torch.randn(10, 100, 3, generator=torch.manual_seed(42+_))
...     vertices_gt = torch.randn(10, 100, 3, generator=torch.manual_seed(43+_))
...     values.append(metric(vertices_pred, vertices_gt))
>>> fig_, ax_ = metric.plot(values)
../_images/fdd-2.png
update(vertices_pred, vertices_gt)[source]

Update metric states with predictions and targets.

Parameters:
  • vertices_pred (Tensor) – Predicted vertices tensor of shape (T, V, 3) where T is number of frames, V is number of vertices, and 3 represents XYZ coordinates

  • vertices_gt (Tensor) – Ground truth vertices tensor of shape (T’, V, 3) where T is number of frames, V is number of vertices, and 3 represents XYZ coordinates

Return type:

None

Functional Interface

torchmetrics.functional.multimodal.fdd.upper_face_dynamics_deviation(vertices_pred, vertices_gt, template, upper_face_map)[source]

Compute Upper Face Dynamics Deviation (FDD) for 3D talking head evaluation.

The Upper Face Dynamics Deviation (FDD) metric evaluates the quality of facial expressions in the upper face region for 3D talking head models. It quantifies the deviation in vertex motion dynamics between the predicted and ground truth sequences by comparing the temporal variation (standard deviation) of per-vertex squared displacements relative to a neutral template. Lower values of FDD indicate closer alignment of the predicted upper-face motion dynamics with the ground truth.

The metric is defined as:

\[\text{FDD} = \frac{1}{|S_U|} \sum_{v \in S_U} \Big( \text{std}(\| x_{1:T,v} - \text{template}_v \|_2^2) - \text{std}(\| \hat{x}_{1:T,v} - \text{template}_v \|_2^2) \Big)\]

where \(T\) is the number of frames, \(S_U\) is the set of upper-face vertices with \(M = |S_U|\), \(x_{t,v}\) are the 3D coordinates of vertex \(v\) at frame \(t\) in the ground truth sequence, and \(\hat{x}_{t,v} \in \mathbb{R}^3\) are the corresponding predicted vertices. The neutral template coordinate of vertex \(v\) is denoted as \(\text{template}_v \in \mathbb{R}^3\). The operator \(\text{std}(\cdot)\) computes the standard deviation of the temporal sequence.

Parameters:
  • vertices_pred (Tensor) – Predicted vertices tensor of shape (T, V, 3) where T is number of frames, V is number of vertices, and 3 represents XYZ coordinates.

  • vertices_gt (Tensor) – Ground truth vertices tensor of shape (T, V, 3) where T is number of frames, V is number of vertices, and 3 represents XYZ coordinates.

  • template (Tensor) – Template mesh tensor of shape (V, 3) representing the neutral face.

  • upper_face_map (List[int]) – List of vertex indices corresponding to the upper face region.

Returns:

Scalar tensor containing the mean FDD value across upper-face vertices.

Return type:

torch.Tensor

Raises:

ValueError – If the number of dimensions of vertices_pred or vertices_gt is not 3. If template does not have shape (No_of_vertices, 3). If vertices_pred and vertices_gt do not have the same vertex and coordinate dimensions. If template shape does not match the vertex-coordinate dimensions of vertices_pred (and vertices_gt). If upper_face_map is empty or contains invalid vertex indices.

Example

>>> import torch
>>> from torchmetrics.functional.multimodal import upper_face_dynamics_deviation
>>> vertices_pred = torch.randn(10, 100, 3, generator=torch.manual_seed(41))
>>> vertices_gt = torch.randn(10, 100, 3, generator=torch.manual_seed(42))
>>> upper_face_map = [10, 11, 12, 13, 14]
>>> template = torch.randn(100, 3, generator=torch.manual_seed(43))
>>> upper_face_dynamics_deviation(vertices_pred, vertices_gt, template, upper_face_map)
tensor(1.0385)