Making a function to load LoRA
Let’s add one more list to store keys that have been visited and put all the preceding code together into a function named load_lora
:
def load_lora( pipeline, lora_path, lora_weight = 0.5, device = 'cpu' ): state_dict = load_file(lora_path, device=device) LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_TEXT_ENCODER = 'lora_te' alpha = lora_weight visited = [] # directly update weight in diffusers model for key in state_dict: # as we have set the alpha beforehand, so just skip if '.alpha' in key or key in visited: ...