Skip to content

Add conditional sampling#68

Open
anahitamansouri wants to merge 6 commits intodwavesystems:mainfrom
anahitamansouri:feature/conditional-sampling
Open

Add conditional sampling#68
anahitamansouri wants to merge 6 commits intodwavesystems:mainfrom
anahitamansouri:feature/conditional-sampling

Conversation

@anahitamansouri
Copy link
Copy Markdown
Collaborator

This PR adds:

  1. Conditional sampling feature for block spin sampling.
  2. BipartiteSampler for sampling bipartite GRBMs.
  3. An example of using the BipartiteSampler.
  4. Tests for the new functionalities.

@anahitamansouri anahitamansouri self-assigned this Mar 9, 2026
@anahitamansouri anahitamansouri added the enhancement New feature or request label Mar 9, 2026
@anahitamansouri anahitamansouri marked this pull request as ready for review March 9, 2026 21:18
Copy link
Copy Markdown
Collaborator

@kevinchern kevinchern left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent!!!! Did a quick first pass with minor requests. The implementation is clean and efficient, documentation is well written, and tests are thorough.
Missing implementation for DimodSampler but let's add that in a separate PR.

Comment thread dwave/plugins/torch/samplers/bipartite_sampler.py Outdated
Comment thread dwave/plugins/torch/samplers/bipartite_sampler.py Outdated
Comment thread dwave/plugins/torch/samplers/bipartite_sampler.py Outdated
Comment thread dwave/plugins/torch/samplers/bipartite_sampler.py Outdated
Comment thread dwave/plugins/torch/samplers/bipartite_sampler.py Outdated
Comment thread tests/test_bipartite_sampler.py
Comment thread dwave/plugins/torch/samplers/bipartite_sampler.py Outdated
Comment thread tests/test_bipartite_sampler.py Outdated
Comment thread tests/test_bipartite_sampler.py Outdated
Comment thread tests/test_bipartite_sampler.py
Copy link
Copy Markdown
Collaborator

@kevinchern kevinchern left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a couple minor requests

