/*
 * Decompiled with CFR 0.152.
 */
package cn.smartjavaai.face.model.facerec.criteria;

import ai.djl.Device;
import ai.djl.modality.cv.Image;
import ai.djl.repository.zoo.Criteria;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Translator;
import ai.djl.util.Progress;
import cn.smartjavaai.common.enums.DeviceEnum;
import cn.smartjavaai.face.config.FaceRecConfig;
import cn.smartjavaai.face.enums.FaceRecModelEnum;
import cn.smartjavaai.face.model.facerec.translator.FaceFeatureTranslator;
import cn.smartjavaai.face.model.facerec.translator.FaceNetRecTranslator;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;

public class FaceRecCriteriaFactory {
    public static Criteria<Image, float[]> createCriteria(FaceRecConfig config) {
        Device device = null;
        if (!Objects.isNull(config.getDevice())) {
            device = config.getDevice() == DeviceEnum.CPU ? Device.cpu() : Device.gpu((int)config.getGpuId());
        }
        Criteria criteria = null;
        if (config.getModelEnum() == FaceRecModelEnum.FACENET_MODEL) {
            criteria = Criteria.builder().setTypes(Image.class, float[].class).optModelName("face_feature").optModelUrls(StringUtils.isNotBlank((CharSequence)config.getModelPath()) ? null : "https://resources.djl.ai/test-models/pytorch/face_feature.zip").optModelPath(StringUtils.isNotBlank((CharSequence)config.getModelPath()) ? Paths.get(config.getModelPath(), new String[0]) : null).optTranslator((Translator)new FaceNetRecTranslator()).optDevice(device).optEngine("PyTorch").optProgress((Progress)new ProgressBar()).build();
        } else if (config.getModelEnum() == FaceRecModelEnum.INSIGHT_FACE_MOBILE_FACENET_MODEL) {
            if (StringUtils.isBlank((CharSequence)config.getModelPath())) {
                throw new RuntimeException("\u8bf7\u6307\u5b9a\u6a21\u578b\u8def\u5f84");
            }
            List<Float> mean = Arrays.asList(Float.valueOf(0.5f), Float.valueOf(0.5f), Float.valueOf(0.5f), Float.valueOf(0.5f), Float.valueOf(0.5f), Float.valueOf(0.5f));
            String normalize = mean.stream().map(Object::toString).collect(Collectors.joining(","));
            criteria = Criteria.builder().setTypes(Image.class, float[].class).optModelPath(Paths.get(config.getModelPath(), new String[0])).optTranslator((Translator)new FaceFeatureTranslator()).optEngine("PyTorch").optProgress((Progress)new ProgressBar()).optDevice(device).build();
        } else if (config.getModelEnum() == FaceRecModelEnum.INSIGHT_FACE_IRSE50_MODEL) {
            if (StringUtils.isBlank((CharSequence)config.getModelPath())) {
                throw new RuntimeException("\u8bf7\u6307\u5b9a\u6a21\u578b\u8def\u5f84");
            }
            List<Float> mean = Arrays.asList(Float.valueOf(0.5f), Float.valueOf(0.5f), Float.valueOf(0.5f), Float.valueOf(0.5f), Float.valueOf(0.5f), Float.valueOf(0.5f));
            String normalize = mean.stream().map(Object::toString).collect(Collectors.joining(","));
            criteria = Criteria.builder().setTypes(Image.class, float[].class).optModelPath(Paths.get(config.getModelPath(), new String[0])).optTranslator((Translator)new FaceFeatureTranslator()).optEngine("PyTorch").optDevice(device).optProgress((Progress)new ProgressBar()).build();
        } else if (config.getModelEnum() == FaceRecModelEnum.ELASTIC_FACE_MODEL) {
            if (StringUtils.isBlank((CharSequence)config.getModelPath())) {
                throw new RuntimeException("\u8bf7\u6307\u5b9a\u6a21\u578b\u8def\u5f84");
            }
            List<Float> mean = Arrays.asList(Float.valueOf(0.5f), Float.valueOf(0.5f), Float.valueOf(0.5f), Float.valueOf(0.5f), Float.valueOf(0.5f), Float.valueOf(0.5f));
            String normalize = mean.stream().map(Object::toString).collect(Collectors.joining(","));
            criteria = Criteria.builder().setTypes(Image.class, float[].class).optModelPath(Paths.get(config.getModelPath(), new String[0])).optTranslator((Translator)new FaceFeatureTranslator()).optEngine("PyTorch").optDevice(device).optProgress((Progress)new ProgressBar()).build();
        }
        return criteria;
    }
}

