Untitled diff

Created Diff never expires
6 removals
62 lines
11 additions
66 lines
#Modified from https://github.com/amrit-das/Custom-Model-Training-PyTorch/blob/master/predict.py
import torch
import torch
import torch.nn as nn
import torch.nn as nn
from torchvision.models import resnet18
#from torchvision.models import resnet18
from torchvision.transforms import transforms
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import numpy as np
import numpy as np
from torch.autograd import Variable
from torch.autograd import Variable
import torch.functional as F
import torch.functional as F
from PIL import Image
from PIL import Image
import os
import os
import sys
import sys
import argparse
import argparse
from prune import *
from finetune import *


parser = argparse.ArgumentParser(description = 'To Predict from a trained model')
parser = argparse.ArgumentParser(description = 'To Predict from a trained model')
parser.add_argument('-i','--image', dest = 'image_name', required = True, help='Path to the image file')
parser.add_argument('-i','--image', dest = 'image_name', required = True, help='Path to the image file')
parser.add_argument('-m','--model', dest = 'model_name', required = True, help='Path to the model')
parser.add_argument('-m','--model', dest = 'model_name', required = True, help='Path to the model')
parser.add_argument('-n','--num_class',dest = 'num_classes', required = True, help='Number of training classes')
parser.add_argument('-n','--num_class',dest = 'num_classes', required = True, help='Number of training classes')
args = parser.parse_args()
args = parser.parse_args()


path_to_model = "./models/"+args.model_name
path_to_model = "./"+args.model_name
checkpoint = torch.load(path_to_model)
#checkpoint = torch.load(path_to_model)


model = resnet18(num_classes = int(args.num_classes))
model = torch.load(args.model_name).cuda()
model.load_state_dict(checkpoint)
#model = resnet18(num_classes = int(args.num_classes))
#model.load_state_dict(checkpoint)
model.eval()
model.eval()


def predict_image(image_path):
def predict_image(image_path):
print("prediciton in progress")
print("prediction in progress")
image = Image.open(image_path)
image = Image.open(image_path)
transformation = transforms.Compose([
transformation = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
])
image_tensor = transformation(image).float()
image_tensor = transformation(image).float()
image_tensor = image_tensor.unsqueeze_(0)
image_tensor = image_tensor.unsqueeze_(0)


if torch.cuda.is_available():
if torch.cuda.is_available():
image_tensor.cuda()
image_tensor.cuda()


input = Variable(image_tensor)
input = Variable(image_tensor)
output = model(input)
output = model(input)


index = output.data.numpy().argmax()
index = output.data.numpy().argmax()
return index
return index


def class_mapping(index):
def class_mapping(index):
mapping=open('class_mapping.txt','r')
mapping=open('class_mapping.txt','r')
class_map={}
class_map={}
for line in mapping:
for line in mapping:
l=line.strip('\n').split('~')
l=line.strip('\n').split('~')
class_map[l[1]]=l[0]
class_map[l[1]]=l[0]
return class_map[str(index)]
return class_map[str(index)]


if __name__ == "__main__":
if __name__ == "__main__":


imagepath = "./Predict_Image/"+args.image_name
imagepath = "./test/Lemon/"+args.image_name
prediction = predict_image(imagepath)
prediction = predict_image(imagepath)
name = class_mapping(prediction)
name = class_mapping(prediction)
print("Predicted Class: ",name)
print("Predicted Class: ",name)