Training Transformer Networks in Scikit-Learn?!

Have you ever wanted to use handy scikit-learn functionalities with your neural networks, but couldn’t because TensorFlow models are not compatible with the scikit-learn API? I’m excited to introduce one-line wrappers for TensorFlow/Keras models that enable you to use TensorFlow models within scikit-learn workflows with features like Pipeline, GridSearch and more.

Transformers are extremely popular for modeling text nowadays with GPT3, ChatGPT, Bard, PaLM, FLAN excelling for conversational AI and other Transformers like T5 & BERT excelling for text classification. Scikit-learn offers a broadly useful suite of features for classifier models, but these are hard to use with Transformers. However not if you use these wrappers we developed, which only require changing one line of code to make your existing Tensorflow/Keras model compatible with scikit-learn’s rich ecosystem!

All you have to do is swap out: keras.ModelKerasWrapperModel, or keras.SequentialKerasSequentialWrapper. The wrapper objects have all the same methods of their keras counterparts, plus you can use them with tons of awesome scikit-learn methods.

Blogpost with demonstration:

Jupyter notebook showing how to make HuggingFace Transformer (BERT model) sklearn-compatible: