training code for with attribute(left) and without attribute(right)
352 linhas
import os
import os
import torch
import torch
import torch.nn as nn
import torch.nn as nn
import numpy as np
import numpy as np
import time
import time
import json
import json
import math
import math
from tqdm import tqdm
from tqdm import tqdm
from pandas import DataFrame
from pandas import DataFrame
import sys 
import sys 
sys.path.append("..") 
sys.path.append("..") 
sys.path.append("../..") 
sys.path.append("../..") 
from torchmeta.utils.data import BatchMetaDataLoader
from torchmeta.utils.data import BatchMetaDataLoader
from torchmeta.utils.prototype import get_prototypes, prototypical_loss
from torchmeta.utils.prototype import get_prototypes, prototypical_loss
from model import ProtoNetAGAM
from model import ProtoNetAGAM, ProtoNetAGAMwoAttr
from utils import get_dataset, get_proto_accuracy, get_addition_loss
from utils import get_dataset, get_proto_accuracy, get_addition_loss
from global_utils import Averager, Averager_with_interval, get_outputs_c_h, get_inputs_and_outputs, get_semantic_size
from global_utils import Averager, Averager_with_interval, get_outputs_c_h, get_inputs_and_outputs, get_semantic_size
def save_model(model, args, tag):
def save_model(model, args, tag):
    model_path = os.path.join(args.record_folder, ('_'.join([args.model_name, args.train_data, args.test_data, args.backbone, tag]) + '.pt'))
    model_path = os.path.join(args.record_folder, ('_'.join([args.model_name, args.train_data, args.test_data, args.backbone, tag]) + '.pt'))
    if args.multi_gpu:
    if args.multi_gpu:
        model = model.module
        model = model.module
    with open(model_path, 'wb') as f:
    with open(model_path, 'wb') as f:
        torch.save(model.state_dict(), f)
        torch.save(model.state_dict(), f)
def save_checkpoint(args, model, train_log, optimizer, global_task_count, tag):
def save_checkpoint(args, model, train_log, optimizer, global_task_count, tag):
    if args.multi_gpu:
    if args.multi_gpu:
        model = model.module
        model = model.module
    state = {
    state = {
        'args': args,
        'args': args,
        'state_dict': model.state_dict(),
        'state_dict': model.state_dict(),
        'train_log': train_log,
        'train_log': train_log,
        'val_acc': train_log['max_acc'],
        'val_acc': train_log['max_acc'],
        'optimizer': optimizer.state_dict(),
        'optimizer': optimizer.state_dict(),
        'global_task_count': global_task_count
        'global_task_count': global_task_count
    }
    }
    checkpoint_path = os.path.join(args.record_folder, ('_'.join([args.model_name, args.train_data, args.test_data, args.backbone, tag]) + '_checkpoint.pt.tar'))
    checkpoint_path = os.path.join(args.record_folder, ('_'.join([args.model_name, args.train_data, args.test_data, args.backbone, tag]) + '_checkpoint.pt.tar'))
    with open(checkpoint_path, 'wb') as f:
    with open(checkpoint_path, 'wb') as f:
        torch.save(state, f)
        torch.save(state, f)
