reptile_types = 'crocodile','alligator plush', 'renekton'
path = Path('reptiles')
if not path.exists():
    path.mkdir()
    for o in reptile_types:
        dest = (path/o)
        dest.mkdir(exist_ok=True)
        results = search_images_bing(key, f'{o}')
        download_images(dest, urls=results.attrgot('content_url'))
fns = get_image_files(path)
fns
(#416) [Path('reptiles/alligator plush/00000000.jpg'),Path('reptiles/alligator plush/00000001.jpg'),Path('reptiles/alligator plush/00000002.jpg'),Path('reptiles/alligator plush/00000003.jpg'),Path('reptiles/alligator plush/00000004.jpeg'),Path('reptiles/alligator plush/00000005.jpg'),Path('reptiles/alligator plush/00000006.jpg'),Path('reptiles/alligator plush/00000007.jpg'),Path('reptiles/alligator plush/00000008.jpg'),Path('reptiles/alligator plush/00000009.jpg')...]
failed = verify_images(fns)
failed.map(Path.unlink)
(#0) []
class DataLoaders(GetAttr):
    num_workers=0
    def __init__(self, *loaders): self.loaders = loaders
    def __getitem__(self, i): return self.loaders[i]
    train,valid = add_props(lambda i, self: self[i])
        
reptiles = DataBlock(
    blocks=(ImageBlock, CategoryBlock), 
    get_items=get_image_files, 
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=parent_label,
    item_tfms=Resize(128))
dls = reptiles.dataloaders(path)
dls.valid.show_batch(max_n=4, nrows=1)
reptiles = reptiles.new(item_tfms=RandomResizedCrop(224, min_scale=0.5),
                        batch_tfms=aug_transforms())
dls = reptiles.dataloaders(path, num_workers=0) # <- num_workers=0 to prevent window error
learn = cnn_learner(dls, resnet18, metrics=error_rate)
learn.fine_tune(4)
epoch train_loss valid_loss error_rate time
0 1.455003 0.249432 0.084337 00:18
epoch train_loss valid_loss error_rate time
0 0.104521 0.049535 0.024096 00:18
1 0.068319 0.012980 0.012048 00:18
2 0.052283 0.011862 0.000000 00:19
3 0.041685 0.010840 0.000000 00:19
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

Result is very good

# cleaner
# for idx in cleaner.delete(): cleaner.fns[idx].unlink() # delete 

Let's test

my_renek = PILImage.create("renek_plush.png")
display(my_renek.to_thumb(256,256))
pred, pred_idx, probs =learn.predict(my_renek)
f'Prediction: {pred}; Probability: {probs[pred_idx]:.04f}'
'Prediction: alligator plush; Probability: 0.9391'

Very good. It is very accurate since my drawing of a plush is very realistic.

renek = PILImage.create("renek_test.png")
display(renek.to_thumb(256,256))
pred, pred_idx, probs =learn.predict(renek)
f'Prediction: {pred}; Probability: {probs[pred_idx]:.04f}'
'Prediction: renekton; Probability: 0.9834'

Easily recognizes my drawing of Renekton as well. I guess I'm an artist

renek_withoutbg = PILImage.create("renek_test1.png")
display(renek_withoutbg.to_thumb(256,256))
pred, pred_idx, probs =learn.predict(renek_withoutbg)
f'Prediction: {pred}; Probability: {probs[pred_idx]:.04f}'
'Prediction: renekton; Probability: 0.9674'

Expected the model to predict plush becasue I removed the background but it's too smart. (In dataset a lot of plush had empty white background contrast to lots of Renekton images having dark backgrounds)

beard = PILImage.create("beard.jpg")
display(beard.to_thumb(200,200))
pred, pred_idx, probs =learn.predict(beard)
f'Prediction: {pred}; Probability: {probs[pred_idx]:.04f}'
'Prediction: alligator plush; Probability: 0.8644'

Indeed I am an alligator plush with my fake beard!

learn.export()

RUN CODE BELOW TO MAKE YOUR OWN TEST (Download export.pkl file on my github)

from fastai.vision.widgets import *
btn_upload = widgets.FileUpload()
out_pl = widgets.Output()
lbl_pred = widgets.Label()
path = Path('')
learn_inf = load_learner(path/'export.pkl', cpu=True)
def on_data_change(change):
    lbl_pred.value = ''
    img = PILImage.create(btn_upload.data[-1])
    out_pl.clear_output()
    with out_pl: display(img.to_thumb(128,128))
    pred,pred_idx,probs = learn_inf.predict(img)
    lbl_pred.value = f'Prediction: {pred}; Probability: {probs[pred_idx]:.04f}'
btn_upload.observe(on_data_change, names=['data'])

display(VBox([widgets.Label('Feed me a reptile photo!'), btn_upload, out_pl, lbl_pred]))