get_model
get_model
get_model(model_name: str = 'gpt2', device: Union[str, torch.types.Device] = 'cuda') -> AutoModelForCausalLM
Loads and returns a model and tokenizer from the modified Hugging Face Transformers library.
Parameters:
-
model_name
(str
, default:'gpt2'
) –The name of the pre-trained model.
-
device
(Union[str, Device]
, default:'cuda'
) –The device to train on.
Examples:
Returns:
-
AutoModelForCausalLM
–The light-weight hooked model and tokenizer.