package net.amygdalum.testrecorder.util;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import net.amygdalum.testrecorder.ByteCode;
import net.amygdalum.testrecorder.SnapshotInput;
import net.amygdalum.testrecorder.SnapshotOutput;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.InsnList;
import org.objectweb.asm.tree.InsnNode;
import org.objectweb.asm.tree.JumpInsnNode;
import org.objectweb.asm.tree.LabelNode;
import org.objectweb.asm.tree.LdcInsnNode;
import org.objectweb.asm.tree.LocalVariableNode;
import org.objectweb.asm.tree.MethodInsnNode;
import org.objectweb.asm.tree.MethodNode;
import org.objectweb.asm.tree.TypeInsnNode;

/* loaded from: input_file:net/amygdalum/testrecorder/util/IORecorderClassLoader.class */
public class IORecorderClassLoader extends AbstractInstrumentedClassLoader {
    private static final String Class_name = Type.getInternalName(Class.class);
    private static final String IORecorderClassLoader_name = Type.getInternalName(IORecorderClassLoader.class);
    private static final String InputProvider_name = Type.getInternalName(InputProvider.class);
    private static final String OutputListener_name = Type.getInternalName(OutputListener.class);
    private static final String SnapshotInput_descriptor = Type.getDescriptor((Class<?>) SnapshotInput.class);
    private static final String SnapshotOutput_descriptor = Type.getDescriptor((Class<?>) SnapshotOutput.class);
    private static final String Class_getClassLoader_descriptor = ByteCode.methodDescriptor(Class.class, "getClassLoader", new Class[0]);
    private static final String IORecorderClassLoader_getOut_descriptor = ByteCode.methodDescriptor(IORecorderClassLoader.class, "getOut", new Class[0]);
    private static final String IORecorderClassLoader_getIn_descriptor = ByteCode.methodDescriptor(IORecorderClassLoader.class, "getIn", new Class[0]);
    private static final String InputProvider_requestInput_descriptor = ByteCode.methodDescriptor(InputProvider.class, "requestInput", Class.class, String.class, Object[].class);
    private static final String OutputListener_notifyOutput_descriptor = ByteCode.methodDescriptor(OutputListener.class, "notifyOutput", Class.class, String.class, Object[].class);
    private InputProvider in;
    private OutputListener out;
    private String root;
    private Set<String> classes;

    public IORecorderClassLoader(Class<?> cls, InputProvider inputProvider, OutputListener outputListener, Set<String> set) {
        super(cls.getClassLoader());
        this.root = cls.getName();
        this.in = inputProvider;
        this.out = outputListener;
        this.classes = set;
        adoptInstrumentations(cls.getClassLoader());
    }

    /* JADX WARN: Multi-variable type inference failed */
    private final void adoptInstrumentations(ClassLoader classLoader) {
        if (classLoader instanceof ClassInstrumenting) {
            for (Map.Entry<String, byte[]> entry : ((ClassInstrumenting) classLoader).getInstrumentations().entrySet()) {
                String key = entry.getKey();
                byte[] value = entry.getValue();
                if (!this.classes.contains(key) && findLoadedClass(key) == null) {
                    define(key, value);
                }
            }
        }
    }

    public InputProvider getIn() {
        return this.in;
    }

    public OutputListener getOut() {
        return this.out;
    }

    @Override // net.amygdalum.testrecorder.util.AbstractInstrumentedClassLoader, java.lang.ClassLoader
    public Class<?> loadClass(String str) throws ClassNotFoundException {
        if (str.startsWith(this.root)) {
            Class<?> findLoadedClass = findLoadedClass(str);
            return findLoadedClass != null ? findLoadedClass : findClass(str);
        }
        if (isInstrumented(str)) {
            return findLoadedClass(str);
        }
        if (!this.classes.contains(str)) {
            return super.loadClass(str);
        }
        try {
            return define(str, instrument(str));
        } catch (Throwable th) {
            throw new ClassNotFoundException(th.getMessage(), th);
        }
    }

    public byte[] instrument(String str) throws IOException {
        return instrument(new ClassReader(str));
    }

