/*
 * Decompiled with CFR 0.152.
 */
package cn.smartjavaai.face.model.facedect.criterial;

import ai.djl.Device;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.repository.zoo.Criteria;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Translator;
import ai.djl.util.Progress;
import cn.smartjavaai.common.enums.DeviceEnum;
import cn.smartjavaai.face.config.FaceDetConfig;
import cn.smartjavaai.face.constant.RetinaFaceConstant;
import cn.smartjavaai.face.constant.UltraLightFastGenericFaceConstant;
import cn.smartjavaai.face.enums.FaceDetModelEnum;
import cn.smartjavaai.face.translator.FaceDetectionTranslator;
import java.nio.file.Paths;
import java.util.Objects;
import org.apache.commons.lang3.StringUtils;

public class FaceDetCriteriaFactory {
    public static Criteria<Image, DetectedObjects> createCriteria(FaceDetConfig config) {
        Device device = null;
        if (!Objects.isNull(config.getDevice())) {
            device = config.getDevice() == DeviceEnum.CPU ? Device.cpu() : Device.gpu((int)config.getGpuId());
        }
        Criteria criteria = null;
        if (config.getModelEnum() == FaceDetModelEnum.RETINA_FACE) {
            FaceDetectionTranslator translator = new FaceDetectionTranslator(config.getConfidenceThreshold(), config.getNmsThresh(), RetinaFaceConstant.variance, 5000, RetinaFaceConstant.scales, RetinaFaceConstant.steps);
            criteria = Criteria.builder().setTypes(Image.class, DetectedObjects.class).optModelUrls(StringUtils.isNotBlank((CharSequence)config.getModelPath()) ? null : "https://resources.djl.ai/test-models/pytorch/retinaface.zip").optModelPath(StringUtils.isNotBlank((CharSequence)config.getModelPath()) ? Paths.get(config.getModelPath(), new String[0]) : null).optModelName("retinaface").optTranslator((Translator)translator).optDevice(device).optProgress((Progress)new ProgressBar()).optEngine("PyTorch").build();
        } else if (config.getModelEnum() == FaceDetModelEnum.ULTRA_LIGHT_FAST_GENERIC_FACE) {
            FaceDetectionTranslator translator = new FaceDetectionTranslator(config.getConfidenceThreshold(), config.getNmsThresh(), UltraLightFastGenericFaceConstant.variance, 5000, UltraLightFastGenericFaceConstant.scales, UltraLightFastGenericFaceConstant.steps);
            criteria = Criteria.builder().setTypes(Image.class, DetectedObjects.class).optModelUrls(StringUtils.isNotBlank((CharSequence)config.getModelPath()) ? null : "https://resources.djl.ai/test-models/pytorch/ultranet.zip").optModelPath(StringUtils.isNotBlank((CharSequence)config.getModelPath()) ? Paths.get(config.getModelPath(), new String[0]) : null).optTranslator((Translator)translator).optProgress((Progress)new ProgressBar()).optDevice(device).optEngine("PyTorch").build();
        }
        return criteria;
    }
}