Comment thread dwave/plugins/torch/samplers/bipartite_sampler.py
torch.Tensor: A tensor of shape (num_chains, n_nodes) of +/-1 values sampled from the model.
"""
if x is not None:
mask = self._validate_input_and_generate_mask(x)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mask = self._validate_input_and_generate_mask(x)
self._validate_input(x)
mask = ~torch.isnan(x)

h = self._grbm.hidden_idx
self._x[:, h] = torch.where(mask[:, h], x[:, h], self._x[:, h])

def _validate_input_and_generate_mask(self, x: torch.Tensor) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _validate_input_and_generate_mask(self, x: torch.Tensor) -> torch.Tensor:
def _validate_input(self, x: torch.Tensor) -> None:

self._x[:, h] = torch.where(mask[:, h], x[:, h], self._x[:, h])

def _validate_input_and_generate_mask(self, x: torch.Tensor) -> torch.Tensor:
"""Validate conditional sampling input and construct a boolean mask.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Validate conditional sampling input and construct a boolean mask.
"""Validate conditional sampling input.

Comment on lines +281 to +285

Returns:
torch.Tensor: Boolean mask of shape ``(num_chains, n_nodes)`` where
``True`` indicates clamped variables (observed in ``x``) and
``False`` indicates variables that should be sampled (``NaN`` in x).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Returns:
torch.Tensor: Boolean mask of shape ``(num_chains, n_nodes)`` where
``True`` indicates clamped variables (observed in ``x``) and
``False`` indicates variables that should be sampled (``NaN`` in x).

"The input must be unclamped for visible or hidden but not both."
)

return mask
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return mask


Args:
x (torch.Tensor): A tensor of shape (``num_chains``, ``dim``) or (``num_chains``, ``n_nodes``)
interpreted as a batch of partially-observed spins. Entries marked with ``torch.nan`` will
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
interpreted as a batch of partially-observed spins. Entries marked with ``torch.nan`` will
interpreted as a batch of partially observed spins. Entries marked with ``torch.nan`` will

if mask is not None:
self._x[:, block] = torch.where(mask[:, block], x[:, block], self._x[:, block])

def _validate_input_and_generate_mask(self, x: torch.Tensor) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same suggestion here as in bipartite sampler (docstring, type hints, returns, and defining mask outside)

Comment thread dwave/plugins/torch/samplers/block_spin_sampler.py Outdated
Copy link
Copy Markdown
Contributor

@thisac thisac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recall we talked about this, but it seems like BipartiteGibbsSampler and BlockSampler could share a lot of methods and do with some deduplication. If they're not general enough to fit into TorchSampler, there should either be a hierarchy between them or another common class that they inherit from, or, especially if you foresee some of these methods being used in other samplers, you could create one (or several) mixin classes.

Comment thread dwave/plugins/torch/samplers/bipartite_sampler.py
Comment thread dwave/plugins/torch/samplers/bipartite_sampler.py
Comment thread dwave/plugins/torch/samplers/bipartite_sampler.py
Comment thread dwave/plugins/torch/samplers/bipartite_sampler.py
Comment thread dwave/plugins/torch/samplers/bipartite_sampler.py
Copy link
Copy Markdown
Collaborator

@kevinchern kevinchern left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost there!

Comment thread tests/test_block_sampler.py Outdated
grbm = GRBM(nodes, edges, hidden_nodes=["h1", "h2"])

def crayon(n):
return 0 if n in ["v1", "v2"] else 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return 0 if n in ["v1", "v2"] else 1
return n in ["v1", "v2"]

Comment thread tests/test_block_sampler.py
Comment thread releasenotes/notes/conditional-sampling-94c0d541df3e3fa8.yaml
if mask is not None:
self._x[:, block] = torch.where(mask[:, block], x[:, block], self._x[:, block])

def _validate_input_and_generate_mask(self, x: torch.Tensor) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bumping this suggestion to separate validation and mask generation

Comment thread tests/test_block_sampler.py Outdated
Comment on lines +289 to +295
self.assertEqual(mask.shape, x_valid.shape)

# Chain 0: visible unclamped
self.assertTrue(mask[0, 2:].all()) # First chain: hidden spins are clamped

# Chain 1: hidden unclamped
self.assertTrue(mask[1, :2].all()) # Second chain: visible spins are clamped
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IF we keep the signature of validate_and_generate..., THEN an we combine these tests into one where the mask is hard-coded like expected_mask = torch.tensor([[False, ...], [...]])?

e.g., torch.testing.assertEqual(mask, expected_mask) or self.assertListEqual(mask.tolist(), expected_mask.tolist())

Comment thread tests/test_bipartite_sampler.py Outdated
# Gibbs update for hidden block (block=1)
with self.subTest("hidden block Gibbs update"):
sampler._gibbs_update(0.0, hidden_block, ones*zero_field)
torch.testing.assert_close(torch.tensor(0.0), sampler._x.mean(), atol=1e-2, rtol=1e-2)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this one have a looser tolerance 1e-2 than the previous 1e-3?

Copy link
Copy Markdown
Collaborator Author

@anahitamansouri anahitamansouri Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I noticed that there are fewer random variables in this test compared to the above one, so the estimate has higher variance, and I needed a looser tolerance (1e-2). I could avoid this by setting sampler._x.data[:] = 1.0 just like the earlier example. I can update the test if you think so.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be missing something---don't both tests use sampler._x.mean() so the sample size should be the same(?)

Comment thread tests/test_bipartite_sampler.py Outdated
def test_sample_conditional(self):
nodes = ["v1", "v2", "h1", "h2"]
edges = [["v1", "h1"], ["v1", "h2"], ["v2", "h1"], ["v2", "h2"]]
grbm = GRBM(nodes, edges, hidden_nodes=["h1", "h2"])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider setting the linear fields to be very large so the result is ~= deterministic.
Then, in the test cases, hard-code the expected results per conditional sampling step.

e.g.,
grbm.linear.data[:] = 99999999999999
and sampler._x.data[:] = 1
in one conditional step, everything but the clamped-states should become -1.

Comment thread tests/test_bipartite_sampler.py
Comment thread dwave/plugins/torch/samplers/bipartite_sampler.py
Comment thread dwave/plugins/torch/samplers/bipartite_sampler.py
@kevinchern kevinchern self-requested a review March 26, 2026 18:40
Add conditional sampling functionality for the ``BlockSampler``.
- |
Add ``.clone()`` to the return of ``BlockSampler.sample`` to prevent
unintended in-place modification of the sampler’s internal state due to
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
unintended in-place modification of the samplers internal state due to
unintended in-place modification of the sampler's internal state due to

@anahitamansouri anahitamansouri requested a review from thisac April 2, 2026 22:46
Copy link
Copy Markdown
Contributor

@thisac thisac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a couple of minor comments.

from dwave.plugins.torch.models.boltzmann_machine import (
GraphRestrictedBoltzmannMachine as GRBM,
)
from torch._prims_common import DeviceLikeType
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better avoid importing from "private" module and either have the type alias declared here directly (DeviceLikeType: TypeAlias = str | torch.device | int) or just use either the combined type or *args, **kwargs in the signature as PyTorch does for the to() method.

Copy link
Copy Markdown
Collaborator

@kevinchern kevinchern Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have the type alias declared here directly

@thisac what does this mean 🤔? Is this something you declare at import?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could just add

DeviceLikeType: TypeAlias = str | torch.device | int

to the top of the file (or any file, like utils.py or base.py, and import it from there).

Comment on lines +169 to +170
"""
Computes the effective field for all vertices in ``block``.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
Computes the effective field for all vertices in ``block``.
"""Computes the effective field for all vertices in ``block``.

Comment on lines +238 to +240
mask (torch.Tensor, optional): Boolean tensor of shape
``(num_chains, n_nodes)`` indicating which variables are clamped.
Entries set to ``True`` will keep their values during sampling.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be named something clarifying, like clamp_mask, instead of generically mask?

Comment on lines +282 to +285
ValueError: If ``x`` does not match the sampler state shape
``(num_chains, n_nodes)``, contains values other than ``±1``
or ``NaN``, or if both visible and hidden variables are
simultaneously unclamped within the same chain.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raises should be indented similarly to Args (not Returns).

Suggested change
ValueError: If ``x`` does not match the sampler state shape
``(num_chains, n_nodes)``, contains values other than ``±1``
or ``NaN``, or if both visible and hidden variables are
simultaneously unclamped within the same chain.
ValueError: If ``x`` does not match the sampler state shape
``(num_chains, n_nodes)``, contains values other than ``±1``
or ``NaN``, or if both visible and hidden variables are
simultaneously unclamped within the same chain.

mask = None
for beta in self._schedule:
self._step(beta, mask=mask, x=x)
return self._x.clone() No newline at end of file
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Final empty line missing.

Suggested change
return self._x.clone()
return self._x.clone()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I never bothered to google why this convention is adopted and simply trusted my autoformatter 😆.
Just googled it and thought i'd share a few here (omitted a couple):
image
image
image

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kevinchern what autoformatter are you using?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anahitamansouri I'd recommend using Black with a line-length set to 100 (black -l 100). Just make sure to only touch files and lines that you've added.

Copy link
Copy Markdown
Collaborator Author

@anahitamansouri anahitamansouri Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I used that on our code in a PR and noticed it was changing some lines that weren't my changes. So, I thought people are not using Black here. So, I did not use that anymore :) I'll try it again with 100.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the recommended tool, but we don't enforce it, so there's a fair bit of code that isn't formatted accordingly.

Copy link
Copy Markdown
Collaborator

@kevinchern kevinchern Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sry misesd the notification @anahitamansouri, I use autopep8 with some customization following dwave's contributor guidelinees

Comment on lines +334 to +337
else:
mask = None
for beta in self._schedule:
self._step(beta, mask=mask, x=x)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO makes it much more legible to separate the if-else from for and return with empty lines.

Suggested change
else:
mask = None
for beta in self._schedule:
self._step(beta, mask=mask, x=x)
else:
mask = None
for beta in self._schedule:
self._step(beta, mask=mask, x=x)


Args:
beta (torch.Tensor): Inverse temperature to sample at.
mask (torch.Tensor, optional): Boolean tensor of shape
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here re renaming.

Comment on lines +387 to +390
mask = None
for beta in self._schedule:
self._step(beta)
return self._x
self._step(beta, mask, x)
return self._x.clone() No newline at end of file
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here also about empty lines between if-else and for.

Comment on lines +8 to +11
- |
Add ``.clone()`` to the return of ``BlockSampler.sample`` to prevent
unintended in-place modification of the sampler's internal state due to
returning a reference to the underlying tensor.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Less a feature and more a fix or upgrade, no?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I thought it's a mix of both. I thought BipartiteSampler is a feature and conditional sampling is an upgrade. How would you characterize a feature? :)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So do you suggest to move both under upgrade?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry if I was unclear. I only meant moving the last bullet. The other two are fine.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry. This makes sense :)

Comment on lines +49 to +63
def test_prepare_initial_states(self):
nodes = ["v1", "v2", "h1", "h2"]
edges = [["v1", "h1"], ["v1", "h2"], ["v2", "h1"], ["v2", "h2"]]
grbm = GRBM(nodes, edges, hidden_nodes=["h1", "h2"])

sampler = BipartiteGibbsSampler(grbm, num_chains=2, schedule=[1.0],)
# Invalid spins
with self.subTest("Non-spin initial states."):
self.assertRaisesRegex(ValueError, "contain nonspin values", sampler._prepare_initial_states,
initial_states=torch.tensor([[0, 1, -1, 1]]), num_chains=1)

# Incorrect shape
with self.subTest("Testing initial states with incorrect shape."):
self.assertRaisesRegex(ValueError, "Initial states should be of shape", sampler._prepare_initial_states,
num_chains=2, initial_states=torch.tensor([[-1, 1, 1, 1, -1]]))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only tests exceptions raised. Perhaps rename this test_prepare_initial_states_exceptions and have another test_prepare_initial_states with a test for valid arguments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants