// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "orttraining/training_ops/rocm/tensor/flatten_and_unpad.h"
#include "orttraining/training_ops/rocm/tensor/flatten_and_unpad_impl.h"
#include "core/providers/rocm/shared_inc/rocm_utils.h"

namespace onnxruntime {
namespace rocm {

ONNX_OPERATOR_KERNEL_EX(
    FlattenAndUnpad,
    kMSDomain,
    1,
    kRocmExecutionProvider,
    (*KernelDefBuilder::Create())
        .TypeConstraint("T", BuildKernelDefConstraints<int32_t, int64_t, MLFloat16, float, double, BFloat16>())
        .TypeConstraint("T_INT", DataTypeImpl::GetTensorType<int64_t>())
        .OutputMemoryType(OrtMemTypeCPUOutput, 1),
    FlattenAndUnpad);

// Put implementation in the anonymous namespace to avoid name collision in the global namespace.
namespace {

template <typename T>
struct FlattenAndUnpadFunctor {
  void operator()(hipStream_t stream,
                  const int64_t output_element_count,
                  const fast_divmod output_element_stride_fdm,
                  const int64_t index_value_upper_bound,
                  const Tensor& input_tensor,
                  const Tensor& indices_tensor,
                  Tensor& output_tensor) const {
    typedef typename ToHipType<T>::MappedType HipT;
    const HipT* input_data = reinterpret_cast<const HipT*>(input_tensor.Data<T>());

    FlattenAndUnpadImpl<HipT>(stream, output_element_count, output_element_stride_fdm, index_value_upper_bound,
                               input_data, indices_tensor.Data<int64_t>(),
                               reinterpret_cast<HipT*>(output_tensor.MutableData<T>()));
  }
};

}  // namespace

Status FlattenAndUnpad::ComputeInternal(OpKernelContext* context) const {
  const Tensor* input_tensor = context->Input<Tensor>(0);
  const Tensor* indices_tensor = context->Input<Tensor>(1);
  ORT_ENFORCE(input_tensor->Shape().NumDimensions() >= 2,
              "input_tensor tensor must have at least 2 dimensions.", input_tensor->Shape().NumDimensions());
  ORT_ENFORCE(indices_tensor->Shape().NumDimensions() == 1,
              "indices_tensor tensor must be 1-D.", indices_tensor->Shape().NumDimensions());

  const auto& input_shape = input_tensor->Shape();
  std::vector<int64_t> output_shape_vec;
  output_shape_vec.reserve(input_shape.NumDimensions() - 1);
  output_shape_vec.push_back(indices_tensor->Shape()[0]);
  int64_t element_stride = 1;
  for (size_t i = 2; i < input_shape.NumDimensions(); ++i) {
    output_shape_vec.push_back(input_shape[i]);
    element_stride *= input_shape[i];
  }

  fast_divmod output_element_stride_fdm(static_cast<int>(element_stride));
  auto output_shape = TensorShape(output_shape_vec);
  Tensor* output_tensor = context->Output(0, output_shape);

  std::vector<int64_t> unflatten_dims_vec;
  unflatten_dims_vec.reserve(2);
  unflatten_dims_vec.push_back(input_shape[0]);
  unflatten_dims_vec.push_back(input_shape[1]);
  const int64_t index_value_upper_bound = input_shape[0] * input_shape[1];

  utils::MLTypeCallDispatcher<int32_t, int64_t, float, MLFloat16, double, BFloat16>
      t_disp(input_tensor->GetElementType());
  t_disp.Invoke<FlattenAndUnpadFunctor>(Stream(context),
                                        output_shape.Size(),
                                        output_element_stride_fdm,
                                        index_value_upper_bound,
                                        *input_tensor,
                                        *indices_tensor,
                                        *output_tensor);

  size_t rank = unflatten_dims_vec.size();
  Tensor* unflatten_dims_tensor = context->Output(1, {static_cast<int>(rank)});
  TensorShape(unflatten_dims_vec).CopyDims(unflatten_dims_tensor->MutableData<int64_t>(), rank);

  return Status::OK();
}

}  // namespace rocm
}  // namespace onnxruntime
