#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict


import unittest

import fbgemm_gpu

import hypothesis.strategies as st
import numpy as np
import torch
from hypothesis import given, settings, Verbosity


# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
open_source: bool = getattr(fbgemm_gpu, "open_source", False)

if open_source:
    # pyre-ignore[21]
    from test_utils import gpu_unavailable
else:
    import fbgemm_gpu.sparse_ops  # noqa: F401, E402
    from fbgemm_gpu.test.test_utils import gpu_unavailable

    torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings")

typed_gpu_unavailable: tuple[bool, str] = gpu_unavailable


def make_pitched_tensor(
    height: int,
    width: int,
    dtype: torch.dtype,
    # pyre-fixme[2]: Parameter must be annotated.
    device,
    alignment: int = 256,
) -> torch.Tensor:
    elem_size = (
        torch.finfo(dtype).bits // 8
        if dtype.is_floating_point
        else torch.iinfo(dtype).bits // 8
    )
    width_bytes = width * elem_size
    pitch_bytes = int(np.ceil(width_bytes / alignment) * alignment)
    pitch_elems = pitch_bytes // elem_size
    storage = torch.randn((height, pitch_elems), dtype=dtype, device=device)
    view = storage[:, :width]  # logical shape
    return view.contiguous() if alignment == 0 else view  # return pitched view


