CorpusWord2Vec.java

  1. /*******************************************************************************
  2.  * Copyright (C) 2020 Ram Sadasiv
  3.  *
  4.  * This program is free software: you can redistribute it and/or modify
  5.  * it under the terms of the GNU General Public License as published by
  6.  * the Free Software Foundation, either version 3 of the License, or
  7.  * (at your option) any later version.
  8.  *
  9.  * This program is distributed in the hope that it will be useful,
  10.  * but WITHOUT ANY WARRANTY; without even the implied warranty of
  11.  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  12.  * GNU General Public License for more details.
  13.  *
  14.  * You should have received a copy of the GNU General Public License
  15.  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
  16.  ******************************************************************************/
  17. package io.outofprintmagazine.corpus.batch.impl;

  18. import java.io.File;
  19. import java.io.FileInputStream;
  20. import java.util.ArrayList;
  21. import java.util.Collection;
  22. import java.util.HashMap;
  23. import java.util.Iterator;
  24. import java.util.List;
  25. import java.util.Map;

  26. import org.apache.logging.log4j.LogManager;
  27. import org.apache.logging.log4j.Logger;
  28. import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
  29. import org.deeplearning4j.models.word2vec.Word2Vec;
  30. import org.deeplearning4j.text.sentenceiterator.CollectionSentenceIterator;
  31. import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;

  32. import com.fasterxml.jackson.databind.JsonNode;
  33. import com.fasterxml.jackson.databind.node.ArrayNode;
  34. import com.fasterxml.jackson.databind.node.ObjectNode;

  35. import io.outofprintmagazine.corpus.batch.CorpusBatchStep;
  36. import io.outofprintmagazine.corpus.batch.ICorpusBatchStep;

  37. public class CorpusWord2Vec extends CorpusBatchStep implements ICorpusBatchStep {
  38.    
  39.     private static final Logger logger = LogManager.getLogger(CorpusWord2Vec.class);

  40.     @SuppressWarnings("unused")
  41.     private Logger getLogger() {
  42.         return logger;
  43.     }
  44.    
  45.     private Map<String, Collection<String>> sentenceCollections = new HashMap<String, Collection<String>>();
  46.    
  47.     protected Map<String, Collection<String>> getSentenceCollections() {
  48.         return sentenceCollections;
  49.     }
  50.    
  51.     protected Collection<String> getSentenceCollection(String name) {
  52.         return getSentenceCollections().get(name);
  53.     }
  54.    
  55.     public CorpusWord2Vec() {
  56.         super();
  57.         getSentenceCollections().put("Tokens", new ArrayList<String>());
  58.         getSentenceCollections().put("Lemmas", new ArrayList<String>());
  59.         getSentenceCollections().put("Lemmas_POS", new ArrayList<String>());
  60.     }
  61.    
  62.     @Override
  63.     public ObjectNode getDefaultProperties() {
  64.         ObjectNode properties = getMapper().createObjectNode();
  65.         properties.put("minWordFrequency", 2);
  66.         properties.put("iterations", 100);
  67.         properties.put("layerSize", 50);
  68.         properties.put("windowSize", 5);
  69.         return properties;
  70.     }
  71.    
  72.     @Override
  73.     public ArrayNode run(ArrayNode input)  {
  74.         ArrayNode retval = super.run(input);
  75.         try {
  76.             storeModel("Tokens");
  77.             storeModel("Lemmas");
  78.             storeModel("Lemmas_POS");
  79.         }
  80.         catch (Exception e) {
  81.             e.printStackTrace();
  82.             getLogger().error(e);
  83.         }
  84.         return retval;
  85.     }
  86.    

  87.    
  88.     @Override
  89.     public ArrayNode runOne(ObjectNode inputStepItem) throws Exception {
  90.         ArrayNode retval = getMapper().createArrayNode();
  91.         ObjectNode outputStepItem = copyInputToOutput(inputStepItem);
  92.         ObjectNode oop = (ObjectNode) getJsonNodeFromStorage(inputStepItem, "oopNLPStorage");
  93.         getSentenceCollection("Tokens").addAll(docToTokenStrings(oop, new NodeToToken()));
  94.         getSentenceCollection("Lemmas").addAll(docToTokenStrings(oop, new NodeToLemma()));
  95.         getSentenceCollection("Lemmas_POS").addAll(docToTokenStrings(oop, new NodeToLemma_POS()));
  96.         outputStepItem.put(
  97.                 "oopNLPCorpusAggregatesWord2Vec_TokensStorage",
  98.                 getStorage().getScratchFilePath(
  99.                         getData().getCorpusBatchId(),
  100.                         getData().getCorpusBatchStepId(),
  101.                         "CORPUS_AGGREGATES_WORD2VEC_Tokens.word2vec"
  102.                 )
  103.         );
  104.         outputStepItem.put(
  105.                 "oopNLPCorpusAggregatesWord2Vec_LemmasStorage",
  106.                 getStorage().getScratchFilePath(
  107.                         getData().getCorpusBatchId(),
  108.                         getData().getCorpusBatchStepId(),
  109.                         "CORPUS_AGGREGATES_WORD2VEC_Lemmas.word2vec"
  110.                 )
  111.         );
  112.         outputStepItem.put(
  113.                 "oopNLPCorpusAggregatesWord2Vec_Lemmas_POSStorage",
  114.                 getStorage().getScratchFilePath(
  115.                         getData().getCorpusBatchId(),
  116.                         getData().getCorpusBatchStepId(),
  117.                         "CORPUS_AGGREGATES_WORD2VEC_Lemmas_POS.word2vec"
  118.                 )
  119.         );      
  120.         retval.add(outputStepItem);
  121.         return retval;
  122.     }
  123.        
  124.     protected void storeModel(String prefix) throws Exception {
  125.         Word2Vec vec = new Word2Vec.Builder()
  126.                 .minWordFrequency(getData().getProperties().get("minWordFrequency").asInt())
  127.                 .iterations(getData().getProperties().get("iterations").asInt())
  128.                 .layerSize(getData().getProperties().get("layerSize").asInt())
  129.                 .seed(42)
  130.                 .windowSize(getData().getProperties().get("windowSize").asInt())
  131.                 .iterate(
  132.                         new CollectionSentenceIterator(
  133.                                 getSentenceCollection(prefix)
  134.                         )
  135.                 )
  136.                 .tokenizerFactory(new DefaultTokenizerFactory())
  137.                 .build();
  138.         vec.fit();
  139.         File f = File.createTempFile(prefix, "word2vec");
  140.         WordVectorSerializer.writeWord2VecModel(vec, f);
  141.         FileInputStream fin = null;
  142.         try {
  143.             fin = new FileInputStream(f);
  144.             getStorage().storeScratchFileStream(
  145.                 getData().getCorpusId(),
  146.                 getOutputScratchFilePath("CORPUS_AGGREGATES_WORD2VEC_" + prefix, "word2vec"),
  147.                 fin
  148.             );
  149.         }
  150.         finally {
  151.             if (fin != null) {
  152.                 fin.close();
  153.             }
  154.         }      
  155.     }

  156.     abstract class NodeToString {
  157.         abstract String nodeToString(ObjectNode node);
  158.     }
  159.    
  160.     class NodeToToken extends NodeToString {
  161.         public String nodeToString(ObjectNode node) {
  162.             return node.get("TokensAnnotation").get("word").asText();
  163.         }
  164.     }
  165.    
  166.     class NodeToLemma extends NodeToString {
  167.         public String nodeToString(ObjectNode node) {
  168.             return node.get("TokensAnnotation").get("lemma").asText();
  169.         }
  170.     }
  171.    
  172.     class NodeToLemma_POS extends NodeToString {
  173.         public String nodeToString(ObjectNode node) {
  174.             return node.get("TokensAnnotation").get("lemma").asText() + "_" + node.get("TokensAnnotation").get("pos").asText();
  175.         }
  176.     }
  177.    
  178.     protected List<String> docToTokenStrings(ObjectNode doc, NodeToString serializer) {
  179.         ArrayNode sentences = (ArrayNode) doc.get("sentences");
  180.         List<String> cleanedSentences = new ArrayList<String>();
  181.         Iterator<JsonNode> sentencesIter = sentences.iterator();
  182.         while (sentencesIter.hasNext()) {
  183.             JsonNode sentenceNode = sentencesIter.next();
  184.             StringBuffer buf = new StringBuffer();
  185.             ArrayNode tokensNode = (ArrayNode) sentenceNode.get("tokens");
  186.             Iterator<JsonNode> tokensIter = tokensNode.iterator();
  187.             while (tokensIter.hasNext()) {
  188.                 ObjectNode tokenNode = (ObjectNode) tokensIter.next();
  189.                 //keep all the verbs
  190.                 if (
  191.                         tokenNode.has("OOPActionlessVerbsAnnotation")
  192.                         || tokenNode.has("OOPVerbsAnnotation")
  193.                     ) {
  194.                     buf.append(serializer.nodeToString(tokenNode));
  195.                     buf.append(" ");
  196.                 }
  197.                 else if (
  198.                         !tokenNode.get("TokensAnnotation").get("pos").asText().equals("POS")
  199.                         && !tokenNode.get("TokensAnnotation").get("pos").asText().equals("NFP")
  200.                         && !tokenNode.has("OOPPunctuationMarkAnnotation")
  201.                         //&& !stopWords.contains(tokenNode.get("TokensAnnotation").get("lemma").asText().toLowerCase())
  202.                         //&& !tokenNode.has("OOPCommonWordsAnnotation")
  203.                         && !tokenNode.has("OOPFunctionWordsAnnotation")
  204.                     ) {
  205.                     buf.append(serializer.nodeToString(tokenNode));
  206.                     buf.append(" ");
  207.                 }
  208.             }
  209.            
  210.             cleanedSentences.add(buf.toString());
  211.         }
  212.         return cleanedSentences;
  213.     }

  214. }