diff --git a/src/pyrecest/distributions/hypersphere_subset/complex_angular_central_gaussian_distribution.py b/src/pyrecest/distributions/hypersphere_subset/complex_angular_central_gaussian_distribution.py index 8bdb08aca..428fff858 100644 --- a/src/pyrecest/distributions/hypersphere_subset/complex_angular_central_gaussian_distribution.py +++ b/src/pyrecest/distributions/hypersphere_subset/complex_angular_central_gaussian_distribution.py @@ -183,6 +183,14 @@ def estimate_parameter_matrix(Z, n_iterations=100): C : array-like of shape (d, d) Estimated Hermitian parameter matrix. """ + Z = array(Z) + if Z.ndim != 2: + raise ValueError("Z must be a two-dimensional array of samples") + if Z.shape[0] == 0 or Z.shape[1] == 0: + raise ValueError("Z must contain at least one sample and one dimension") + if not _to_python_bool(backend_all(isfinite(Z))): + raise ValueError("Z must contain only finite values") + N = Z.shape[0] D = Z.shape[1] C = eye(D, dtype=complex128) diff --git a/tests/distributions/test_complex_angular_central_gaussian_distribution.py b/tests/distributions/test_complex_angular_central_gaussian_distribution.py index 40e8ed1a8..6b737276a 100644 --- a/tests/distributions/test_complex_angular_central_gaussian_distribution.py +++ b/tests/distributions/test_complex_angular_central_gaussian_distribution.py @@ -227,6 +227,20 @@ def test_fit_returns_distribution(self): self.assertIsInstance(dist, ComplexAngularCentralGaussianDistribution) self.assertEqual(dist.dim, 2) + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", + reason="Not supported on JAX backend", + ) # pylint: disable=no-member + def test_fit_accepts_array_like_samples(self): + """fit() should coerce array-like sample collections before reading shape.""" + Z = [[1.0 + 0.0j, 0.0 + 0.0j], [0.0 + 0.0j, 1.0 + 0.0j]] + + dist = ComplexAngularCentralGaussianDistribution.fit(Z, n_iterations=1) + + self.assertIsInstance(dist, ComplexAngularCentralGaussianDistribution) + self.assertEqual(dist.dim, 2) + npt.assert_allclose(real(dist.C), 0.5 * eye(2), atol=1e-12) + @unittest.skipIf( pyrecest.backend.__backend_name__ == "jax", reason="Not supported on JAX backend",