package cz.seznam.euphoria.beam;

import cz.seznam.euphoria.beam.coder.PairCoder;
import cz.seznam.euphoria.beam.io.KryoCoder;
import cz.seznam.euphoria.core.client.accumulators.AccumulatorProvider;
import cz.seznam.euphoria.core.client.dataset.Dataset;
import cz.seznam.euphoria.core.client.functional.ReduceFunctor;
import cz.seznam.euphoria.core.client.functional.UnaryFunction;
import cz.seznam.euphoria.core.client.functional.UnaryFunctor;
import cz.seznam.euphoria.core.client.operator.FlatMap;
import cz.seznam.euphoria.core.client.operator.Operator;
import cz.seznam.euphoria.core.client.operator.ReduceByKey;
import cz.seznam.euphoria.core.client.operator.ReduceStateByKey;
import cz.seznam.euphoria.core.client.operator.Union;
import cz.seznam.euphoria.core.client.type.TypeAwareReduceFunctor;
import cz.seznam.euphoria.core.client.type.TypeAwareUnaryFunction;
import cz.seznam.euphoria.core.client.type.TypeAwareUnaryFunctor;
import cz.seznam.euphoria.core.client.type.TypeHint;
import cz.seznam.euphoria.core.executor.graph.DAG;
import cz.seznam.euphoria.core.util.Settings;
import cz.seznam.euphoria.shadow.com.google.common.collect.Iterables;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.joda.time.Duration;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:cz/seznam/euphoria/beam/BeamExecutorContext.class */
public class BeamExecutorContext {
    private DAG<Operator<?, ?>> dag;
    private final Map<Dataset<?>, PCollection<?>> outputs = new HashMap();
    private final Pipeline pipeline;
    private final Duration allowedLateness;
    private final Settings settings;
    private final AccumulatorProvider.Factory accumulatorFactory;

    /* JADX INFO: Access modifiers changed from: package-private */
    public BeamExecutorContext(DAG<Operator<?, ?>> dag, AccumulatorProvider.Factory factory, Pipeline pipeline, Settings settings, Duration duration) {
        this.dag = dag;
        this.accumulatorFactory = factory;
        this.pipeline = pipeline;
        this.settings = settings;
        this.allowedLateness = duration;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public <IN> PCollection<IN> getInput(Operator<IN, ?> operator) {
        return (PCollection) Iterables.getOnlyElement(getInputs(operator));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public <IN> List<PCollection<IN>> getInputs(Operator<IN, ?> operator) {
        return (List) this.dag.getNode(operator).getParents().stream().map((v0) -> {
            return v0.get();
        }).map(operator2 -> {
            PCollection<?> pCollection = this.outputs.get(operator2.output());
            if (pCollection == null) {
                throw new IllegalArgumentException("Output missing for operator " + operator2.getName());
            }
            return pCollection;
        }).collect(Collectors.toList());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public <T> Optional<PCollection<T>> getPCollection(Dataset<T> dataset) {
        return Optional.ofNullable(this.outputs.get(dataset));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Multi-variable type inference failed */
    public <T> void setPCollection(Dataset<T> dataset, PCollection<T> pCollection) {
        PCollection<?> put = this.outputs.put(dataset, pCollection);
        if (put != null && put != pCollection) {
            throw new IllegalStateException("Dataset(" + dataset + ") already materialized.");
        }
        if (put == null) {
            pCollection.setCoder(getOutputCoder(dataset));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Pipeline getPipeline() {
        return this.pipeline;
    }

    boolean strongTypingEnabled() {
        return false;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public <IN, OUT> Coder<OUT> getCoder(UnaryFunction<IN, OUT> unaryFunction) {
        if (unaryFunction instanceof TypeAwareUnaryFunction) {
            return getCoder(((TypeAwareUnaryFunction) unaryFunction).getTypeHint());
        }
        if (strongTypingEnabled()) {
            throw new IllegalArgumentException("Missing type information for function " + unaryFunction);
        }
        return new KryoCoder();
    }

    <IN, OUT> Coder<OUT> getCoder(UnaryFunctor<IN, OUT> unaryFunctor) {
        if (unaryFunctor instanceof TypeAwareUnaryFunctor) {
            return getCoder(((TypeAwareUnaryFunctor) unaryFunctor).getTypeHint());
        }
        if (strongTypingEnabled()) {
            throw new IllegalArgumentException("Missing type information for funtion " + unaryFunctor);
        }
        return new KryoCoder();
    }

    <IN, OUT> Coder<OUT> getCoder(ReduceFunctor<IN, OUT> reduceFunctor) {
        if (reduceFunctor instanceof TypeAwareReduceFunctor) {
            return getCoder(((TypeAwareReduceFunctor) reduceFunctor).getTypeHint());
        }
        if (strongTypingEnabled()) {
            throw new IllegalArgumentException("Missing type information for function " + reduceFunctor);
        }
        return new KryoCoder();
    }

    private <T> Coder<T> getCoder(TypeHint<T> typeHint) {
        try {
            return this.pipeline.getCoderRegistry().getCoder(TypeDescriptor.of(typeHint.getType()));
        } catch (CannotProvideCoderException e) {
            throw new IllegalArgumentException("Unable to provide coder for type hint.", e);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public AccumulatorProvider.Factory getAccumulatorFactory() {
        return this.accumulatorFactory;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Settings getSettings() {
        return this.settings;
    }

    private <T> Coder<T> getOutputCoder(Dataset<T> dataset) {
        Operator<?, T> producer = dataset.getProducer();
        if (producer instanceof FlatMap) {
            return getCoder(((FlatMap) producer).getFunctor());
        }
        if (producer instanceof Union) {
            return getOutputCoder((Dataset) Objects.requireNonNull(Iterables.getFirst(((Union) producer).listInputs(), null)));
        }
        if (producer instanceof ReduceByKey) {
            ReduceByKey reduceByKey = (ReduceByKey) producer;
            return PairCoder.of(getCoder((UnaryFunction) reduceByKey.getKeyExtractor()), getCoder(reduceByKey.getReducer()));
        }
        if (!(producer instanceof ReduceStateByKey)) {
            return producer instanceof WrappedPCollectionOperator ? ((WrappedPCollectionOperator) producer).input.getCoder() : producer == null ? new KryoCoder() : new KryoCoder();
        }
        return new KryoCoder();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Duration getAllowedLateness(Operator<?, ?> operator) {
        return this.allowedLateness;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setTranslationDAG(DAG<Operator<?, ?>> dag) {
        this.dag = dag;
    }
}
