train.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. """Example of training spaCy's named entity recognizer, starting off with an
  2. existing model or a blank model.
  3. For more details, see the documentation:
  4. * Training: https://spacy.io/usage/training
  5. * NER: https://spacy.io/usage/linguistic-features#named-entities
  6. Compatible with: spaCy v2.0.0+
  7. """
  8. import random
  9. import fr_core_news_md
  10. from path import Path
  11. import spacy
  12. # training data
  13. TRAIN_DATA = [
  14. ('Qui est Georges Brassens?', { 'entities': [(8, 24, 'PERSON')] }),
  15. ("J'aime Strasbourg et Avignon.", { 'entities': [(7, 17, 'LOC'), (21, 28, 'LOC')] }),
  16. ("J'aime Strasbourg et Avignon.", { 'entities': [(7, 17, 'LOC'), (21, 28, 'LOC')] }),
  17. ]
  18. MODEL_DIR = Path(__file__).parent / "data"
  19. def main(n_iter=100):
  20. """Load the model, set up the pipeline and train the entity recognizer."""
  21. nlp = spacy.load(MODEL_DIR) # load existing spaCy model @UndefinedVariable
  22. print("Model loaded")
  23. # create the built-in pipeline components and add them to the pipeline
  24. # nlp.create_pipe works for built-ins that are registered with spaCy
  25. if 'ner' not in nlp.pipe_names:
  26. ner = nlp.create_pipe('ner')
  27. nlp.add_pipe(ner, last=True)
  28. # otherwise, get it so we can add labels
  29. else:
  30. ner = nlp.get_pipe('ner')
  31. # add labels
  32. for _, annotations in TRAIN_DATA:
  33. for ent in annotations.get('entities'):
  34. ner.add_label(ent[2])
  35. # get names of other pipes to disable them during training
  36. other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'ner']
  37. with nlp.disable_pipes(*other_pipes): # only train NER
  38. optimizer = nlp.begin_training()
  39. for _ in range(n_iter):
  40. random.shuffle(TRAIN_DATA)
  41. losses = {}
  42. for text, annotations in TRAIN_DATA:
  43. nlp.update(
  44. [text], # batch of texts
  45. [annotations], # batch of annotations
  46. drop=0.5, # dropout - make it harder to memorise data
  47. sgd=optimizer, # callable to update weights
  48. losses=losses)
  49. print(losses)
  50. # test the trained model
  51. for text, _ in TRAIN_DATA:
  52. doc = nlp(text)
  53. print('Entities', [(ent.text, ent.label_) for ent in doc.ents])
  54. print('Tokens', [(t.text, t.ent_type_, t.ent_iob) for t in doc])
  55. # save model to output directory
  56. if not MODEL_DIR.exists():
  57. MODEL_DIR.mkdir()
  58. nlp.to_disk(MODEL_DIR)
  59. print("Saved model to", MODEL_DIR)
  60. if __name__ == '__main__':
  61. main()
  62. # Expected output:
  63. # Entities [('Georges Brassens', 'PERSON')]
  64. # Tokens [('Qui', '', 2), ('est', '', 2), ('Georges', 'PERSON', 3), ('Brassens', 'PERSON', 1), ('?', '', 2)]
  65. # Entities [('Strasbourg', 'LOC'), ('Avignon', 'LOC')]
  66. # Tokens [("J'", '', 2), ('aime', '', 2), ('Strasbourg', 'LOC', 3), ('et', '', 2), ('Avignon', 'LOC', 3), ('.', '', 2)]