/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor.rerank;

import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Collectors;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.SimilarityInferenceRequest;
import org.opensearch.neuralsearch.processor.TextInferenceRequest;
import org.opensearch.neuralsearch.processor.rerank.RerankType;
import org.opensearch.neuralsearch.processor.rerank.RescoringRerankProcessor;
import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher;
import org.opensearch.neuralsearch.stats.events.EventStatName;
import org.opensearch.neuralsearch.stats.events.EventStatsManager;

public class MLOpenSearchRerankProcessor
extends RescoringRerankProcessor {
    public static final String MODEL_ID_FIELD = "model_id";
    protected final String modelId;
    protected final MLCommonsClientAccessor mlCommonsClientAccessor;

    public MLOpenSearchRerankProcessor(String description, String tag, boolean ignoreFailure, String modelId, List<ContextSourceFetcher> contextSourceFetchers, MLCommonsClientAccessor mlCommonsClientAccessor) {
        super(RerankType.ML_OPENSEARCH, description, tag, ignoreFailure, contextSourceFetchers);
        this.modelId = modelId;
        this.mlCommonsClientAccessor = mlCommonsClientAccessor;
    }

    @Override
    public void rescoreSearchResponse(SearchResponse response, Map<String, Object> rerankingContext, ActionListener<List<Float>> listener) {
        EventStatsManager.increment(EventStatName.RERANK_ML_PROCESSOR_EXECUTIONS);
        Object ctxObj = rerankingContext.get("document_context_list");
        if (!(ctxObj instanceof List)) {
            listener.onFailure((Exception)new IllegalStateException(String.format(Locale.ROOT, "No document context found! Perhaps \"%s.%s\" is missing from the pipeline definition?", "context", "document_fields")));
            return;
        }
        List ctxList = (List)ctxObj;
        List<String> contexts = ctxList.stream().map(str -> (String)str).collect(Collectors.toList());
        this.mlCommonsClientAccessor.inferenceSimilarity((SimilarityInferenceRequest)((SimilarityInferenceRequest.SimilarityInferenceRequestBuilder)((TextInferenceRequest.TextInferenceRequestBuilder)((SimilarityInferenceRequest.SimilarityInferenceRequestBuilder)SimilarityInferenceRequest.builder().modelId(this.modelId)).queryText((String)rerankingContext.get("query_text"))).inputTexts(contexts)).build(), listener);
    }
}

