Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify Torch Support to handle tensors on GPU #19

Merged
merged 1 commit into from
Aug 12, 2024

Conversation

afspies
Copy link
Contributor

@afspies afspies commented Aug 9, 2024

It seems the torch support assumed tensors exist on CPU, as this is a pre-requisite for conversion to numpy arrays with some_tensor.numpy(). Perhaps this was an intentional design choice to avoid occupying non-accelerator memory for users who aren't being careful, but moving non-cpu tensors onto cpu automatically for the purposes of visualization is likely acceptable in most use cases, so I have added this.

Minor note that the .cpu() and .detach() conversion can be combined into one-line. Not sure what is preferable.

Copy link

google-cla bot commented Aug 9, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Copy link
Collaborator

@danieldjohnson danieldjohnson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! This was an oversight, I must not have thoroughly tested this with GPU tensors.

That said, I think it's better to avoid copying the entire array to CPU memory at the beginning, because the array may be very large and we may only need to visualize a small portion of it. Would you mind changing this to only copy to CPU right before converting to numpy? (Or, I think .numpy(force=True) accomplishes the same thing?)

@afspies
Copy link
Contributor Author

afspies commented Aug 9, 2024

Hi @danieldjohnson,

Thanks for your response - Indeed I only just noticed the slicing! Looking into the force=true flag indeed confirms this as a better option as it also handles some other cases (I didn't know that flag existed!)

Made the changes - not sure if this is needed in get_array_data_with_truncation, but it probably doesn't hurt as torch claims not to return copies unless some sophisticated conversion is performed.

@danieldjohnson
Copy link
Collaborator

Thanks!

One more minor thing: would you mind squashing these commits together into a single commit so that the old version doesn't appear in the commit history?

@afspies
Copy link
Contributor Author

afspies commented Aug 10, 2024

Done :)

@@ -74,12 +74,13 @@ def _truncate_and_copy(
ignoring any axes whose slices are already computed in `source_slices`.
"""
assert torch is not None, "PyTorch is not available."

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, sorry, could you also amend to remove this line with extra whitespace? It's tripping up the internal lint system and I haven't set up the external lint checks yet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully ok now!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

- Conversion to numpy is now done using `some_tensor.numpy(force=True)`
- This ensures device conversion as well as some cases such as complex tensors
@copybara-service copybara-service bot merged commit 8465645 into google-deepmind:main Aug 12, 2024
3 checks passed
@afspies afspies deleted the patch-1 branch August 13, 2024 13:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants