core.model.transformer.models

View Source
from abc import abstractmethod

import core.utils.const as const

from transformers import (
	BertTokenizer,
	BertForNextSentencePrediction,
	BertForSequenceClassification,
	BertForMultipleChoice,
	BertForQuestionAnswering
)
from core.model.transformer import (
	NextSentenceDataset,
	ClassificationDataset,
	MultipleChoiceDataset,
	QuestionTextAnswerSCDDataset, QuestionSCDAnswerTextDataset
)
from core.model.transformer.model import TransformerModel

class BertModel(TransformerModel):
	'''
		General BERT model

		This class is abstract, use subclasses `core.model.transformer.models.IsNextSCDBert`,
		`core.model.transformer.models.IsSCDBert`, `core.model.transformer.models.SelectSCDBert`,
		`core.model.transformer.models.GivenTextFindSCDBert`, or 
		`core.model.transformer.models.GivenSCDFindTextBert`.

		Make sure to run `core.download.init_transformers()` to
		downlad needed ressources.

		**This model uses a GPU if found via CUDA, else it will use multiple CPU cores.**  
		*However, a GPU will be much faster.*
	'''

	def _default_pretrained_model(self):
		return const.BERT_MODEL_DEFAULT

	def _tokenizer_class(self):
		return BertTokenizer

	@abstractmethod
	def _model_class(self):
		pass

	@abstractmethod 
	def _dataset_class(self):
		pass

	@abstractmethod
	def _parse_prediction(self, result, input, tokens):
		pass

class IsNextSCDBert(BertModel):
	'''
		BERT NextSentence model  
		=> "Is next sentence a SCD?"
	'''

	def _model_class(self):
		return BertForNextSentencePrediction 

	def _dataset_class(self):
		return NextSentenceDataset

	def _parse_prediction(self, result, input, tokens):
		next_is_scd = int(result.logits.argmax(dim=1)) == 0
		return next_is_scd, "Next sentence " + ("seems" if next_is_scd else "does not seem") + " to be matching scd!"

class IsSCDBert(BertModel):
	'''
		BERT SentenceClassification model  
		=> "Is the (current) sentence a SCD?"
	'''

	def _model_class(self):
		return BertForSequenceClassification

	def _dataset_class(self):
		return ClassificationDataset

	def _parse_prediction(self, result, input, tokens):
		is_scd = int(result.logits.argmax(dim=1)) == 1
		return is_scd, "Seems to be " + ("a" if is_scd else "no" ) + " scd!"

class SelectSCDBert(BertModel):
	'''
		BERT MultipleChoice model  
		=> "Given a sentence and a selection of SCDs, select the best."
	'''

	def _model_class(self):
		return BertForMultipleChoice

	def _dataset_class(self):
		return MultipleChoiceDataset

	def _parse_prediction(self, result, input, tokens):
		result = int(result.logits.argmax(dim=1))
		return result, "Text: '" + input[0][0] + "'; SCD: '" + input[1][result] + "'"

class GivenSCDFindTextBert(BertModel):
	'''
		BERT QuestionAnswer model  
		=> "Given a scd and text which sentence from text matches scd?"
	'''

	def _model_class(self):
		return BertForQuestionAnswering

	def _dataset_class(self):
		return QuestionSCDAnswerTextDataset

	def _parse_prediction(self, result, input, tokens):
		start_i = result.start_logits.argmax(dim=1)
		end_i = result.end_logits.argmax(dim=1)
		if end_i < start_i:
			end_i, start_i = start_i, end_i

		text = self.tokenizer.decode(tokens['input_ids'][0,start_i:end_i+1])
		return "SCD: '" + input[0] + "'; Text: '" + text + "'"

class GivenTextFindSCDBert(BertModel):
	'''
		BERT QuestionAnswer model  
		=> "Given a sentence from text and scds which scd matches sentence?"
	'''

	def _model_class(self):
		return BertForQuestionAnswering

	def _dataset_class(self):
		return QuestionTextAnswerSCDDataset

	def _parse_prediction(self, result, input, tokens):
		start_i = result.start_logits.argmax(dim=1)
		end_i = result.end_logits.argmax(dim=1)
		if end_i < start_i:
			end_i, start_i = start_i, end_i

		scd = self.tokenizer.decode(tokens['input_ids'][0,start_i:end_i+1])
		return "Text: '" + input[0] + "'; SCD: '" + scd + "'"
