DocumentWord2Vec.java
/*******************************************************************************
* Copyright (C) 2020 Ram Sadasiv
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package io.outofprintmagazine.corpus.batch.impl;
import java.io.File;
import java.io.FileInputStream;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.text.sentenceiterator.CollectionSentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import io.outofprintmagazine.corpus.batch.CorpusBatchStep;
import io.outofprintmagazine.corpus.batch.ICorpusBatchStep;
public class DocumentWord2Vec extends CorpusBatchStep implements ICorpusBatchStep {
private static final Logger logger = LogManager.getLogger(DocumentWord2Vec.class);
@SuppressWarnings("unused")
private Logger getLogger() {
return logger;
}
public DocumentWord2Vec() {
super();
}
@Override
public ObjectNode getDefaultProperties() {
ObjectNode properties = getMapper().createObjectNode();
properties.put("minWordFrequency", 2);
properties.put("iterations", 100);
properties.put("layerSize", 50);
properties.put("windowSize", 5);
return properties;
}
@Override
public ArrayNode runOne(ObjectNode inputStepItem) throws Exception {
ArrayNode retval = getMapper().createArrayNode();
ObjectNode outputStepItem = copyInputToOutput(inputStepItem);
ObjectNode oop = (ObjectNode) getJsonNodeFromStorage(inputStepItem, "oopNLPStorage");
storeModel(outputStepItem, oop, new NodeToToken(), "Tokens");
storeModel(outputStepItem, oop, new NodeToLemma(), "Lemmas");
storeModel(outputStepItem, oop, new NodeToLemma_POS(), "Lemmas_POS");
retval.add(outputStepItem);
return retval;
}
protected void storeModel(ObjectNode outputStepItem, ObjectNode oop, NodeToString serializer, String prefix ) throws Exception {
Word2Vec vec = new Word2Vec.Builder()
.minWordFrequency(getData().getProperties().get("minWordFrequency").asInt())
.iterations(getData().getProperties().get("iterations").asInt())
.layerSize(getData().getProperties().get("layerSize").asInt())
.seed(42)
.windowSize(getData().getProperties().get("windowSize").asInt())
.iterate(
new CollectionSentenceIterator(
docToTokenStrings(oop, serializer)
)
)
.tokenizerFactory(new DefaultTokenizerFactory())
.build();
vec.fit();
File f = File.createTempFile(getDocID(outputStepItem), "word2vec");
WordVectorSerializer.writeWord2VecModel(vec, f);
FileInputStream fin = null;
String storageLocation = null;
try {
fin = new FileInputStream(f);
storageLocation = getStorage().storeScratchFileStream(
getData().getCorpusId(),
getOutputScratchFilePath(prefix + "_" + getDocID(outputStepItem), "word2vec"),
fin
);
}
finally {
if (fin != null) {
fin.close();
}
}
if (storageLocation != null) {
outputStepItem.put(
"Word2Vec" + prefix + "Storage",
storageLocation
);
}
}
abstract class NodeToString {
abstract String nodeToString(ObjectNode node);
}
class NodeToToken extends NodeToString {
public String nodeToString(ObjectNode node) {
return node.get("TokensAnnotation").get("word").asText();
}
}
class NodeToLemma extends NodeToString {
public String nodeToString(ObjectNode node) {
return node.get("TokensAnnotation").get("lemma").asText();
}
}
class NodeToLemma_POS extends NodeToString {
public String nodeToString(ObjectNode node) {
return node.get("TokensAnnotation").get("lemma").asText() + "_" + node.get("TokensAnnotation").get("pos").asText();
}
}
protected List<String> docToTokenStrings(ObjectNode doc, NodeToString serializer) {
ArrayNode sentences = (ArrayNode) doc.get("sentences");
List<String> cleanedSentences = new ArrayList<String>();
Iterator<JsonNode> sentencesIter = sentences.iterator();
while (sentencesIter.hasNext()) {
JsonNode sentenceNode = sentencesIter.next();
StringBuffer buf = new StringBuffer();
ArrayNode tokensNode = (ArrayNode) sentenceNode.get("tokens");
Iterator<JsonNode> tokensIter = tokensNode.iterator();
while (tokensIter.hasNext()) {
ObjectNode tokenNode = (ObjectNode) tokensIter.next();
//keep all the verbs
if (
tokenNode.has("OOPActionlessVerbsAnnotation")
|| tokenNode.has("OOPVerbsAnnotation")
) {
buf.append(serializer.nodeToString(tokenNode));
buf.append(" ");
}
else if (
!tokenNode.get("TokensAnnotation").get("pos").asText().equals("POS")
&& !tokenNode.get("TokensAnnotation").get("pos").asText().equals("NFP")
&& !tokenNode.has("OOPPunctuationMarkAnnotation")
//&& !stopWords.contains(tokenNode.get("TokensAnnotation").get("lemma").asText().toLowerCase())
//&& !tokenNode.has("OOPCommonWordsAnnotation")
&& !tokenNode.has("OOPFunctionWordsAnnotation")
) {
buf.append(serializer.nodeToString(tokenNode));
buf.append(" ");
}
}
cleanedSentences.add(buf.toString());
}
return cleanedSentences;
}
}