if __name__ == '__main__':
if __name__ == '__main__':
    import argparse
    import argparse
    torch.backends.cudnn.enabled = False
    # should be updated in different models
    # should be updated in different models
    parser = argparse.ArgumentParser('Prototypical Networks with AGAM')
    parser = argparse.ArgumentParser('Prototypical Networks with AGAM')
    parser.add_argument('--model-name', type=str, default='protonet_agam',
    parser.add_argument('--model-name', type=str, default='protonet_agam',
        help='Name of the model.')
        help='Name of the model.')
    # experimental settings
    # experimental settings
    parser.add_argument('--data-folder', type=str, default='../../datasets', 
    parser.add_argument('--data-folder', type=str, default='../../datasets', 
        help='Path to the folder the data is downloaded to.')
        help='Path to the folder the data is downloaded to.')
    parser.add_argument('--train-data', type=str, default='cub', 
    parser.add_argument('--train-data', type=str, default='cub', 
        choices=['cub', 'sun'],
        choices=['cub', 'sun'],
        help='Name of the dataset used in meta-train phase.')
        help='Name of the dataset used in meta-train phase.')
    parser.add_argument('--test-data', type=str, default='cub', 
    parser.add_argument('--test-data', type=str, default='cub', 
        choices=['cub', 'sun'],
        choices=['cub', 'sun'],
        help='Name of the dataset used in meta-test phase.')
        help='Name of the dataset used in meta-test phase.')
    parser.add_argument('--backbone', type=str, default='conv4', 
    parser.add_argument('--backbone', type=str, default='conv4', 
        choices=['conv4', 'resnet12'],
        choices=['conv4', 'resnet12'],
        help='Name of the CNN backbone.')
        help='Name of the CNN backbone.')
    parser.add_argument('--lr', type=float, default=0.001,
    parser.add_argument('--lr', type=float, default=0.001,
        help='Initial learning rate (default: 0.001).')
        help='Initial learning rate (default: 0.001).')
    parser.add_argument('--num-shots', type=int, default=5, choices=[1, 5, 10],
    parser.add_argument('--num-shots', type=int, default=5, choices=[1, 5, 10],
        help='Number of examples per class (k in "k-shot", default: 5).')
        help='Number of examples per class (k in "k-shot", default: 5).')
    parser.add_argument('--num-ways', type=int, default=5, choices=[5, 20], 
    parser.add_argument('--num-ways', type=int, default=5, choices=[5, 20], 
        help='Number of classes per task (N in "N-way", default: 5).')
        help='Number of classes per task (N in "N-way", default: 5).')
    parser.add_argument('--test-shots', type=int, default=15,
    parser.add_argument('--test-shots', type=int, default=15,
        help='Number of query examples per class (k in "k-shot", default: 15).')
        help='Number of query examples per class (k in "k-shot", default: 15).')
    parser.add_argument('--batch-tasks', type=int, default=4,
    parser.add_argument('--batch-tasks', type=int, default=4,
        help='Number of tasks in a mini-batch of tasks (default: 4).')
        help='Number of tasks in a mini-batch of tasks (default: 4).')
    # follows Closer_Look, 40000 for 5 shots and 60000 for 1 shot
    # follows Closer_Look, 40000 for 5 shots and 60000 for 1 shot
    parser.add_argument('--train-tasks', type=int, default=40000,
    parser.add_argument('--train-tasks', type=int, default=40000,
        help='Number of tasks in the training phase (default: 40000).')
        help='Number of tasks in the training phase (default: 40000).')
    parser.add_argument('--val-tasks', type=int, default=600,
    parser.add_argument('--val-tasks', type=int, default=600,
        help='Number of tasks in the validation phase (default: 600).')
        help='Number of tasks in the validation phase (default: 600).')
    parser.add_argument('--test-tasks', type=int, default=10000,
    parser.add_argument('--test-tasks', type=int, default=10000,
        help='Number of tasks in the testing phase (default: 10000).')
        help='Number of tasks in the testing phase (default: 10000).')
    parser.add_argument('--augment', type=bool, default=True,
    parser.add_argument('--augment', type=bool, default=True,
        help='Augment the training dataset (default: True).')
        help='Augment the training dataset (default: True).')
    parser.add_argument('--schedule', type=int, nargs='+', default=[15000, 30000, 45000, 60000], 
    parser.add_argument('--schedule', type=int, nargs='+', default=[15000, 30000, 45000, 60000], 
        help='Decrease learning rate at these number of tasks.')
        help='Decrease learning rate at these number of tasks.')
    parser.add_argument('--gamma', type=float, default=0.1,
    parser.add_argument('--gamma', type=float, default=0.1,
        help='Learning rate decreasing ratio (default: 0.1).')
        help='Learning rate decreasing ratio (default: 0.1).')
    parser.add_argument('--valid-every-tasks', type=int, default=1000,
    parser.add_argument('--valid-every-tasks', type=int, default=1000,
        help='Number of tasks for each validation (default: 1000).')
        help='Number of tasks for each validation (default: 1000).')
    # arguments of program
    # arguments of program
    parser.add_argument('--num-workers', type=int, default=2,
    parser.add_argument('--num-workers', type=int, default=16,
        help='Number of workers for data loading (default: 2).')
        help='Number of workers for data loading (default: 2).')
    parser.add_argument('--download', action='store_true',
    parser.add_argument('--download', action='store_true',
        help='Download the dataset in the data folder.')
        help='Download the dataset in the data folder.')
    parser.add_argument('--use-cuda', type=bool, default=True,
    parser.add_argument('--use-cuda', type=bool, default=True,
        help='Use CUDA if available.')
        help='Use CUDA if available.')
    parser.add_argument('--multi-gpu', action='store_true',
    parser.add_argument('--multi-gpu', action='store_true',
        help='True if use multiple GPUs. Else, use single GPU.')
        help='True if use multiple GPUs. Else, use single GPU.')
    # arguments for resume (i.e. checkpoint)
    # arguments for resume (i.e. checkpoint)
    parser.add_argument('--resume', action='store_true',
    parser.add_argument('--resume', action='store_true',
        help='If training starts from resume.')
        help='If training starts from resume.')
    parser.add_argument('--resume-folder', type=str, default=None,
    parser.add_argument('--resume-folder', type=str, default=None,
        help='Path to the folder the resume is saved to.')
        help='Path to the folder the resume is saved to.')
    # special arguments for AGAM
    # special arguments for AGAM
    parser.add_argument('--ca-trade-off', type=float, default=1.0,
    parser.add_argument('--ca-trade-off', type=float, default=1.0,
        help='Value of the trade-off parameter of channel-attention weights similarity term in loss function(default: 1.0).')
        help='Value of the trade-off parameter of channel-attention weights similarity term in loss function(default: 1.0).')
    parser.add_argument('--sa-trade-off', type=float, default=0.1,
    parser.add_argument('--sa-trade-off', type=float, default=0.1,
        help='Value of the trade-off parameter of spatial-attention weights similarity term in loss function(default: 0.1).')
        help='Value of the trade-off parameter of spatial-attention weights similarity term in loss function(default: 0.1).')
    parser.add_argument('--addition-loss', type=str, default='norm_softmargin', 
    parser.add_argument('--addition-loss', type=str, default='norm_softmargin', 
        choices=['norm_softmargin', 'softmargin'],
        choices=['norm_softmargin', 'softmargin'],
        help='Type of the attention alignment loss.')
        help='Type of the attention alignment loss.')
    # arguments of semantic
    # arguments of semantic
    parser.add_argument('--semantic-type', type=str, nargs='+',
    parser.add_argument('--semantic-type', type=str, nargs='+',
        choices=['class_attributes', 'image_attributes'],
        choices=['class_attributes', 'image_attributes'],
        help='Semantic type.')
        help='Semantic type.')
