pytorch controlnet image-colorization image-to-image

Model Card for ColorizeNet

<!-- Provide a quick summary of what the model is/does. -->

This model is a ControlNet training to perform image colorization from black and white images.

Model Details

Model Description

<!-- Provide a longer summary of what this model is. -->

ColorizeNet is an image colorization model based on ControlNet, trained using the pre-trained Stable Diffusion model version 2.1 proposed by Stability AI.

Model Sources [optional]

<!-- Provide the basic links for the model. -->


Training Data

<!-- This should link to a Data Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->

The model has been trained on COCO, using all the images in the dataset and converting them to grayscale to use them to condition the ControlNet


Run the model

Instantiate the model and load its configuration and weights

import random

import cv2
import einops
import numpy as np
import torch
from pytorch_lightning import seed_everything

from import HWC3, apply_color, resize_image
from utils.ddim import DDIMSampler
from utils.model import create_model, load_state_dict

model = create_model('./models/cldm_v21.yaml').cpu()
    'lightning_logs/version_6/checkpoints/colorizenet-sd21.ckpt', location='cuda'))
model = model.cuda()
ddim_sampler = DDIMSampler(model)

Read the image to be colorized

input_image = cv2.imread("sample_data/sample1_bw.jpg")
input_image = HWC3(input_image)
img = resize_image(input_image, resolution=512)
H, W, C = img.shape

num_samples = 1
control = torch.from_numpy(img.copy()).float().cuda() / 255.0
control = torch.stack([control for _ in range(num_samples)], dim=0)
control = einops.rearrange(control, 'b h w c -> b c h w').clone()

Prepare the input and parameters of the model

seed = 1294574436
prompt = "Colorize this image"
n_prompt = ""
guess_mode = False
strength = 1.0
eta = 0.0
ddim_steps = 20
scale = 9.0

cond = {"c_concat": [control], "c_crossattn": [
    model.get_learned_conditioning([prompt] * num_samples)]}
un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [
    model.get_learned_conditioning([n_prompt] * num_samples)]}
shape = (4, H // 8, W // 8)

model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else (
    [strength] * 13)

Sample and post-process the results

samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
                                             shape, cond, verbose=False, eta=eta,

x_samples = model.decode_first_stage(samples)
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c')
             * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

results = [x_samples[i] for i in range(num_samples)]
colored_results = [apply_color(img, result) for result in results]


BW Input Colorized
image image
image image
image image
image image
image image
image image