    public byte[] instrument(ClassReader classReader) {
        ClassNode classNode = new ClassNode();
        classReader.accept(classNode, 0);
        instrumentInputMethods(classNode);
        instrumentOutputMethods(classNode);
        ClassWriter classWriter = new ClassWriter(3);
        classNode.accept(classWriter);
        return classWriter.toByteArray();
    }

    private void instrumentInputMethods(ClassNode classNode) {
        for (MethodNode methodNode : getInputMethods(classNode)) {
            methodNode.instructions.clear();
            methodNode.instructions.insert(readInput(classNode, methodNode));
        }
    }

    private void instrumentOutputMethods(ClassNode classNode) {
        for (MethodNode methodNode : getOutputMethods(classNode)) {
            methodNode.instructions.insert(notifyOutput(classNode, methodNode));
        }
    }

    private List<MethodNode> getInputMethods(ClassNode classNode) {
        return (List) classNode.methods.stream().filter(methodNode -> {
            return isInputMethod(methodNode);
        }).collect(Collectors.toList());
    }

    private boolean isInputMethod(MethodNode methodNode) {
        if (methodNode.visibleAnnotations == null) {
            return false;
        }
        return methodNode.visibleAnnotations.stream().anyMatch(annotationNode -> {
            return annotationNode.desc.equals(SnapshotInput_descriptor);
        });
    }

    private List<MethodNode> getOutputMethods(ClassNode classNode) {
        return (List) classNode.methods.stream().filter(methodNode -> {
            return isOutputMethod(methodNode);
        }).collect(Collectors.toList());
    }

    private boolean isOutputMethod(MethodNode methodNode) {
        if (methodNode.visibleAnnotations == null) {
            return false;
        }
        return methodNode.visibleAnnotations.stream().anyMatch(annotationNode -> {
            return annotationNode.desc.equals(SnapshotOutput_descriptor);
        });
    }

