Comprehensive testing is essential for maintaining the reliability and stability of TorchEBM. This guide outlines our testing approach and best practices.
# Run all testspytest
# Run specific testspytesttests/unit/core/
pytesttests/unit/samplers/test_langevin.py
# Run specific test classpytesttests/unit/core/test_energy.py::TestGaussianEnergy
# Run specific test methodpytesttests/unit/core/test_energy.py::TestGaussianEnergy::test_energy_computation
importpytestimporttorchfromtorchebm.coreimportGaussianEnergyclassTestGaussianEnergy:@pytest.fixturedefenergy_fn(self):"""Fixture to create a standard Gaussian energy function."""returnGaussianEnergy(mean=torch.zeros(2),cov=torch.eye(2))deftest_energy_computation(self,energy_fn):"""Test that energy is correctly computed for known inputs."""x=torch.zeros(2)energy=energy_fn(x)assertenergy.item()==0.0x=torch.ones(2)energy=energy_fn(x)asserttorch.isclose(energy,torch.tensor(1.0))
importpytestimporttorch@pytest.fixturedefdevice():"""Return the default device for testing."""returntorch.device("cuda"iftorch.cuda.is_available()else"cpu")@pytest.fixturedefprecision():"""Return the default precision for comparison."""return1e-5
deftest_with_mock(mocker):# Mock an expensive functionmock_compute=mocker.patch("torchebm.utils.compute_expensive_function")mock_compute.return_value=torch.tensor(1.0)# Test code that uses the mocked function# ...# Verify the mock was called correctlymock_compute.assert_called_once_with(torch.tensor(0.0))
importhypothesis.strategiesasstfromhypothesisimportgivenimporttorchfromtorchebm.coreimportGaussianEnergy@given(x=st.lists(st.floats(min_value=-10,max_value=10),min_size=2,max_size=2).map(torch.tensor))deftest_gaussian_energy_properties(x):"""Test properties of Gaussian energy function."""energy_fn=GaussianEnergy(mean=torch.zeros(2),cov=torch.eye(2))# Property: energy is non-negative for standard Gaussianenergy=energy_fn(x)assertenergy>=0# Property: energy is minimized at the meanenergy_at_mean=energy_fn(torch.zeros(2))assertenergy>=energy_at_mean