package ai.libs.jaicore.ml.core;

import ai.libs.jaicore.basic.FileUtil;
import ai.libs.jaicore.logging.LoggerUtil;
import ai.libs.jaicore.ml.interfaces.LabeledInstance;
import ai.libs.jaicore.ml.interfaces.LabeledInstances;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/ml/core/SimpleLabeledInstancesImpl.class */
public class SimpleLabeledInstancesImpl extends ArrayList<LabeledInstance<String>> implements LabeledInstances<String> {
    private static final Logger logger = LoggerFactory.getLogger(SimpleLabeledInstancesImpl.class);
    private int numColumns = -1;
    private final Set<String> occurringLabels = new HashSet();

    public SimpleLabeledInstancesImpl() {
    }

    public SimpleLabeledInstancesImpl(String str) throws IOException {
        addAllFromJson(str);
    }

    public SimpleLabeledInstancesImpl(JsonNode jsonNode) {
        addAllFromJson(jsonNode);
    }

    public SimpleLabeledInstancesImpl(File file) throws IOException {
        addAllFromJson(file);
    }

    @Override // java.util.ArrayList, java.util.AbstractList, java.util.AbstractCollection, java.util.Collection, java.util.List
    public boolean add(LabeledInstance<String> labeledInstance) {
        if (this.numColumns < 0) {
            this.numColumns = labeledInstance.getNumberOfColumns();
        } else if (this.numColumns != labeledInstance.getNumberOfColumns()) {
            throw new IllegalArgumentException("Cannot add " + labeledInstance.getNumberOfColumns() + "-valued instance to dataset with " + this.numColumns + " instances.");
        }
        this.occurringLabels.add(labeledInstance.getLabel());
        return super.add((SimpleLabeledInstancesImpl) labeledInstance);
    }

    @Override // ai.libs.jaicore.ml.interfaces.LabeledInstances
    public int getNumberOfRows() {
        return size();
    }

    @Override // ai.libs.jaicore.ml.interfaces.LabeledInstances
    public int getNumberOfColumns() {
        return this.numColumns;
    }

    @Override // ai.libs.jaicore.ml.interfaces.LabeledInstances
    public String toJson() {
        ObjectMapper objectMapper = new ObjectMapper();
        ObjectNode createObjectNode = objectMapper.createObjectNode();
        ArrayNode putArray = createObjectNode.putArray("instances");
        ArrayNode putArray2 = createObjectNode.putArray("labels");
        Iterator<LabeledInstance<String>> it = iterator();
        while (it.hasNext()) {
            LabeledInstance<String> next = it.next();
            ArrayNode addArray = putArray.addArray();
            Iterator it2 = next.iterator();
            while (it2.hasNext()) {
                addArray.add((Double) it2.next());
            }
            putArray2.add(next.getLabel());
        }
        try {
            return objectMapper.writeValueAsString(createObjectNode);
        } catch (JsonProcessingException e) {
            logger.error(LoggerUtil.getExceptionInfo(e));
            return null;
        }
    }

    @Override // ai.libs.jaicore.ml.interfaces.LabeledInstances
    public ArrayList<String> getOccurringLabels() {
        return new ArrayList<>(this.occurringLabels);
    }

    @Override // ai.libs.jaicore.ml.interfaces.LabeledInstances
    public void addAllFromJson(String str) throws IOException {
        addAllFromJson(new ObjectMapper().readTree(str));
    }

    public void addAllFromJson(JsonNode jsonNode) {
        JsonNode jsonNode2 = jsonNode.get("instances");
        JsonNode jsonNode3 = jsonNode.get("labels");
        if (jsonNode3 == null) {
            throw new IllegalArgumentException("No labels provided in the dataset!");
        }
        if (jsonNode2.size() != jsonNode3.size()) {
            throw new IllegalArgumentException("Number of labels does not match the number of instances!");
        }
        int i = 0;
        Iterator it = jsonNode2.iterator();
        while (it.hasNext()) {
            JsonNode jsonNode4 = (JsonNode) it.next();
            SimpleLabeledInstanceImpl simpleLabeledInstanceImpl = new SimpleLabeledInstanceImpl();
            Iterator it2 = jsonNode4.iterator();
            while (it2.hasNext()) {
                simpleLabeledInstanceImpl.add(Double.valueOf(((JsonNode) it2.next()).asDouble()));
            }
            int i2 = i;
            i++;
            simpleLabeledInstanceImpl.setLabel((SimpleLabeledInstanceImpl) jsonNode3.get(i2).asText());
            add((LabeledInstance<String>) simpleLabeledInstanceImpl);
        }
    }

    @Override // ai.libs.jaicore.ml.interfaces.LabeledInstances
    public void addAllFromJson(File file) throws IOException {
        addAllFromJson(FileUtil.readFileAsString(file));
    }

    @Override // java.util.ArrayList, java.util.AbstractList, java.util.Collection, java.util.List
    public int hashCode() {
        return (31 * ((31 * super.hashCode()) + this.numColumns)) + this.occurringLabels.hashCode();
    }

    @Override // java.util.ArrayList, java.util.AbstractList, java.util.Collection, java.util.List
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!super.equals(obj) || getClass() != obj.getClass()) {
            return false;
        }
        SimpleLabeledInstancesImpl simpleLabeledInstancesImpl = (SimpleLabeledInstancesImpl) obj;
        if (this.numColumns != simpleLabeledInstancesImpl.numColumns) {
            return false;
        }
        return this.occurringLabels.equals(simpleLabeledInstancesImpl.occurringLabels);
    }
}
