ThreadedCorpusBatchStep.java
package io.outofprintmagazine.corpus.batch;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
public class ThreadedCorpusBatchStep extends CorpusBatchStep implements ICorpusBatchStep {
private static final Logger logger = LogManager.getLogger(ThreadedCorpusBatchStep.class);
@SuppressWarnings("unused")
private Logger getLogger() {
return logger;
}
public ThreadedCorpusBatchStep() {
super();
}
@Override
public ArrayNode runOne(ObjectNode input) throws Exception {
// TODO Auto-generated method stub
return null;
}
protected int maxThreads = 50;
protected String taskClass = "io.outofprintmagazine.corpus.batch.impl.PostgreSQLCoreNLPLoader";
protected List<Thread> threads = new ArrayList<Thread>();
protected Map<ObjectNode, ThreadedCorpusBatchStepTask> tasks = new HashMap<ObjectNode, ThreadedCorpusBatchStepTask>();
public ArrayNode run(ArrayNode input) {
//getLogger().debug("run");
maxThreads = getData().getProperties().get("maxThreads").asInt();
taskClass = getData().getProperties().get("taskClass").asText();
push(input);
execute();
return pop(input);
}
protected void execute() {
int currentIdx = 0;
while (currentIdx<threads.size()) {
for (int i=0;i<maxThreads&&i+currentIdx<threads.size();i++) {
threads.get(currentIdx+i).start();
//getLogger().debug("starting: " + (i+currentIdx));
}
for (int i=0;i<maxThreads&&i+currentIdx<threads.size();i++) {
try {
threads.get(currentIdx+i).join();
}
catch (Exception e) {
getLogger().error(e);
}
}
//getLogger().debug("joined");
currentIdx=currentIdx+maxThreads;
}
}
protected void push(ArrayNode input) {
int count = 0;
for (JsonNode inputItem : input) {
if (getData().getProperties().has("maxInput") && getData().getProperties().get("maxInput").asInt() < count ) {
break;
}
count++;
boolean foundInputItem = false;
if (!(getData().getProperties().has("noCache") && getData().getProperties().get("noCache").asBoolean())) {
for (JsonNode existingInputItem : getData().getInput()) {
if (existingInputItem.equals(inputItem)) {
foundInputItem = true;
break;
}
}
}
if (!foundInputItem) {
try {
if (!(getData().getProperties().has("noCache") && getData().getProperties().get("noCache").asBoolean())) {
getData().getInput().add(inputItem);
}
//getLogger().debug("pushing: " + getDocID((ObjectNode)(inputItem)));
Object task = Class.forName(taskClass).getConstructor().newInstance();
ThreadedCorpusBatchStepTask currentBatchStep = (ThreadedCorpusBatchStepTask) task;
currentBatchStep.setData(getData());
currentBatchStep.setStorage(getStorage());
currentBatchStep.setParameterStore(getParameterStore());
currentBatchStep.setInput((ObjectNode)inputItem);
tasks.put((ObjectNode)inputItem, currentBatchStep);
threads.add(new Thread(currentBatchStep, getDocID((ObjectNode)inputItem)));
}
catch (Throwable t) {
t.printStackTrace();
getLogger().error(t);
}
}
}
}
public ArrayNode pop(ArrayNode input) {
for (JsonNode inputItem : input) {
if (tasks.get((ObjectNode)inputItem) != null) {
try {
ArrayNode generatedOutput = tasks.get((ObjectNode)inputItem).getOutput();
if (generatedOutput == null) {
getLogger().debug("no generated output for: " + getDocID((ObjectNode)inputItem));
}
else {
for (JsonNode generatedOutputItem : generatedOutput) {
getData().getOutput().add(generatedOutputItem);
}
}
}
catch (Throwable t) {
t.printStackTrace();
getLogger().error(t);
}
}
}
return getData().getOutput();
}
}