// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/common/common.h"
#include "core/providers/rocm/rocm_kernel.h"

namespace onnxruntime {
namespace contrib {
namespace rocm {

using namespace onnxruntime::rocm;

template <typename T>
class ImageScaler final : public RocmKernel {
 public:
  ImageScaler(const OpKernelInfo& info);
  Status ComputeInternal(OpKernelContext* context) const override;

 private:
  float scale_;
  std::vector<float> bias_;
  IAllocatorUniquePtr<float> b_data_;  // gpu copy of bias
};

}  // namespace rocm
}  // namespace contrib
}  // namespace onnxruntime
