// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/shared_library/provider_api.h"
#include "core/providers/rocm/rocm_kernel.h"

namespace onnxruntime {
namespace rocm {

class Expand final : public RocmKernel {
 public:
  Expand(const OpKernelInfo& info) : RocmKernel(info) {}

  Status ComputeInternal(OpKernelContext* context) const override;
};

Status FuncExpand(
    const RocmKernel* rocm_kernel,
    OpKernelContext* ctx,
    const Tensor* input_data_tensor,
    const Tensor* /*input_shape_tensor*/,
    Tensor* output_tensor);

std::unique_ptr<Tensor> FuncExpand(
    const RocmKernel* rocm_kernel,
    OpKernelContext* ctx,
    const Tensor* input_data_tensor,
    const Tensor* input_shape_tensor);

}  // namespace rocm
}  // namespace onnxruntime