# @unittest.skipIf(open_source, "Not supported in open source yet")
@unittest.skipIf(*typed_gpu_unavailable)
class MergePooledEmbeddingsTest(unittest.TestCase):
    # pyre-fixme[56]: Pyre was not able to infer the type of argument
    #  `hypothesis.strategies.integers($parameter$min_value = 1, $parameter$max_value =
    #  10)` to decorator factory `hypothesis.given`.
    @given(
        num_ads=st.integers(min_value=1, max_value=10),
        embedding_dimension=st.integers(min_value=1, max_value=32),
        ads_tables=st.integers(min_value=1, max_value=32),
        num_gpus=st.integers(min_value=1, max_value=torch.cuda.device_count()),
        non_default_stream=st.booleans(),
        r=st.randoms(use_true_random=False),
        dim=st.integers(min_value=0, max_value=1),
        source_from_same_device=st.booleans(),
    )
    # Can instantiate 8 contexts which takes a long time.
    @settings(verbosity=Verbosity.verbose, max_examples=40, deadline=None)
    def test_merge(
        self,
        num_ads: int,
        embedding_dimension: int,
        ads_tables: int,
        num_gpus: int,
        non_default_stream: bool,
        # pyre-fixme[2]: Parameter must be annotated.
        r,
        dim: int,
        source_from_same_device: bool,
    ) -> None:
        dst_device = r.randint(0, num_gpus - 1)
        torch.cuda.set_device(dst_device)
        ad_ds = [embedding_dimension * ads_tables for _ in range(num_gpus)]
        batch_indices = torch.zeros(num_ads).long().cuda()
        pooled_ad_embeddings = [
            (
                torch.randn(num_ads, ad_d, dtype=torch.float16, device=dst_device)
                if source_from_same_device
                else torch.randn(
                    num_ads, ad_d, dtype=torch.float16, device=torch.device(f"cuda:{i}")
                )
            )
            for i, ad_d in enumerate(ad_ds)
        ]
        r.shuffle(pooled_ad_embeddings)

        streams = [torch.cuda.Stream(device=i) for i in range(num_gpus)]
        import contextlib

        uncat_size = batch_indices.size(0) if dim == 1 else ad_ds[0]

        with contextlib.ExitStack() as stack:
            if non_default_stream:
                for stream in streams:
                    stack.enter_context(torch.cuda.stream(stream))
            output = torch.ops.fbgemm.merge_pooled_embeddings(
                pooled_ad_embeddings, uncat_size, batch_indices.device, dim
            )

        # pyre-fixme[3]: Return type must be annotated.
        # pyre-fixme[2]: Parameter must be annotated.
        def ref(pooled_ad_embeddings, batch_indices):
            return torch.cat([p.cpu() for p in pooled_ad_embeddings], dim=dim)

        output_ref = ref(pooled_ad_embeddings, batch_indices)
        output_cpu = torch.ops.fbgemm.merge_pooled_embeddings(
            [pe.cpu() for pe in pooled_ad_embeddings],
            uncat_size,
            batch_indices.cpu().device,
            dim,
        )
        self.assertEqual(output.device, torch.device(f"cuda:{dst_device}"))
        torch.testing.assert_close(output_ref, output.cpu())
        torch.testing.assert_close(output_ref, output_cpu)

    # pyre-fixme[56]: Pyre was not able to infer the type of argument
    @given(
        num_inputs=st.integers(min_value=1, max_value=10),
        num_gpus=st.integers(min_value=1, max_value=torch.cuda.device_count()),
        r=st.randoms(use_true_random=False),
        use_pitched=st.booleans(),
    )
    # Can instantiate 8 contexts which takes a long time.
    @settings(verbosity=Verbosity.verbose, max_examples=40, deadline=None)
    def test_all_to_one_device(
        self,
        num_inputs: int,
        num_gpus: int,
        # pyre-fixme[2]: Parameter must be annotated.
        r,
        use_pitched: bool,
    ) -> None:
        dst_device = torch.device(f"cuda:{r.randint(0, num_gpus - 1)}")
        with torch.cuda.device(dst_device):
            if use_pitched:
                inputs = [
                    make_pitched_tensor(10, 20, torch.float32, "cpu", alignment=256)
                    for _ in range(num_inputs)
                ]
            else:
                inputs = [torch.randn(10, 20) for _ in range(num_inputs)]

            cuda_inputs = [
                input.to(f"cuda:{i % num_gpus}") for i, input in enumerate(inputs)
            ]
            cuda_outputs = torch.ops.fbgemm.all_to_one_device(cuda_inputs, dst_device)
            for i, o in zip(inputs, cuda_outputs):
                self.assertEqual(o.device, dst_device)
                torch.testing.assert_close(o.cpu(), i)

    def test_merge_pooled_embeddings_gpu_to_cpu(self) -> None:
        dst_device = torch.device("cpu")
        inputs = [torch.randn(10, 20) for _ in range(4)]
        cuda_inputs = [input.to("cuda:0") for i, input in enumerate(inputs)]
        uncat_size = inputs[0].size(1)
        output = torch.ops.fbgemm.merge_pooled_embeddings(
            cuda_inputs, uncat_size, dst_device, 0
        )
        ref_output = torch.ops.fbgemm.merge_pooled_embeddings(
            inputs, uncat_size, dst_device, 0
        )
        torch.testing.assert_close(output, ref_output)

    # pyre-fixme[56]: Pyre was not able to infer the type of argument
    @given(
        num_inputs=st.integers(min_value=1, max_value=8),
        num_gpus=st.integers(min_value=1, max_value=torch.cuda.device_count()),
        r=st.randoms(use_true_random=False),
        target_deivce=st.integers(min_value=0, max_value=torch.cuda.device_count() - 1),
    )
    # Can instantiate 8 contexts which takes a long time.
    @settings(verbosity=Verbosity.verbose, max_examples=40, deadline=None)
    def test_merge_pooled_embeddings_gpu_to_cuda_without_index(
        self,
        # pyre-fixme[2]: Parameter must be annotated.
        num_inputs,
        # pyre-fixme[2]: Parameter must be annotated.
        num_gpus,
        # pyre-fixme[2]: Parameter must be annotated.
        r,
        # pyre-fixme[2]: Parameter must be annotated.
        target_deivce,
    ) -> None:

        out_device = torch.device(f"cuda:{target_deivce}")
        with torch.cuda.device(out_device):
            inputs = [torch.randn(10, 20) for _ in range(num_inputs)]
            cuda_inputs = [
                input.to(f"cuda:{r.randint(0, num_gpus - 1)}")
                for i, input in enumerate(inputs)
            ]
            uncat_size = inputs[0].size(1)
            output = torch.ops.fbgemm.merge_pooled_embeddings(
                cuda_inputs,
                uncat_size,
                torch.device("cuda"),
                0,
            )
            ref_output = torch.ops.fbgemm.merge_pooled_embeddings(
                cuda_inputs, uncat_size, out_device, 0
            )
        torch.testing.assert_close(output, ref_output)

    def test_merge_pooled_embeddings_cpu_with_different_target_device(self) -> None:
        uncat_size = 2
        pooled_embeddings = [torch.ones(uncat_size, 4), torch.ones(uncat_size, 8)]
        output_meta = torch.ops.fbgemm.merge_pooled_embeddings(
            pooled_embeddings,
            uncat_size,
            torch.device("meta"),
            1,
        )
        self.assertFalse(output_meta.is_cpu)
        self.assertTrue(output_meta.is_meta)

    # pyre-fixme[56]: Pyre was not able to infer the type of argument
    #  `hypothesis.strategies.integers($parameter$min_value = 1, $parameter$max_value =
    #  10)` to decorator factory `hypothesis.given`.
    @given(
        num_inputs=st.integers(min_value=1, max_value=10),
        num_gpus=st.integers(min_value=1, max_value=torch.cuda.device_count()),
        r=st.randoms(use_true_random=False),
    )
    # Can instantiate 8 contexts which takes a long time.
    @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None)
    def test_sum_reduce_to_one(
        self,
        # pyre-fixme[2]: Parameter must be annotated.
        num_inputs,
        # pyre-fixme[2]: Parameter must be annotated.
        num_gpus,
        # pyre-fixme[2]: Parameter must be annotated.
        r,
    ) -> None:
        dst_device = torch.device(f"cuda:{r.randint(0, num_gpus - 1)}")
        with torch.cuda.device(dst_device):
            inputs = [torch.randn(10, 20) for _ in range(num_inputs)]
            cuda_inputs = [
                input.to(f"cuda:{i % num_gpus}") for i, input in enumerate(inputs)
            ]
            cuda_output = torch.ops.fbgemm.sum_reduce_to_one(cuda_inputs, dst_device)
            self.assertEqual(cuda_output.device, dst_device)
            torch.testing.assert_close(
                cuda_output.cpu(), torch.stack(inputs).sum(dim=0)
            )

    def test_merge_pooled_embeddings_meta(self) -> None:
        """
        Test that merge_pooled_embeddings works with meta tensor and
        dynamo export mode
        """
        uncat_size = 2
        cat_dim = 1
        pooled_embeddings = [torch.ones(uncat_size, 4), torch.ones(uncat_size, 8)]

        # pyre-fixme[53]: Captured variable `cat_dim` is not annotated.
        # pyre-fixme[53]: Captured variable `pooled_embeddings` is not annotated.
        # pyre-fixme[53]: Captured variable `uncat_size` is not annotated.
        # pyre-fixme[3]: Return type must be annotated.
        # pyre-fixme[2]: Parameter must be annotated.
        def fbgemm_merge_pooled_embeddings(device):
            pooled_embeddings_device = [
                pooled_embedding.to(device) for pooled_embedding in pooled_embeddings
            ]
            return torch.ops.fbgemm.merge_pooled_embeddings(
                pooled_embeddings_device, uncat_size, device, cat_dim
            )

        output_cpu = fbgemm_merge_pooled_embeddings(torch.device("cpu"))
        output_meta = fbgemm_merge_pooled_embeddings(torch.device("meta"))

        self.assertFalse(output_meta.is_cpu)
        self.assertTrue(output_meta.is_meta)

        assert output_meta.shape == output_cpu.shape

    def test_merge_pooled_embeddings_empty_input_tensors(self) -> None:
        uncat_size = 2
        pooled_embeddings = [
            torch.ones(uncat_size, 0, dtype=torch.int32),
            torch.ones(uncat_size, 0, dtype=torch.int32),
        ]
        output = torch.ops.fbgemm.merge_pooled_embeddings(
            pooled_embeddings,
            uncat_size,
            torch.device("cpu"),
            1,
        )
        self.assertEqual(output.numel(), 0)
        self.assertEqual(output.dtype, torch.int32)

        pooled_embeddings = [
            torch.ones(uncat_size, 0, dtype=torch.int32).cuda(),
            torch.ones(uncat_size, 0, dtype=torch.int32).cuda(),
        ]
        output = torch.ops.fbgemm.merge_pooled_embeddings(
            pooled_embeddings,
            uncat_size,
            torch.device("cuda"),
            1,
        )
        self.assertEqual(output.numel(), 0)
        self.assertEqual(output.dtype, torch.int32)


if __name__ == "__main__":
    unittest.main()
