package com.google.research.ic.gesture.visualgesture.trainer;

import com.google.research.ic.gesture.recognition.Instance;
import com.google.research.ic.gesture.visualgesture.classifier.CursiveTrainer;
import com.google.research.ic.gesture.visualgesture.classifier.Dictionary;
import com.google.research.ic.gesture.visualgesture.classifier.LinearBinarySVM;
import com.google.research.ic.gesture.visualgesture.classifier.LinearSVM;
import com.google.research.ic.gesture.visualgesture.classifier.SymbolResult;
import java.io.BufferedInputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;

/* loaded from: classes.dex */
public class CursiveModelTrainer {
    static DecimalFormat DF = new DecimalFormat("#.###");
    public boolean includeDigits;
    List<List<Instance>> dataSources = new ArrayList();
    public int maxAlignmentIterations = 6;
    public boolean liftSegmentTrain = false;
    public boolean liftSegmentTest = false;
    private LinearSVM letterRecognizer = null;
    private CursiveTrainer recognizer = null;
    private boolean randomize = false;

    public CursiveModelTrainer(boolean z) {
        this.includeDigits = false;
        this.includeDigits = z;
    }

    private boolean isValidString(String str) {
        for (char c : str.toCharArray()) {
            if (!Character.isLetter(c) && (!this.includeDigits || !Character.isDigit(c))) {
                return false;
            }
        }
        return true;
    }

    private CursiveTrainer loadRecognizer(String str) {
        CursiveTrainer cursiveTrainer = new CursiveTrainer(null, this.liftSegmentTrain);
        if (str != null) {
            try {
                DataInputStream dataInputStream = new DataInputStream(new BufferedInputStream(new FileInputStream(str)));
                this.letterRecognizer = LinearSVM.deserialize(dataInputStream);
                dataInputStream.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        return cursiveTrainer;
    }

    public static void main(String[] strArr) {
        String str = strArr[0] + "/";
        CursiveModelTrainer cursiveModelTrainer = new CursiveModelTrainer(true);
        cursiveModelTrainer.liftSegmentTrain = false;
        cursiveModelTrainer.liftSegmentTest = false;
        cursiveModelTrainer.loadDataset(str + "letters", 0);
        cursiveModelTrainer.loadDataset(str + "words", 0);
        if (1 != 0) {
            cursiveModelTrainer.loadDataset(str + "pendigits", 0);
            cursiveModelTrainer.loadDataset(str + "numbers", 0);
        }
        cursiveModelTrainer.trainCursiveModel(str + "english_svm", 1.0f);
    }

    private void saveRecognizer(String str) {
        try {
            DataOutputStream dataOutputStream = new DataOutputStream(new FileOutputStream(str));
            this.letterRecognizer.serialize(dataOutputStream);
            dataOutputStream.flush();
            dataOutputStream.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private void test(CursiveTrainer cursiveTrainer, List<Instance> list) {
        cursiveTrainer.getBeamSearch().setDictionary(new Dictionary());
        Iterator<Instance> it = list.iterator();
        while (it.hasNext()) {
            cursiveTrainer.getBeamSearch().getDictionary().add(it.next().label.toLowerCase());
        }
        long currentTimeMillis = System.currentTimeMillis();
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        int i5 = 0;
        int i6 = 0;
        for (Instance instance : list) {
            if (instance.gesture.getStrokesCount() == 0) {
                System.out.println("ZERO_STROKE_GESTURE " + instance.label);
            } else {
                if (instance.label.length() != 1) {
                    i6++;
                } else if (Character.isDigit(instance.label.charAt(0))) {
                    i5++;
                } else {
                    i4++;
                }
                List<SymbolResult> predict = cursiveTrainer.predict(instance.gesture);
                if (predict.size() > 0) {
                    String str = predict.get(0).label;
                    if (str.toLowerCase().equals(instance.label.toLowerCase())) {
                        if (str.length() != 1) {
                            i3++;
                        } else if (Character.isDigit(instance.label.charAt(0))) {
                            i2++;
                        } else {
                            i++;
                        }
                    }
                }
            }
        }
        System.out.println("Letters: " + ((i * 1.0f) / i4));
        System.out.println("Digits: " + ((i2 * 1.0f) / i5));
        System.out.println("Words: " + ((i3 * 1.0f) / i6));
        System.out.println("Time: " + ((((float) (System.currentTimeMillis() - currentTimeMillis)) * 1.0f) / list.size()));
    }

    public void addDataset(List<Instance> list) {
        ArrayList arrayList = new ArrayList();
        for (Instance instance : list) {
            if (isValidString(instance.label)) {
                arrayList.add(instance);
            }
        }
        this.dataSources.add(arrayList);
    }

    public void loadDataset(String str, int i) {
        new ArrayList();
        DataLoader dataLoader = new DataLoader();
        dataLoader.loadDirectory(str, true, i);
        List<Instance> instances = dataLoader.getInstances();
        if (this.randomize) {
            Collections.shuffle(instances);
        }
        addDataset(instances);
    }

    public void testModel(List<Instance> list, String str) {
        test(loadRecognizer(str), list);
    }

    public void trainCursiveModel(String str, float f) {
        LinearBinarySVM.MAX_ITER = 20;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (List<Instance> list : this.dataSources) {
            int size = list.size();
            int min = (int) (size * Math.min(f, 0.8f));
            arrayList.addAll(list.subList(0, (int) (size * f)));
            arrayList2.addAll(list.subList(min, size));
        }
        this.letterRecognizer = new LinearSVM();
        this.recognizer = new CursiveTrainer(this.letterRecognizer, this.liftSegmentTrain);
        this.recognizer.useCache = false;
        for (int i = 0; i < this.maxAlignmentIterations; i++) {
            this.recognizer.setCaseSensitive(true);
            this.recognizer.setLiftSegment(this.liftSegmentTrain);
            this.recognizer.getBeamSearch().setDictionary(null);
            this.recognizer.train(arrayList, -3.4028235E38f);
            saveRecognizer(str + "_" + i);
            this.recognizer.setCaseSensitive(false);
            this.recognizer.setLiftSegment(this.liftSegmentTest);
            test(this.recognizer, arrayList2);
        }
    }
}