View Source
class BertModel(TransformerModel):
	'''
		General BERT model

		This class is abstract, use subclasses `core.model.transformer.models.IsNextSCDBert`,
		`core.model.transformer.models.IsSCDBert`, `core.model.transformer.models.SelectSCDBert`,
		`core.model.transformer.models.GivenTextFindSCDBert`, or 
		`core.model.transformer.models.GivenSCDFindTextBert`.

		Make sure to run `core.download.init_transformers()` to
		downlad needed ressources.

		**This model uses a GPU if found via CUDA, else it will use multiple CPU cores.**  
		*However, a GPU will be much faster.*
	'''

	def _default_pretrained_model(self):
		return const.BERT_MODEL_DEFAULT

	def _tokenizer_class(self):
		return BertTokenizer

	@abstractmethod
	def _model_class(self):
		pass

	@abstractmethod 
	def _dataset_class(self):
		pass

	@abstractmethod
	def _parse_prediction(self, result, input, tokens):
		pass

General BERT model

This class is abstract, use subclasses core.model.transformer.models.IsNextSCDBert, core.model.transformer.models.IsSCDBert, core.model.transformer.models.SelectSCDBert, core.model.transformer.models.GivenTextFindSCDBert, or core.model.transformer.models.GivenSCDFindTextBert.

Make sure to run core.download.init_transformers() to downlad needed ressources.

This model uses a GPU if found via CUDA, else it will use multiple CPU cores.
However, a GPU will be much faster.

#   class IsNextSCDBert(BertModel):
View Source
class IsNextSCDBert(BertModel):
	'''
		BERT NextSentence model  
		=> "Is next sentence a SCD?"
	'''

	def _model_class(self):
		return BertForNextSentencePrediction 

	def _dataset_class(self):
		return NextSentenceDataset

	def _parse_prediction(self, result, input, tokens):
		next_is_scd = int(result.logits.argmax(dim=1)) == 0
		return next_is_scd, "Next sentence " + ("seems" if next_is_scd else "does not seem") + " to be matching scd!"

BERT NextSentence model
=> "Is next sentence a SCD?"

#   class IsSCDBert(BertModel):
View Source
class IsSCDBert(BertModel):
	'''
		BERT SentenceClassification model  
		=> "Is the (current) sentence a SCD?"
	'''

	def _model_class(self):
		return BertForSequenceClassification

	def _dataset_class(self):
		return ClassificationDataset

	def _parse_prediction(self, result, input, tokens):
		is_scd = int(result.logits.argmax(dim=1)) == 1
		return is_scd, "Seems to be " + ("a" if is_scd else "no" ) + " scd!"

BERT SentenceClassification model
=> "Is the (current) sentence a SCD?"

#   class SelectSCDBert(BertModel):
View Source
class SelectSCDBert(BertModel):
	'''
		BERT MultipleChoice model  
		=> "Given a sentence and a selection of SCDs, select the best."
	'''

	def _model_class(self):
		return BertForMultipleChoice

	def _dataset_class(self):
		return MultipleChoiceDataset

	def _parse_prediction(self, result, input, tokens):
		result = int(result.logits.argmax(dim=1))
		return result, "Text: '" + input[0][0] + "'; SCD: '" + input[1][result] + "'"

BERT MultipleChoice model
=> "Given a sentence and a selection of SCDs, select the best."

#   class GivenSCDFindTextBert(BertModel):
View Source
class GivenSCDFindTextBert(BertModel):
	'''
		BERT QuestionAnswer model  
		=> "Given a scd and text which sentence from text matches scd?"
	'''

	def _model_class(self):
		return BertForQuestionAnswering

	def _dataset_class(self):
		return QuestionSCDAnswerTextDataset

	def _parse_prediction(self, result, input, tokens):
		start_i = result.start_logits.argmax(dim=1)
		end_i = result.end_logits.argmax(dim=1)
		if end_i < start_i:
			end_i, start_i = start_i, end_i

		text = self.tokenizer.decode(tokens['input_ids'][0,start_i:end_i+1])
		return "SCD: '" + input[0] + "'; Text: '" + text + "'"

BERT QuestionAnswer model
=> "Given a scd and text which sentence from text matches scd?"

#   class GivenTextFindSCDBert(BertModel):
View Source
class GivenTextFindSCDBert(BertModel):
	'''
		BERT QuestionAnswer model  
		=> "Given a sentence from text and scds which scd matches sentence?"
	'''

	def _model_class(self):
		return BertForQuestionAnswering

	def _dataset_class(self):
		return QuestionTextAnswerSCDDataset

	def _parse_prediction(self, result, input, tokens):
		start_i = result.start_logits.argmax(dim=1)
		end_i = result.end_logits.argmax(dim=1)
		if end_i < start_i:
			end_i, start_i = start_i, end_i

		scd = self.tokenizer.decode(tokens['input_ids'][0,start_i:end_i+1])
		return "Text: '" + input[0] + "'; SCD: '" + scd + "'"

BERT QuestionAnswer model
=> "Given a sentence from text and scds which scd matches sentence?"