//===-- CUDA.cpp -- CUDA Fortran specific lowering ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
//
//===----------------------------------------------------------------------===//

#include "flang/Lower/CUDA.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"

#define DEBUG_TYPE "flang-lower-cuda"

mlir::Type Fortran::lower::gatherDeviceComponentCoordinatesAndType(
    fir::FirOpBuilder &builder, mlir::Location loc,
    const Fortran::semantics::Symbol &sym, fir::RecordType recTy,
    llvm::SmallVector<mlir::Value> &coordinates) {
  unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString());
  mlir::Type fieldTy;
  if (fieldIdx != std::numeric_limits<unsigned>::max()) {
    // Field found in the base record type.
    auto fieldName = recTy.getTypeList()[fieldIdx].first;
    fieldTy = recTy.getTypeList()[fieldIdx].second;
    mlir::Value fieldIndex = fir::FieldIndexOp::create(
        builder, loc, fir::FieldType::get(fieldTy.getContext()), fieldName,
        recTy,
        /*typeParams=*/mlir::ValueRange{});
    coordinates.push_back(fieldIndex);
  } else {
    // Field not found in base record type, search in potential
    // record type components.
    for (auto component : recTy.getTypeList()) {
      if (auto childRecTy = mlir::dyn_cast<fir::RecordType>(component.second)) {
        fieldIdx = childRecTy.getFieldIndex(sym.name().ToString());
        if (fieldIdx != std::numeric_limits<unsigned>::max()) {
          mlir::Value parentFieldIndex = fir::FieldIndexOp::create(
              builder, loc, fir::FieldType::get(childRecTy.getContext()),
              component.first, recTy,
              /*typeParams=*/mlir::ValueRange{});
          coordinates.push_back(parentFieldIndex);
          auto fieldName = childRecTy.getTypeList()[fieldIdx].first;
          fieldTy = childRecTy.getTypeList()[fieldIdx].second;
          mlir::Value childFieldIndex = fir::FieldIndexOp::create(
              builder, loc, fir::FieldType::get(fieldTy.getContext()),
              fieldName, childRecTy,
              /*typeParams=*/mlir::ValueRange{});
          coordinates.push_back(childFieldIndex);
          break;
        }
      }
    }
  }
  if (coordinates.empty())
    TODO(loc, "device resident component in complex derived-type hierarchy");
  return fieldTy;
}

cuf::DataAttributeAttr Fortran::lower::translateSymbolCUFDataAttribute(
    mlir::MLIRContext *mlirContext, const Fortran::semantics::Symbol &sym) {
  std::optional<Fortran::common::CUDADataAttr> cudaAttr =
      Fortran::semantics::GetCUDADataAttr(&sym.GetUltimate());
  return cuf::getDataAttribute(mlirContext, cudaAttr);
}

hlfir::ElementalOp Fortran::lower::isTransferWithConversion(mlir::Value rhs) {
  auto isConversionElementalOp = [](hlfir::ElementalOp elOp) {
    return llvm::hasSingleElement(
               elOp.getBody()->getOps<hlfir::DesignateOp>()) &&
           llvm::hasSingleElement(elOp.getBody()->getOps<fir::LoadOp>()) == 1 &&
           llvm::hasSingleElement(elOp.getBody()->getOps<fir::ConvertOp>()) ==
               1;
  };
  if (auto declOp = mlir::dyn_cast<hlfir::DeclareOp>(rhs.getDefiningOp())) {
    if (!declOp.getMemref().getDefiningOp())
      return {};
    if (auto associateOp = mlir::dyn_cast<hlfir::AssociateOp>(
            declOp.getMemref().getDefiningOp()))
      if (auto elOp = mlir::dyn_cast<hlfir::ElementalOp>(
              associateOp.getSource().getDefiningOp()))
        if (isConversionElementalOp(elOp))
          return elOp;
  }
  if (auto elOp = mlir::dyn_cast<hlfir::ElementalOp>(rhs.getDefiningOp()))
    if (isConversionElementalOp(elOp))
      return elOp;
  return {};
}

bool Fortran::lower::hasDoubleDescriptor(mlir::Value addr) {
  if (auto declareOp =
          mlir::dyn_cast_or_null<hlfir::DeclareOp>(addr.getDefiningOp())) {
    if (mlir::isa_and_nonnull<fir::AddrOfOp>(
            declareOp.getMemref().getDefiningOp())) {
      if (declareOp.getDataAttr() &&
          *declareOp.getDataAttr() == cuf::DataAttribute::Pinned)
        return false;
      return true;
    }
  }
  return false;
}
