get_model
get_model
get_model(model_name: str = 'gpt2', device: Union[str, torch.types.Device] = 'cuda') -> AutoModelForCausalLM
Loads and returns a model and tokenizer from the modified Hugging Face Transformers library.
Parameters:
-
model_name(str, default:'gpt2') –The name of the pre-trained model.
-
device(Union[str, Device], default:'cuda') –The device to train on.
Examples:
Returns:
-
AutoModelForCausalLM–The light-weight hooked model and tokenizer.