    private InsnList readInput(ClassNode classNode, MethodNode methodNode) {
        int i = (methodNode.access & 8) == 0 ? 1 : 0;
        Type returnType = Type.getReturnType(methodNode.desc);
        Type[] argumentTypes = Type.getArgumentTypes(methodNode.desc);
        List<LocalVariableNode> range = ByteCode.range(methodNode.localVariables, i, argumentTypes.length);
        InsnList insnList = new InsnList();
        LabelNode labelNode = new LabelNode();
        LabelNode labelNode2 = new LabelNode();
        insnList.add(new LdcInsnNode(Type.getObjectType(classNode.name)));
        insnList.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, Class_name, "getClassLoader", Class_getClassLoader_descriptor, false));
        insnList.add(new InsnNode(89));
        insnList.add(new JumpInsnNode(Opcodes.IFNULL, labelNode));
        insnList.add(new InsnNode(89));
        insnList.add(new TypeInsnNode(Opcodes.INSTANCEOF, IORecorderClassLoader_name));
        insnList.add(new JumpInsnNode(153, labelNode));
        insnList.add(new TypeInsnNode(Opcodes.CHECKCAST, IORecorderClassLoader_name));
        insnList.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, IORecorderClassLoader_name, "getIn", IORecorderClassLoader_getIn_descriptor, false));
        insnList.add(new InsnNode(89));
        insnList.add(new JumpInsnNode(Opcodes.IFNULL, labelNode));
        insnList.add(new LdcInsnNode(Type.getObjectType(classNode.name)));
        insnList.add(new LdcInsnNode(methodNode.name));
        insnList.add(ByteCode.pushAsArray(range, argumentTypes));
        insnList.add(new MethodInsnNode(Opcodes.INVOKEINTERFACE, InputProvider_name, "requestInput", InputProvider_requestInput_descriptor, true));
        insnList.add(returnValue(returnType));
        insnList.add(labelNode);
        insnList.add(new InsnNode(87));
        insnList.add(returnDefaultValue(returnType));
        insnList.add(labelNode2);
        return insnList;
    }

    private InsnList returnValue(Type type) {
        InsnList insnList = new InsnList();
        if (type.getSize() == 0) {
            insnList.add(new InsnNode(Opcodes.RETURN));
            return insnList;
        }
        insnList.add(ByteCode.unboxPrimitives(type));
        switch (type.getSort()) {
            case 1:
            case 2:
            case 3:
            case 4:
            case 5:
                insnList.add(new InsnNode(Opcodes.IRETURN));
                break;
            case 6:
                insnList.add(new InsnNode(Opcodes.FRETURN));
                break;
            case 7:
                insnList.add(new InsnNode(Opcodes.LRETURN));
                break;
            case 8:
                insnList.add(new InsnNode(Opcodes.DRETURN));
                break;
            case 9:
            case 10:
            default:
                insnList.add(new TypeInsnNode(Opcodes.CHECKCAST, type.getInternalName()));
                insnList.add(new InsnNode(Opcodes.ARETURN));
                break;
        }
        return insnList;
    }

    private InsnList returnDefaultValue(Type type) {
        InsnList insnList = new InsnList();
        if (type.getSize() == 0) {
            insnList.add(new InsnNode(Opcodes.RETURN));
            return insnList;
        }
        switch (type.getSort()) {
            case 1:
            case 2:
            case 3:
            case 4:
            case 5:
                insnList.add(new InsnNode(3));
                insnList.add(new InsnNode(Opcodes.IRETURN));
                break;
            case 6:
                insnList.add(new InsnNode(11));
                insnList.add(new InsnNode(Opcodes.FRETURN));
                break;
            case 7:
                insnList.add(new InsnNode(9));
                insnList.add(new InsnNode(Opcodes.LRETURN));
                break;
            case 8:
                insnList.add(new InsnNode(14));
                insnList.add(new InsnNode(Opcodes.DRETURN));
                break;
            case 9:
            case 10:
            default:
                insnList.add(new InsnNode(1));
                insnList.add(new TypeInsnNode(Opcodes.CHECKCAST, type.getInternalName()));
                insnList.add(new InsnNode(Opcodes.ARETURN));
                break;
        }
        return insnList;
    }

    private InsnList notifyOutput(ClassNode classNode, MethodNode methodNode) {
        int i = (methodNode.access & 8) == 0 ? 1 : 0;
        Type[] argumentTypes = Type.getArgumentTypes(methodNode.desc);
        List<LocalVariableNode> range = ByteCode.range(methodNode.localVariables, i, argumentTypes.length);
        InsnList insnList = new InsnList();
        LabelNode labelNode = new LabelNode();
        LabelNode labelNode2 = new LabelNode();
        insnList.add(new LdcInsnNode(Type.getObjectType(classNode.name)));
        insnList.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, Class_name, "getClassLoader", Class_getClassLoader_descriptor, false));
        insnList.add(new InsnNode(89));
        insnList.add(new JumpInsnNode(Opcodes.IFNULL, labelNode));
        insnList.add(new InsnNode(89));
        insnList.add(new TypeInsnNode(Opcodes.INSTANCEOF, IORecorderClassLoader_name));
        insnList.add(new JumpInsnNode(153, labelNode));
        insnList.add(new TypeInsnNode(Opcodes.CHECKCAST, IORecorderClassLoader_name));
        insnList.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, IORecorderClassLoader_name, "getOut", IORecorderClassLoader_getOut_descriptor, false));
        insnList.add(new InsnNode(89));
        insnList.add(new JumpInsnNode(Opcodes.IFNULL, labelNode));
        insnList.add(new LdcInsnNode(Type.getObjectType(classNode.name)));
        insnList.add(new LdcInsnNode(methodNode.name));
        insnList.add(ByteCode.pushAsArray(range, argumentTypes));
        insnList.add(new MethodInsnNode(Opcodes.INVOKEINTERFACE, OutputListener_name, "notifyOutput", OutputListener_notifyOutput_descriptor, true));
        insnList.add(new JumpInsnNode(Opcodes.GOTO, labelNode2));
        insnList.add(labelNode);
        insnList.add(new InsnNode(87));
        insnList.add(labelNode2);
        return insnList;
    }
}