#     parser.add_argument('--attsize', type=float, default=None,
#         help='use 6 or 50 or 156')
    args = parser.parse_args()
    args = parser.parse_args()
    # make folder and tensorboard writer to save model and results
    # make folder and tensorboard writer to save model and results
    cur_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
    cur_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
    args.record_folder = './{}_{}_{}_{}_{}'.format(args.train_data, args.test_data, args.model_name, args.backbone, cur_time)
    args.record_folder = './{}_{}_{}_{}_{}'.format(args.train_data, args.test_data, args.model_name, args.backbone, cur_time)
    # writer = SummaryWriter(args.record_folder)
    # writer = SummaryWriter(args.record_folder)
    os.makedirs(args.record_folder, exist_ok=True)
    os.makedirs(args.record_folder, exist_ok=True)
    if args.use_cuda and torch.cuda.is_available():
    if args.use_cuda and torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.benchmark = True
    elif args.use_cuda:
    elif args.use_cuda:
        raise RuntimeError('You are using GPU mode, but GPUs are not available!')
        raise RuntimeError('You are using GPU mode, but GPUs are not available!')
    # construct model and optimizer
    # construct model and optimizer
    assert (args.train_data == args.test_data)
    assert (args.train_data == args.test_data)
    args.image_len = 84
    args.image_len = 84
    args.semantic_size = get_semantic_size(args)
    args.semantic_size = get_semantic_size(args)
    args.out_channels, args.feature_h = get_outputs_c_h(args.backbone, args.image_len)
    args.out_channels, args.feature_h = get_outputs_c_h(args.backbone, args.image_len)
    model = ProtoNetAGAM(args.backbone, args.semantic_size, args.out_channels)
    model = ProtoNetAGAMwoAttr(args.backbone, args.semantic_size, args.out_channels)
    if args.use_cuda:
    if args.use_cuda:
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        num_gpus = torch.cuda.device_count()
        num_gpus = torch.cuda.device_count()
        if args.multi_gpu:
        if args.multi_gpu:
            model = nn.DataParallel(model)
            model = nn.DataParallel(model)
        model = model.cuda()
        model = model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0005)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0005)
    # training from the checkpoint
    # training from the checkpoint
    if args.resume and args.resume_folder is not None:
    if args.resume and args.resume_folder is not None:
        # load checkpoint
        # load checkpoint
        checkpoint_path = os.path.join(args.resume_folder, ('_'.join([args.model_name, args.train_data, args.test_data, args.backbone, 'max_acc']) + '_checkpoint.pt.tar'))    # tag='max_acc' can be changed
        checkpoint_path = os.path.join(args.resume_folder, ('_'.join([args.model_name, args.train_data, args.test_data, args.backbone, 'max_acc']) + '_checkpoint.pt.tar'))    # tag='max_acc' can be changed
        state = torch.load(checkpoint_path)
        state = torch.load(checkpoint_path)
        resumed_state = state['state_dict']
        resumed_state = state['state_dict']
        if args.multi_gpu:
        if args.multi_gpu:
            model.module.load_state_dict(resumed_state)
            model.module.load_state_dict(resumed_state)
        else:
        else:
            model.load_state_dict(resumed_state)
            model.load_state_dict(resumed_state)
        train_log = state['train_log']
        train_log = state['train_log']
        optimizer.load_state_dict(state['optimizer'])
        optimizer.load_state_dict(state['optimizer'])
        initial_lr = optimizer.param_groups[0]['lr']
        initial_lr = optimizer.param_groups[0]['lr']
        global_task_count = state['global_task_count']
        global_task_count = state['global_task_count']
        print('global_task_count: {}, initial_lr: {}'.format(str(global_task_count), str(initial_lr)))
        print('global_task_count: {}, initial_lr: {}'.format(str(global_task_count), str(initial_lr)))
    # training from scratch
    # training from scratch
    else:
    else:
        train_log = {}
        train_log = {}
        train_log['args'] = vars(args)
        train_log['args'] = vars(args)
        train_log['train_loss'] = []
        train_log['train_loss'] = []
        train_log['train_acc'] = []
        train_log['train_acc'] = []
        train_log['val_loss'] = []
        train_log['val_loss'] = []
        train_log['val_acc'] = []
        train_log['val_acc'] = []
        train_log['max_acc'] = 0.0
        train_log['max_acc'] = 0.0
        train_log['max_acc_i_task'] = 0
        train_log['max_acc_i_task'] = 0
        initial_lr = args.lr
        initial_lr = args.lr
        global_task_count = 0
        global_task_count = 0
    # save the args into .json file
    # save the args into .json file
    with open(os.path.join(args.record_folder, 'args.json'), 'w') as f:
    with open(os.path.join(args.record_folder, 'args.json'), 'w') as f:
        json.dump(vars(args), f)
        json.dump(vars(args), f)
    # get datasets and dataloaders
    # get datasets and dataloaders
    train_dataset = get_dataset(args, dataset_name=args.train_data, phase='train')
    train_dataset = get_dataset(args, dataset_name=args.train_data, phase='train')
    val_dataset = get_dataset(args, dataset_name=args.test_data, phase='val')
    val_dataset = get_dataset(args, dataset_name=args.test_data, phase='val')
    test_dataset = get_dataset(args, dataset_name=args.test_data, phase='test')
    test_dataset = get_dataset(args, dataset_name=args.test_data, phase='test')
    train_loader = BatchMetaDataLoader(train_dataset,
    train_loader = BatchMetaDataLoader(train_dataset,
                                     batch_size=args.batch_tasks,
                                     batch_size=args.batch_tasks,
                                     shuffle=True,
                                     shuffle=True,
                                     num_workers=args.num_workers,
                                     num_workers=args.num_workers,
                                     pin_memory=True)
                                     pin_memory=True)
    val_loader = BatchMetaDataLoader(val_dataset,
    val_loader = BatchMetaDataLoader(val_dataset,
                                     batch_size=args.batch_tasks,
                                     batch_size=args.batch_tasks,
                                     shuffle=False,
                                     shuffle=False,
                                     num_workers=args.num_workers,
                                     num_workers=args.num_workers,
                                     pin_memory=True)
                                     pin_memory=True)
    test_loader = BatchMetaDataLoader(test_dataset,
    test_loader = BatchMetaDataLoader(test_dataset,
                                     batch_size=args.batch_tasks,
                                     batch_size=args.batch_tasks,
                                     shuffle=False,
                                     shuffle=False,
                                     num_workers=args.num_workers,
                                     num_workers=args.num_workers,
                                     pin_memory=True)
                                     pin_memory=True)
    # training
    # training
    with tqdm(train_loader, total=int(args.train_tasks/args.batch_tasks), initial=int(global_task_count / args.batch_tasks)) as pbar:
    with tqdm(train_loader, total=int(args.train_tasks/args.batch_tasks), initial=int(global_task_count / args.batch_tasks)) as pbar:
        for i_train_batch, train_batch in enumerate(pbar, int(global_task_count / args.batch_tasks)+1):
        for i_train_batch, train_batch in enumerate(pbar, int(global_task_count / args.batch_tasks)+1):
            if i_train_batch > (args.train_tasks / args.batch_tasks):
            if i_train_batch > (args.train_tasks / args.batch_tasks):
                break
                break
            model.train()
            model.train()
            # check if lr should decrease as in schedule
            # check if lr should decrease as in schedule
            if (i_train_batch * args.batch_tasks) in args.schedule:
            if (i_train_batch * args.batch_tasks) in args.schedule:
                initial_lr *= args.gamma
                initial_lr *= args.gamma
                for param_group in optimizer.param_groups:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = initial_lr
                    param_group['lr'] = initial_lr
            global_task_count += args.batch_tasks
            global_task_count += args.batch_tasks
            support_inputs, support_targets, support_semantics = get_inputs_and_outputs(args, train_batch['train'])
            support_inputs, support_targets, support_semantics = get_inputs_and_outputs(args, train_batch['train'])
            query_inputs, query_targets, _ = get_inputs_and_outputs(args, train_batch['test'])
            query_inputs, query_targets, _ = get_inputs_and_outputs(args, train_batch['test'])
            support_embeddings, ca_weights, sca_weights, sa_weights, ssa_weights = model(support_inputs, semantics=support_semantics, output_weights=True)
            support_embeddings, ca_weights, sa_weights = model(support_inputs, semantics=support_semantics, output_weights=True)
            addition_loss = get_addition_loss(ca_weights, sca_weights, sa_weights, ssa_weights, args)
            query_embeddings = model(query_inputs)
            query_embeddings = model(query_inputs)
            prototypes = get_prototypes(support_embeddings, support_targets, 
            prototypes = get_prototypes(support_embeddings, support_targets, 
                                        train_dataset.num_classes_per_task)
                                        train_dataset.num_classes_per_task)
            train_loss = prototypical_loss(prototypes, query_embeddings, query_targets) + addition_loss
            train_loss = prototypical_loss(prototypes, query_embeddings, query_targets)
            train_acc = get_proto_accuracy(prototypes, query_embeddings, query_targets)
            train_acc = get_proto_accuracy(prototypes, query_embeddings, query_targets)
            del ca_weights, sca_weights, sa_weights, ssa_weights
            del ca_weights, sa_weights
            optimizer.zero_grad()
            optimizer.zero_grad()
            train_loss.backward()
            train_loss.backward()
            optimizer.step()
            optimizer.step()
            pbar.set_postfix(train_acc='{0:.4f}'.format(train_acc.item()))
            pbar.set_postfix(train_acc='{0:.4f}'.format(train_acc.item()))
            # validation
            # validation
            if global_task_count % args.valid_every_tasks == 0:
            if global_task_count % args.valid_every_tasks == 0:
                val_loss_averager = Averager()
                val_loss_averager = Averager()
                val_acc_averager = Averager_with_interval()
                val_acc_averager = Averager_with_interval()
                model.eval()
                model.eval()
                with torch.no_grad():
                with torch.no_grad():
                    for i_val_batch, val_batch in enumerate(val_loader, 1):
                    for i_val_batch, val_batch in enumerate(val_loader, 1):
                        if i_val_batch > (args.val_tasks / args.batch_tasks):
                        if i_val_batch > (args.val_tasks / args.batch_tasks):
                            break
                            break
                        support_inputs, support_targets, support_semantics = get_inputs_and_outputs(args, val_batch['train'])
                        support_inputs, support_targets, support_semantics = get_inputs_and_outputs(args, val_batch['train'])
                        query_inputs, query_targets, _ = get_inputs_and_outputs(args, val_batch['test'])
                        query_inputs, query_targets, _ = get_inputs_and_outputs(args, val_batch['test'])
                        support_embeddings, ca_weights, sca_weights, sa_weights, ssa_weights = model(support_inputs, semantics=support_semantics, output_weights=True)
                        support_embeddings, ca_weights, sa_weights = model(support_inputs, semantics=support_semantics, output_weights=True)
                        addition_loss = get_addition_loss(ca_weights, sca_weights, sa_weights, ssa_weights, args)
                        query_embeddings = model(query_inputs)
                        query_embeddings = model(query_inputs)
                        prototypes = get_prototypes(support_embeddings, support_targets, 
                        prototypes = get_prototypes(support_embeddings, support_targets, 
                                                    val_dataset.num_classes_per_task)
                                                    val_dataset.num_classes_per_task)
                        val_loss = prototypical_loss(prototypes, query_embeddings, query_targets) + addition_loss
                        val_loss = prototypical_loss(prototypes, query_embeddings, query_targets)
                        val_acc = get_proto_accuracy(prototypes, query_embeddings, query_targets)
                        val_acc = get_proto_accuracy(prototypes, query_embeddings, query_targets)
                        del ca_weights, sca_weights, sa_weights, ssa_weights
                        del ca_weights, sa_weights
                        val_loss_averager.add(val_loss.item())
                        val_loss_averager.add(val_loss.item())
                        val_acc_averager.add(val_acc.item())
                        val_acc_averager.add(val_acc.item())
                # record
                # record
                val_acc_mean = val_acc_averager.item()
                val_acc_mean = val_acc_averager.item()
                # print('global_task_count: {}, val_acc_mean: {}'.format(str(global_task_count), str(val_acc_mean)))
                # print('global_task_count: {}, val_acc_mean: {}'.format(str(global_task_count), str(val_acc_mean)))
                if val_acc_mean > train_log['max_acc']:
                if val_acc_mean > train_log['max_acc']:
                    train_log['max_acc'] = val_acc_mean
                    train_log['max_acc'] = val_acc_mean
                    train_log['max_acc_i_task'] = global_task_count
                    train_log['max_acc_i_task'] = global_task_count
                    save_model(model, args, tag='max_acc')
                    save_model(model, args, tag='max_acc')
                train_log['train_loss'].append(train_loss.item())
                train_log['train_loss'].append(train_loss.item())
                train_log['train_acc'].append(train_acc.item())
                train_log['train_acc'].append(train_acc.item())
                train_log['val_loss'].append(val_loss_averager.item())
                train_log['val_loss'].append(val_loss_averager.item())
                train_log['val_acc'].append(val_acc_mean)
                train_log['val_acc'].append(val_acc_mean)
                save_checkpoint(args, model, train_log, optimizer, global_task_count, tag='max_acc')
                save_checkpoint(args, model, train_log, optimizer, global_task_count, tag='max_acc')
                del val_loss_averager, val_acc_averager
                del val_loss_averager, val_acc_averager
    # testing
    # testing
    test_loss_averager = Averager()
    test_loss_averager = Averager()
    test_acc_averager = Averager_with_interval()
    test_acc_averager = Averager_with_interval()
    model.eval()
    model.eval()
    with torch.no_grad():
    with torch.no_grad():
        with tqdm(test_loader, total=int(args.test_tasks/args.batch_tasks)) as pbar:
        with tqdm(test_loader, total=int(args.test_tasks/args.batch_tasks)) as pbar:
            for i_test_batch, test_batch in enumerate(pbar, 1): 
            for i_test_batch, test_batch in enumerate(pbar, 1): 
                if i_test_batch > (args.test_tasks / args.batch_tasks):
                if i_test_batch > (args.test_tasks / args.batch_tasks):
                    break
                    break
                support_inputs, support_targets, support_semantics = get_inputs_and_outputs(args, test_batch['train'])
                support_inputs, support_targets, support_semantics = get_inputs_and_outputs(args, test_batch['train'])
                query_inputs, query_targets, _ = get_inputs_and_outputs(args, test_batch['test'])
                query_inputs, query_targets, _ = get_inputs_and_outputs(args, test_batch['test'])
                support_embeddings, ca_weights, sca_weights, sa_weights, ssa_weights = model(support_inputs, semantics=support_semantics, output_weights=True)
                support_embeddings, ca_weights, sa_weights = model(support_inputs, semantics=support_semantics, output_weights=True)
                addition_loss = get_addition_loss(ca_weights, sca_weights, sa_weights, ssa_weights, args)
                query_embeddings = model(query_inputs)
                query_embeddings = model(query_inputs)
                prototypes = get_prototypes(support_embeddings, support_targets, 
                prototypes = get_prototypes(support_embeddings, support_targets, 
                                            test_dataset.num_classes_per_task)
                                            test_dataset.num_classes_per_task)
                test_loss = prototypical_loss(prototypes, query_embeddings, query_targets) + addition_loss
                test_loss = prototypical_loss(prototypes, query_embeddings, query_targets)
                test_acc = get_proto_accuracy(prototypes, query_embeddings, query_targets)
                test_acc = get_proto_accuracy(prototypes, query_embeddings, query_targets)
                del ca_weights, sca_weights, sa_weights, ssa_weights
                del ca_weights, sa_weights
                pbar.set_postfix(test_acc='{0:.4f}'.format(test_acc.item()))
                pbar.set_postfix(test_acc='{0:.4f}'.format(test_acc.item()))
                test_loss_averager.add(test_loss.item())
                test_loss_averager.add(test_loss.item())
                test_acc_averager.add(test_acc.item())
                test_acc_averager.add(test_acc.item())
        # record
        # record
        index_values = [
        index_values = [
            'test_acc',
            'test_acc',
            'best_i_task',    # the best_i_task of the highest val_acc
            'best_i_task',    # the best_i_task of the highest val_acc
            'best_train_acc',    # the train_acc according to the best_i_task of the highest val_acc
            'best_train_acc',    # the train_acc according to the best_i_task of the highest val_acc
            'best_train_loss',    # the train_loss according to the best_i_task of the highest val_acc
            'best_train_loss',    # the train_loss according to the best_i_task of the highest val_acc
            'best_val_acc',
            'best_val_acc',
            'best_val_loss'
            'best_val_loss'
        ]
        ]
        best_index = int(train_log['max_acc_i_task'] / args.valid_every_tasks) - 1
        best_index = int(train_log['max_acc_i_task'] / args.valid_every_tasks) - 1
        test_record = {}
        test_record = {}
        test_record_data = [
        test_record_data = [
            test_acc_averager.item(return_str=True),
            test_acc_averager.item(return_str=True),
            str(train_log['max_acc_i_task']),
            str(train_log['max_acc_i_task']),
            str(train_log['train_acc'][best_index]),
            str(train_log['train_acc'][best_index]),
            str(train_log['train_loss'][best_index]),
            str(train_log['train_loss'][best_index]),
            str(train_log['max_acc']),
            str(train_log['max_acc']),
            str(train_log['val_loss'][best_index]),
            str(train_log['val_loss'][best_index]),
        ]
        ]
        test_record[args.record_folder] = test_record_data
        test_record[args.record_folder] = test_record_data
        test_record_file = os.path.join(args.record_folder, 'record_{}_{}_{}shot.csv'.format(args.train_data, args.test_data, args.num_shots))
        test_record_file = os.path.join(args.record_folder, 'record_{}_{}_{}shot.csv'.format(args.train_data, args.test_data, args.num_shots))
        DataFrame(test_record, index=index_values).to_csv(test_record_file)
        DataFrame(test_record, index=index_values).to_csv(test_record_file)